ギブスサンプリング

MikuHatsune2018-05-29

Metropolis-Hastings サンプリングをやったので、ギブスサンプリングをやってみる。
ある変数X=\{x_1,x_2,\dots,x_{i-1},x_i,x_{i+1},\dots,x_m\} について、i 番目を取り除いたX_{-i}
X_i|X_{-i} \sim x_1,x_2,\dots,x_{i-1},~~~x_{i+1},\dots,x_mi 番目が抜けている)
で順次サンプリングして、その値を入れなおしてまたサンプリングする、を1\dots m について行う。
 
結局、あるひとつi 以外のすべてを固定して、ひとつサンプリングすることになる。多次元正規分布の場合は、こちらを参考に
\mu_{i|-i}\leftarrow \mu_{i}+\Sigma_{-i-i}^{-1}(X_{-i}-\mu_{-i})
\Sigma_{i|-i}\leftarrow \Sigma_{ii}-\Sigma_{-i-i}^{-1}\Sigma_{-ii}
とする。ただし\underbrace{\Sigma_{-i-i}^{-1}}_{1\times (m-1)}=\underbrace{\Sigma-ii}_{1\times(m-1)}\underbrace{\Sigma_{-i-i}}_{(m-1)\times(m-1)} である。なので\mu_{i|-i}\Sigma_{i|-i} はそれぞれ1\times 1 の大きさというかスカラーで、各X_i はひとつずつサンプリングされる。
スクリプトここからパクる。
 
初期値が真の分布からは遠いところにしておいて、収束するまでにちょっと時間がかかるようにする。
Metropolis-Hastings(MH)では、真の分布に近づくまでに少し時間がかかる。サンプリングされた分布は、xy 同時にサンプリングしているので、移動方向は斜めである。
Gibbs (GB)は、各X_i でサンプリングしているので、動き方は各x, y 方向にジグザグである。詳細釣り合い条件や移動確率が1であることはここでは取り上げないが、動きやすいのでMH より収束に向かいやすい。
この2次元の例では、一発目から真の分布に取り込まれている様子がわかる。


 
多次元に拡張するが、図示することを考えて3次元にしてみる。MH もGB もひだり下のマゼンタのところからサンプリングを開始したが、HM は最初の50点の時点ではまだ収束していないが、GB は7点目くらいからもう収束している。

MH の移動方向は斜めだが、GB では各点がサンプリングされるまでに、xyz でそれぞれX_i をサンプリングしているので、カクカクになっている。

# 2D
# 初期値
x0 <- c(-5, 2)
iter <- 1500 # 繰り返し回数
x <- matrix(0, iter, 2) # ランダムウォークの座標
dratio <- rep(0, iter) # 確率密度比
out <- rep(0, iter) # 確率密度比にしたがって確率的に動いたか、動かなかったかの記録
x[1,] <- x0

# 2次元正規分布の相関
rho <- 0.7
sig <- matrix(c(1, rho, rho, 1), 2)
mu <- c(4, 5) # 2次元正規分布の各パラメータの平均

# Metrololis-Hastings
for(i in 2:iter){
  walk <- runif(2, -0.5, 0.5) # 歩く量
  v <- x[i-1,] + walk # 次の候補点
  dratio[i-1] <- dmvnorm(v, mean=mu, sigma=sig)/dmvnorm(x[i-1,], mean=mu, sigma=sig) # 確率密度比
  out[i-1] <- rbinom(1, 1, min(1, dratio[i-1])) # 動いたか、動かなかったか
  x[i,] <- x[i-1,] + out[i-1]*walk
}

kd1 <- kde2d(x[,1], x[,2], c(bandwidth.nrd(x[,1]), bandwidth.nrd(x[,2])), n=1000)
cols <- jet.colors(iter)
plot(x, type="p", pch=16, cex=0.6, col=2, xlab="x1", ylab="x2", main="Metropolis-Hastings sampling", xlim=xl, ylim=yl)
for(i in 2:iter){
  segments(x[i-1,1], x[i-1,2], x[i,1], x[i,2], col=cols[i])
}
points(x[1,1], x[1,2], pch="★", col=6)
points(mu[1], mu[2], pch="★", col=5)
contour(kd1, add=TRUE, col=1)

