ギブスサンプリングによるベイズ推定
はじめに
MCMCによるベイズ推定として、正規分布に従うデータが与えられたとき、その正規分布のパラメータ(平均と分散)が従う分布および推定値を求める。
尤度関数が正規分布の場合、共役事前分布はそれぞれ、平均は正規分布、分散は逆ガンマ分布になるので、ギブスサンプリングを使うことができ、これでパラメータの推定する。
コード
#include <iostream> #include <vector> #include <algorithm> #include <cmath> static const double PI = 3.14159265358979323846264338; //xorshift // 注意: longではなくint(32bit)にすべき unsigned long xor128(){ static unsigned long x=123456789, y=362436069, z=521288629, w=88675123; unsigned long t; t=(x^(x<<11)); x=y; y=z; z=w; return w=(w^(w>>19))^(t^(t>>8)); } //[0,1)の一様乱数 // 注意: int_maxぐらいで割るべき double frand(){ return xor128()%ULONG_MAX/static_cast<double>(ULONG_MAX); } //ガンマ乱数 double gamma_rand(double shape, double scale){ double n, b1, b2, c1, c2; if(4.0 < shape) n = 1.0/sqrt(shape); else if(0.4 < shape) n = 1.0/shape + (1.0/shape)*(shape-0.4)/3.6; else if(0.0 < shape) n = 1.0/shape; else return -1; b1 = shape - 1.0/n; b2 = shape + 1.0/n; if(0.4 < shape) c1 = b1 * (log(b1)-1.0) / 2.0; else c1 = 0; c2 = b2 * (log(b2)-1.0) / 2.0; while(true){ double v1 = frand(), v2 = frand(); double w1, w2, y, x; w1 = c1 + log(v1); w2 = c2 + log(v2); y = n * (b1*w2-b2*w1); if(y < 0) continue; x = n * (w2-w1); if(log(y) < x) continue; return exp(x) * scale; } return -1; } //逆ガンマ乱数 double invgamma_rand(double shape, double scale){ return 1.0/gamma_rand(shape, scale); } //正規乱数 double normal_rand(double mu, double sigma2){ double sigma = sqrt(sigma2); double u1 = frand(), u2 = frand(); double z1 = sqrt(-2*log(u1)) * cos(2*PI*u2); //double z2 = sqrt(-2*log(u1)) * sin(2*PI*u2); return mu + sigma*z1; } //正規分布に従うデータに対するギブスサンプリング class Gibbs_Normal { double mu, sigma2; //未知パラメータの平均muと分散sigma2 double n, x_bar; //データ数n、データの平均値x_bar //事後分布を見やすくするための自由度を与える変数など double m_0, n_0, S_0; double m_1, n_1, mu_1, nS_1; public: Gibbs_Normal(double mu_0, double sigma2_0, double alpha_0, double lambda_0){ mu = mu_0; m_0 = 1.0; sigma2 = sigma2_0 / m_0; n_0 = alpha_0 * 2.0; S_0 = lambda_0 * 2.0 / n_0; } void set_data(const std::vector<double>& data){ n = data.size(); x_bar = 0.0; for(int i=0; i<n; i++) x_bar += data[i]; x_bar /= n; m_1 = m_0 + n; n_1 = n_0 + n; mu_1 = (m_0 * mu + n * x_bar) / (m_0 + n); nS_1 = n_0 * S_0; for(int i=0; i<n; i++) nS_1 += (data[i] - x_bar) * (data[i] - x_bar); nS_1 += (m_0 * n) / (m_0 + n) * (x_bar - mu) * (x_bar - mu); //std::cerr << "m_0 : " << m_0 << std::endl; //std::cerr << "m_1 : " << m_1 << std::endl; //std::cerr << "n_1 : " << n_1 << std::endl; //std::cerr << "n_1*S_1 : " << nS_1 << std::endl; //std::cerr << "mu_1 : " << mu_1 << std::endl; //std::cerr << "x_bar : " << x_bar << std::endl; } void sampling(){ mu = normal_rand(mu_1, sigma2/m_1); sigma2 = invgamma_rand((n_1+1)/2.0, 2.0/(nS_1+m_1*(mu-mu_1)*(mu-mu_1))); //精度を渡したいので逆数にする } double get_mu_mean(){ return mu; } double get_mu_var(){ return sigma2/m_1; } double get_sigma2_shape(){ return (n_1+1)/2.0; } double get_sigma2_scale(){ return (nS_1+m_1*(mu-mu_1)*(mu-mu_1))/2.0; } }; int main(){ Gibbs_Normal gn(0, 1000, 0.001, 0.001); //事前分布のパラメータ //データ生成(適当に平均5,分散4の正規分布に従うデータを1000個作成) std::vector<double> data; for(int i=0; i<1000; i++){ double r = normal_rand(5, 4); data.push_back(r); } //生成したデータの詳細 double ave_data = 0.0, sigma2_data = 0.0; for(int i=0; i<data.size(); i++){ ave_data += data[i]; } ave_data /= data.size(); for(int i=0; i<data.size(); i++){ sigma2_data += (data[i]-ave_data)*(data[i]-ave_data); } sigma2_data /= data.size(); std::cerr << "data ave: " << ave_data << std::endl; std::cerr << "data sigma2: " << sigma2_data << std::endl; //ギブスサンプリング int cnt = 0; double ave_mu_mean = 0.0; //平均の事後分布の平均の平均値 double ave_mu_var = 0.0; //平均の事後分布の分散の平均値 double ave_sigma2_shape = 0.0; //分散の事後分布のshapeの平均値 double ave_sigma2_scale = 0.0; //分散の事後分布のscaleの平均値 int iter_N = 1000000; //イテレーション回数 int burnin = 300000; //バーンイン期間 gn.set_data(data); //データをセット for(int t=0; t<iter_N; t++){ gn.sampling(); //サンプリング if(t >= burnin){ ave_mu_mean += gn.get_mu_mean(); ave_mu_var += gn.get_mu_var(); ave_sigma2_shape += gn.get_sigma2_shape(); ave_sigma2_scale += gn.get_sigma2_scale(); cnt++; } } ave_mu_mean /= cnt; ave_mu_var /= cnt; ave_sigma2_shape /= cnt; ave_sigma2_scale /= cnt; //事後分布 std::cerr << "-------" << std::endl; std::cerr << "N(" << ave_mu_mean << "," << ave_mu_var << ")" << std::endl; std::cerr << "IG(" << ave_sigma2_shape << "," << ave_sigma2_scale << ")" << std::endl; //事後分布の平均値(期待値) std::cerr << "-------" << std::endl; std::cerr << "mean : " << ave_mu_mean << std::endl; std::cerr << "var : " << ave_sigma2_scale/(ave_sigma2_shape-1) << std::endl; return 0; }
結果
#データの平均分散 data ave: 5.01414 data sigma2: 4.04538 ------- #パラメータの事後分布 N(5.00913,0.00407451) IG(500.501,2037.29) ------- #事後分布の代表値(平均値) mean : 5.00913 var : 4.07864
参考文献
- 涌井「道具としてのベイズ統計」
- http://d.hatena.ne.jp/yokoken00/20100128