Metropolis-Hastings サンプリングをやったので、ギブスサンプリングをやってみる。
ある変数 について、 番目を取り除いた を
( 番目が抜けている)
で順次サンプリングして、その値を入れなおしてまたサンプリングする、を について行う。
結局、あるひとつ 以外のすべてを固定して、ひとつサンプリングすることになる。多次元正規分布の場合は、こちらを参考に
とする。ただし である。なので と はそれぞれ の大きさというかスカラーで、各 はひとつずつサンプリングされる。
スクリプトはここからパクる。
初期値が真の分布からは遠いところにしておいて、収束するまでにちょっと時間がかかるようにする。
Metropolis-Hastings(MH)では、真の分布に近づくまでに少し時間がかかる。サンプリングされた分布は、xy 同時にサンプリングしているので、移動方向は斜めである。
Gibbs (GB)は、各 でサンプリングしているので、動き方は各x, y 方向にジグザグである。詳細釣り合い条件や移動確率が1であることはここでは取り上げないが、動きやすいのでMH より収束に向かいやすい。
この2次元の例では、一発目から真の分布に取り込まれている様子がわかる。
多次元に拡張するが、図示することを考えて3次元にしてみる。MH もGB もひだり下のマゼンタのところからサンプリングを開始したが、HM は最初の50点の時点ではまだ収束していないが、GB は7点目くらいからもう収束している。
MH の移動方向は斜めだが、GB では各点がサンプリングされるまでに、xyz でそれぞれ をサンプリングしているので、カクカクになっている。
# 2D # 初期値 x0 <- c(-5, 2) iter <- 1500 # 繰り返し回数 x <- matrix(0, iter, 2) # ランダムウォークの座標 dratio <- rep(0, iter) # 確率密度比 out <- rep(0, iter) # 確率密度比にしたがって確率的に動いたか、動かなかったかの記録 x[1,] <- x0 # 2次元正規分布の相関 rho <- 0.7 sig <- matrix(c(1, rho, rho, 1), 2) mu <- c(4, 5) # 2次元正規分布の各パラメータの平均 # Metrololis-Hastings for(i in 2:iter){ walk <- runif(2, -0.5, 0.5) # 歩く量 v <- x[i-1,] + walk # 次の候補点 dratio[i-1] <- dmvnorm(v, mean=mu, sigma=sig)/dmvnorm(x[i-1,], mean=mu, sigma=sig) # 確率密度比 out[i-1] <- rbinom(1, 1, min(1, dratio[i-1])) # 動いたか、動かなかったか x[i,] <- x[i-1,] + out[i-1]*walk } kd1 <- kde2d(x[,1], x[,2], c(bandwidth.nrd(x[,1]), bandwidth.nrd(x[,2])), n=1000) cols <- jet.colors(iter) plot(x, type="p", pch=16, cex=0.6, col=2, xlab="x1", ylab="x2", main="Metropolis-Hastings sampling", xlim=xl, ylim=yl) for(i in 2:iter){ segments(x[i-1,1], x[i-1,2], x[i,1], x[i,2], col=cols[i]) } points(x[1,1], x[1,2], pch="★", col=6) points(mu[1], mu[2], pch="★", col=5) contour(kd1, add=TRUE, col=1) # Gibbs X <- matrix(x0, nr=1) for(j in 2:iter){ x <- X[j-1,] for(i in seq(mu)){ s <- sig[-i, i] %*% solve(sig[-i, -i]) # Σ_ab Σ_bb ^ -1 # (PRML 2.81) μ_a|b = μ_a + Σ_ab Σ_bb ^ -1 (x_b - μ_b) mu_a_b <- mu[i] + s %*% (x[-i] - mu[-i]) # (PRML 2.82) Σ_a|b = Σ_aa - Σ_ab Σ_bb ^ -1 Σ_ba sig_a_b <- sig[i, i] - s %*% sig[i, -i] # [Gibbs] x_a 〜 p(x_a|x_{-a}) = N(μ_a|b, Σ_a|b) x[i] <- rnorm(1, mu_a_b, sqrt(sig_a_b)) } X <- rbind(X, x) } colnames(X) <- sprintf("V%d", seq(ncol(X))) kd2 <- kde2d(X[,1], X[,2], c(bandwidth.nrd(X[,1]), bandwidth.nrd(X[,2])), n=1000) cols <- jet.colors(iter) plot(X, type="p", pch=16, cex=0.6, col=2, xlab="x1", ylab="x2", main="Gibbs sampling", xlim=xl, ylim=yl) #lines(x, col=cols) for(i in 2:iter){ for(j in seq(ncol(X))){ y <- rbind(replace(X[i-1,], 0:(j-1), X[i, 0:(j-1)]), replace(X[i-1,], 1:j, X[i, 1:j])) segments(y[1,1], y[1,2], y[2,1], y[2,2], col=cols[i]) } } points(X[1,1], X[1,2], pch="★", col=6) points(mu[1], mu[2], pch="★", col=5) contour(kd2, add=TRUE, col=1)
# 3D # 初期値 x0 <- c(-5, 2, -4) iter <- 1500 # 繰り返し回数 x <- matrix(0, iter, length(x0)) # ランダムウォークの座標 dratio <- rep(0, iter) # 確率密度比 out <- rep(0, iter) # 確率密度比にしたがって確率的に動いたか、動かなかったかの記録 x[1,] <- x0 # 3次元の多次元正規分布を適当に作る rho <- c(0.7, 0.6, 0.9) sig <- diag(0, length(x0)) sig[lower.tri(sig)] <- rho sig <- sig + t(sig) diag(sig) <- rep(1, length(x0)) mu <- c(4, 5, 3) # Metropolis-Hastings for(i in 2:iter){ walk <- runif(length(x0), -0.5, 0.5) # 歩く量 v <- x[i-1,] + walk # 次の候補点 dratio[i-1] <- dmvnorm(v, mean=mu, sigma=sig)/dmvnorm(x[i-1,], mean=mu, sigma=sig) # 確率密度比 out[i-1] <- rbinom(1, 1, min(1, dratio[i-1])) # 動いたか、動かなかったか x[i,] <- x[i-1,] + out[i-1]*walk } colnames(x) <- sprintf("V%d", seq(ncol(x))) cols <- jet.colors(iter) plot3d(x, type="p", pch=16, size=0.01) points3d(head(x, 50), size=5) title3d("Metropolis-Hastings sampling", line=3) for(i in 2:iter){ segments3d(x[(i-1):i,], col=cols[i]) } spheres3d(x[1,], col=6, radius=0.5) # Gibbs X <- matrix(x0, nr=1) for(j in 2:iter){ x <- X[j-1,] for(i in seq(mu)){ s <- sig[-i, i] %*% solve(sig[-i, -i]) # Σ_ab Σ_bb ^ -1 # (PRML 2.81) μ_a|b = μ_a + Σ_ab Σ_bb ^ -1 (x_b - μ_b) mu_a_b <- mu[i] + s %*% (x[-i] - mu[-i]) # (PRML 2.82) Σ_a|b = Σ_aa - Σ_ab Σ_bb ^ -1 Σ_ba sig_a_b <- sig[i, i] - s %*% sig[i, -i] # [Gibbs] x_a 〜 p(x_a|x_{-a}) = N(μ_a|b, Σ_a|b) x[i] <- rnorm(1, mu_a_b, sqrt(sig_a_b)) } X <- rbind(X, x) } colnames(X) <- sprintf("V%d", seq(ncol(X))) # ギザギザの描画 Y <- X[1, , drop=FALSE] for(i in 2:nrow(X)){ for(j in 1:ncol(X)){ y <- rbind(replace(X[i-1,], 0:(j-1), X[i, 0:(j-1)]), replace(X[i-1,], 1:j, X[i, 1:j])) Y <- rbind(Y, y) } } cols <- jet.colors(iter) plot3d(X, size=0.01) points3d(head(X, 50), size=5) title3d("Gibbs sampling", line=3) lines3d(Y, col=rep(cols, each=3)) spheres3d(x[1,], col=6, radius=0.5)