# Gibbs
X <- matrix(x0, nr=1)
for(j in 2:iter){
  x <- X[j-1,]
  for(i in seq(mu)){
    s <- sig[-i, i] %*% solve(sig[-i, -i])   # Σ_ab Σ_bb ^ -1
    # (PRML 2.81) μ_a|b = μ_a + Σ_ab Σ_bb ^ -1 (x_b - μ_b)
    mu_a_b <- mu[i] + s %*% (x[-i] - mu[-i])
    # (PRML 2.82) Σ_a|b = Σ_aa - Σ_ab Σ_bb ^ -1 Σ_ba
    sig_a_b <- sig[i, i] - s %*% sig[i, -i]
   # [Gibbs] x_a 〜 p(x_a|x_{-a}) = N(μ_a|b, Σ_a|b)
    x[i] <- rnorm(1, mu_a_b, sqrt(sig_a_b))
  }
  X <- rbind(X, x)
}
colnames(X) <- sprintf("V%d", seq(ncol(X)))

kd2 <- kde2d(X[,1], X[,2], c(bandwidth.nrd(X[,1]), bandwidth.nrd(X[,2])), n=1000)
cols <- jet.colors(iter)
plot(X, type="p", pch=16, cex=0.6, col=2, xlab="x1", ylab="x2", main="Gibbs sampling", xlim=xl, ylim=yl)
#lines(x, col=cols)
for(i in 2:iter){
  for(j in seq(ncol(X))){
    y <- rbind(replace(X[i-1,], 0:(j-1), X[i, 0:(j-1)]), replace(X[i-1,], 1:j, X[i, 1:j]))
    segments(y[1,1], y[1,2], y[2,1], y[2,2], col=cols[i])
  }
}
points(X[1,1], X[1,2], pch="★", col=6)
points(mu[1], mu[2], pch="★", col=5)
contour(kd2, add=TRUE, col=1)
# 3D
# 初期値
x0 <- c(-5, 2, -4)

iter <- 1500 # 繰り返し回数
x <- matrix(0, iter, length(x0)) # ランダムウォークの座標
dratio <- rep(0, iter) # 確率密度比
out <- rep(0, iter) # 確率密度比にしたがって確率的に動いたか、動かなかったかの記録
x[1,] <- x0
# 3次元の多次元正規分布を適当に作る
rho <- c(0.7, 0.6, 0.9)
sig <- diag(0, length(x0))
sig[lower.tri(sig)] <- rho
sig <- sig + t(sig)
diag(sig) <- rep(1, length(x0))
mu <- c(4, 5, 3)

# Metropolis-Hastings
for(i in 2:iter){
  walk <- runif(length(x0), -0.5, 0.5) # 歩く量
  v <- x[i-1,] + walk # 次の候補点
  dratio[i-1] <- dmvnorm(v, mean=mu, sigma=sig)/dmvnorm(x[i-1,], mean=mu, sigma=sig) # 確率密度比
  out[i-1] <- rbinom(1, 1, min(1, dratio[i-1])) # 動いたか、動かなかったか
  x[i,] <- x[i-1,] + out[i-1]*walk
}
colnames(x) <- sprintf("V%d", seq(ncol(x)))

cols <- jet.colors(iter)
plot3d(x, type="p", pch=16, size=0.01)
points3d(head(x, 50), size=5)
title3d("Metropolis-Hastings sampling", line=3)
for(i in 2:iter){
  segments3d(x[(i-1):i,], col=cols[i])
}
spheres3d(x[1,], col=6, radius=0.5)

# Gibbs
X <- matrix(x0, nr=1)
for(j in 2:iter){
  x <- X[j-1,]
  for(i in seq(mu)){
    s <- sig[-i, i] %*% solve(sig[-i, -i])   # Σ_ab Σ_bb ^ -1
    # (PRML 2.81) μ_a|b = μ_a + Σ_ab Σ_bb ^ -1 (x_b - μ_b)
    mu_a_b <- mu[i] + s %*% (x[-i] - mu[-i])
    # (PRML 2.82) Σ_a|b = Σ_aa - Σ_ab Σ_bb ^ -1 Σ_ba
    sig_a_b <- sig[i, i] - s %*% sig[i, -i]
   # [Gibbs] x_a 〜 p(x_a|x_{-a}) = N(μ_a|b, Σ_a|b)
    x[i] <- rnorm(1, mu_a_b, sqrt(sig_a_b))
  }
  X <- rbind(X, x)
}
colnames(X) <- sprintf("V%d", seq(ncol(X)))
# ギザギザの描画
Y <- X[1, , drop=FALSE]
for(i in 2:nrow(X)){
  for(j in 1:ncol(X)){
    y <- rbind(replace(X[i-1,], 0:(j-1), X[i, 0:(j-1)]), replace(X[i-1,], 1:j, X[i, 1:j]))
    Y <- rbind(Y, y)
  }
}

cols <- jet.colors(iter)
plot3d(X, size=0.01)
points3d(head(X, 50), size=5)
title3d("Gibbs sampling", line=3)
lines3d(Y, col=rep(cols, each=3))
spheres3d(x[1,], col=6, radius=0.5)