ギブスサンプリングによるベイズ推定

はじめに

MCMCによるベイズ推定として、正規分布に従うデータが与えられたとき、その正規分布のパラメータ(平均と分散)が従う分布および推定値を求める。

尤度関数が正規分布の場合、共役事前分布はそれぞれ、平均は正規分布、分散は逆ガンマ分布になるので、ギブスサンプリングを使うことができ、これでパラメータの推定する。

コード

#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
static const double PI = 3.14159265358979323846264338;

//xorshift
// 注意: longではなくint(32bit)にすべき
unsigned long xor128(){
  static unsigned long x=123456789, y=362436069, z=521288629, w=88675123;
  unsigned long t;
  t=(x^(x<<11));
  x=y; y=z; z=w;
  return w=(w^(w>>19))^(t^(t>>8));
}
//[0,1)の一様乱数
// 注意: int_maxぐらいで割るべき
double frand(){
  return xor128()%ULONG_MAX/static_cast<double>(ULONG_MAX); 
}
//ガンマ乱数
double gamma_rand(double shape, double scale){
  double n, b1, b2, c1, c2;
  if(4.0 < shape) n = 1.0/sqrt(shape);
  else if(0.4 < shape) n = 1.0/shape + (1.0/shape)*(shape-0.4)/3.6;
  else if(0.0 < shape) n = 1.0/shape;
  else return -1;

  b1 = shape - 1.0/n;
  b2 = shape + 1.0/n;

  if(0.4 < shape) c1 = b1 * (log(b1)-1.0) / 2.0;
  else c1 = 0;
  c2 = b2 * (log(b2)-1.0) / 2.0;

  while(true){
    double v1 = frand(), v2 = frand();
    double w1, w2, y, x;
    w1 = c1 + log(v1);
    w2 = c2 + log(v2);
    y = n * (b1*w2-b2*w1);
    if(y < 0) continue;
    x = n * (w2-w1);
    if(log(y) < x) continue;
    return exp(x) * scale;
  }
  return -1;
}
//逆ガンマ乱数
double invgamma_rand(double shape, double scale){
  return 1.0/gamma_rand(shape, scale);
}
//正規乱数
double normal_rand(double mu, double sigma2){
  double sigma = sqrt(sigma2);
  double u1 = frand(), u2 = frand();
  double z1 = sqrt(-2*log(u1)) * cos(2*PI*u2);
  //double z2 = sqrt(-2*log(u1)) * sin(2*PI*u2);
  return mu + sigma*z1;
}




//正規分布に従うデータに対するギブスサンプリング
class Gibbs_Normal {
  double mu, sigma2; //未知パラメータの平均muと分散sigma2
  double n, x_bar; //データ数n、データの平均値x_bar
  //事後分布を見やすくするための自由度を与える変数など
  double m_0, n_0, S_0;
  double m_1, n_1, mu_1, nS_1; 

public:
  Gibbs_Normal(double mu_0, double sigma2_0, double alpha_0, double lambda_0){
    mu = mu_0;
    m_0 = 1.0;
    sigma2 = sigma2_0 / m_0;
    n_0 = alpha_0 * 2.0;
    S_0 = lambda_0 * 2.0 / n_0;
  }
  
  void set_data(const std::vector<double>& data){
    n = data.size();
    x_bar = 0.0;
    for(int i=0; i<n; i++) x_bar += data[i];
    x_bar /= n;
    m_1 = m_0 + n;
    n_1 = n_0 + n;
    mu_1 = (m_0 * mu + n * x_bar) / (m_0 + n);
    nS_1 = n_0 * S_0;
    for(int i=0; i<n; i++) nS_1 += (data[i] - x_bar) * (data[i] - x_bar);
    nS_1 += (m_0 * n) / (m_0 + n) * (x_bar - mu) * (x_bar - mu);

    //std::cerr << "m_0 : " << m_0 << std::endl;
    //std::cerr << "m_1 : " << m_1 << std::endl;
    //std::cerr << "n_1 : " << n_1 << std::endl;
    //std::cerr << "n_1*S_1 : " << nS_1 << std::endl;
    //std::cerr << "mu_1 : " << mu_1 << std::endl;
    //std::cerr << "x_bar : " << x_bar << std::endl;
  }

  void sampling(){
    mu = normal_rand(mu_1, sigma2/m_1);
    sigma2 = invgamma_rand((n_1+1)/2.0, 2.0/(nS_1+m_1*(mu-mu_1)*(mu-mu_1)));
                                                                  //精度を渡したいので逆数にする
  }

  double get_mu_mean(){ return mu; }
  double get_mu_var(){ return sigma2/m_1; }
  double get_sigma2_shape(){ return (n_1+1)/2.0; }
  double get_sigma2_scale(){ return (nS_1+m_1*(mu-mu_1)*(mu-mu_1))/2.0; }
};


int main(){

  Gibbs_Normal gn(0, 1000, 0.001, 0.001); //事前分布のパラメータ

  //データ生成(適当に平均5,分散4の正規分布に従うデータを1000個作成)
  std::vector<double> data;
  for(int i=0; i<1000; i++){
    double r = normal_rand(5, 4);
    data.push_back(r);
  }
  
  //生成したデータの詳細
  double ave_data = 0.0, sigma2_data = 0.0;
  for(int i=0; i<data.size(); i++){
    ave_data += data[i];
  }
  ave_data /= data.size();
  for(int i=0; i<data.size(); i++){
    sigma2_data += (data[i]-ave_data)*(data[i]-ave_data);
  }
  sigma2_data /= data.size();
  std::cerr << "data ave: " << ave_data << std::endl;
  std::cerr << "data sigma2: " << sigma2_data << std::endl;

  //ギブスサンプリング
  int cnt = 0;
  double ave_mu_mean = 0.0; //平均の事後分布の平均の平均値
  double ave_mu_var = 0.0; //平均の事後分布の分散の平均値
  double ave_sigma2_shape = 0.0; //分散の事後分布のshapeの平均値
  double ave_sigma2_scale = 0.0; //分散の事後分布のscaleの平均値

  int iter_N = 1000000; //イテレーション回数
  int burnin = 300000; //バーンイン期間

  gn.set_data(data); //データをセット

  for(int t=0; t<iter_N; t++){
    gn.sampling(); //サンプリング
    if(t >= burnin){
      ave_mu_mean += gn.get_mu_mean();
      ave_mu_var += gn.get_mu_var();
      ave_sigma2_shape += gn.get_sigma2_shape();
      ave_sigma2_scale += gn.get_sigma2_scale();
      cnt++;
    }
  }

  ave_mu_mean /= cnt;
  ave_mu_var /= cnt;
  ave_sigma2_shape /= cnt;
  ave_sigma2_scale /= cnt;

  //事後分布
  std::cerr << "-------" << std::endl;
  std::cerr << "N(" << ave_mu_mean << "," << ave_mu_var << ")" << std::endl;
  std::cerr << "IG(" << ave_sigma2_shape << "," << ave_sigma2_scale << ")" << std::endl;

  //事後分布の平均値(期待値)
  std::cerr << "-------" << std::endl;
  std::cerr << "mean : " << ave_mu_mean << std::endl;  
  std::cerr << "var : " << ave_sigma2_scale/(ave_sigma2_shape-1) << std::endl;

  return 0;
}

結果

#データの平均分散
data ave: 5.01414
data sigma2: 4.04538
-------
#パラメータの事後分布
N(5.00913,0.00407451)
IG(500.501,2037.29)
-------
#事後分布の代表値(平均値)
mean : 5.00913
var : 4.07864

参考文献