ノンパラベイズな言語モデルを試す

はじめに

最近「言語モデル」がマイブームなので、最近有名になりつつあるというMCMC法を使ったベイズ言語モデルとして、「階層的Pitman-Yor言語モデル(HPYLM)」を試しにちょっと作ってみた。

とりあえず、文字bigramのHPYLMを試してみる。
毎度のことながら勉強用、実験用なのででかいデータはちょっとまずいと思う。。。

(追記3/17 22:30)
コードで最終的な値に使う所がおかしいのを修正とdとθの収束についてを追加

コード

#include <iostream>
#include <sstream>
#include <fstream>
#include <vector>
#include <deque>
#include <map>
#include <algorithm>
#include <string>
#include <cmath>
#include <climits>

//xorshift
// 注意: longではなくint(32bit)にすべき
unsigned long xor128(){
  static unsigned long x=123456789, y=362436069, z=521288629, w=88675123;
  unsigned long t;
  t=(x^(x<<11));
  x=y; y=z; z=w;
  return w=(w^(w>>19))^(t^(t>>8));
}
//[0,1)の一様乱数
// 注意: int_maxぐらいで割るべき
double frand(){
  return xor128()%ULONG_MAX/static_cast<double>(ULONG_MAX); 
}
//Bernoulli試行(確率pで1、1-pで0を返す)
double bernoulli_rand(double p){
  double r = frand();
  if(r<p) return 1.0;
  return 0.0;
}
//gamma分布に従う乱数
double gamma_rand(double shape, double scale){
  double n, b1, b2, c1, c2;
  if(4.0 < shape) n = 1.0/sqrt(shape);
  else if(0.4 < shape) n = 1.0/shape + (1.0/shape)*(shape-0.4)/3.6;
  else if(0.0 < shape) n = 1.0/shape;
  else return -1;

  b1 = shape - 1.0/n;
  b2 = shape + 1.0/n;

  if(0.4 < shape) c1 = b1 * (log(b1)-1.0) / 2.0;
  else c1 = 0;
  c2 = b2 * (log(b2)-1.0) / 2.0;

  while(true){
    double v1 = frand(), v2 = frand();
    double w1, w2, y, x;
    w1 = c1 + log(v1);
    w2 = c2 + log(v2);
    y = n * (b1*w2-b2*w1);
    if(y < 0) continue;
    x = n * (w2-w1);
    if(log(y) < x) continue;
    return exp(x) * scale;
  }
  return -1;
}
//beta分布に従う乱数
double beta_rand(double a, double b){
  double gamma1 = gamma_rand(a,1.0);
  double gamma2 = gamma_rand(b,1.0);
  return gamma1/(gamma1+gamma2);
}
//離散確率でインデックスを返す
int selectProb(const std::vector< std::pair<int,double> > &p){
  double sum = 0.0;
  for(int i=0; i<(int)p.size(); i++) sum += p[i].second;
  double r = frand()*sum, q = 0;
  for(int i=0; i<(int)p.size()-1; i++){
    q += p[i].second;
    if(r<q) return p[i].first;
  }
  return p[p.size()-1].first;
}


//各ngramのレストラン
class Restaurant {
  bool base_flag; //特殊なレストラン(u="")
  double theta, d; //parameters
public:
  std::deque< std::pair<std::string,int> > tables; //各テーブル(料理と人数)

  Restaurant(bool base_flag=false):base_flag(base_flag){
    d = 0.6; //適当
    theta = 1; //適当
  }

  int c_uwd(const std::string &w){
    int cnt = 0;
    for(int i=0; i<(int)tables.size(); i++){
      if(tables[i].first == w){
	cnt += tables[i].second;
      }
    }
    return cnt;
  }
  double d_u(){ return d; }
  int t_uw(const std::string &w){ 
    int cnt = 0;
    for(int i=0; i<(int)tables.size(); i++){
      if(tables[i].first == w){
	cnt++;
      }
    }
    return cnt;
  }
  double theta_u(){ return theta; }
  int c_udd(){
    int cnt = 0;
    for(int i=0; i<(int)tables.size(); i++) cnt += tables[i].second;
    return cnt;
  }
  int t_ud(){ return tables.size(); }
  
  //AddCustomer
  bool AddCustomer(const std::string &w, const double &pw){
    std::vector< std::pair<int,double> > p; //wを持つ各テーブルのインデクスと人数
    //既存のテーブル
    for(int i=0; i<(int)tables.size(); i++){
      if(tables[i].first == w){
	p.push_back(std::make_pair(i,std::max(0.0,tables[i].second - d_u())));
      }
    }
    //未知のテーブル
    p.push_back(std::make_pair(tables.size(),(theta_u() + d_u() * t_ud()) * pw));
    int k = selectProb(p);
    if(k < (int)tables.size()){ //既知のテーブルに座る
      tables[k].second++;
    }else{ //未知のテーブルに座る
      tables.push_back(std::make_pair(w,1));
      if(base_flag) return false;
      return true;
    }
    return false;
  }
  //RemoveCustomer
  bool RemoveCustomer(const std::string &w){
    std::vector< std::pair<int,double> > p; //wを持つ各テーブルのインデクスと人数
    for(int i=0; i<(int)tables.size(); i++){
      if(tables[i].first == w){ 
	p.push_back(std::make_pair(i,tables[i].second));
      }
    }
    int k = selectProb(p);
    tables[k].second--;
    if(tables[k].second == 0){
      tables.erase(tables.begin() + k);
      if(base_flag) return false;
      return true;
    }
    return false;
  }
  //パラメータの更新
  void UpdateParameters(double d_, double theta_){
    d = d_;
    theta = theta_;
  }
};

//階層Pitman-Yor言語モデル
class HPYLM {
  int n_val; //n-gram modelのnの値
  int num_word_type; //文字種の数
  std::map<std::deque<std::string>,Restaurant> ngrams; //各レストラン
  std::vector<double> a_m, b_m, alpha_m, beta_m, d_m, theta_m; //各ngramで共通のハイパーパラメータ

  //ハイパーパラメータの履歴を保存しておく用
  std::vector< std::vector<double> > history_d_m, history_theta_m;

  //基底測度G_0(w)
  double WordProbBase(const std::string &w){
    return 1.0/num_word_type;
  }
public:
  //ngramのnの値, G0用の文字種の数
  HPYLM(int n, int nwt){
    n_val = n;
    num_word_type = nwt;
    for(int i=0; i<n_val; i++){
      a_m.push_back(1.0);
      b_m.push_back(1.0);
      alpha_m.push_back(1.0);
      beta_m.push_back(1.0);
      d_m.push_back(0.6); //適当
      theta_m.push_back(0.1); //適当

      history_d_m.push_back(std::vector<double>());
      history_theta_m.push_back(std::vector<double>());
    }
    ngrams.insert(std::make_pair(std::deque<std::string>(), Restaurant(true)));
  }
  void show_ngrams(){
    std::map<std::deque<std::string>,Restaurant>::iterator itr;
    for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){
      std::cout << "Restaurant[";
      for(int i=0; i<(int)(itr->first.size()); i++){
	std::cout << itr->first[i];
	if(i!=(int)(itr->first.size()-1)) std::cout << "|";
      }
      std::cout << "]" << "(d=" << itr->second.d_u() << ",theta=" << itr->second.theta_u() << ")" << std::endl;
      for(int i=0; i<(int)(itr->second.tables.size()); i++){
	std::cout << itr->second.tables[i].first << "\t" << itr->second.tables[i].second << std::endl;
      }
    }
  }
  //単語の出現確率
  double WordProbability(std::deque<std::string> u, const std::string &w){
    std::map<std::deque<std::string>,Restaurant>::iterator itr;
    std::pair<std::map<std::deque<std::string>,Restaurant>::iterator,bool> res;
    res = ngrams.insert(std::make_pair(u,Restaurant()));
    itr = res.first;
    double c_uwd = itr->second.c_uwd(w);
    double d_u = itr->second.d_u();
    double t_uw = itr->second.t_uw(w);
    double theta_u = itr->second.theta_u();
    double c_udd = itr->second.c_udd();
    double t_ud = itr->second.t_ud();
    if(u.size() == 0) return ((c_uwd - d_u * t_uw) + (theta_u + d_u * t_ud) * WordProbBase(w))/(theta_u + c_udd);
    u.pop_front();
    return ((c_uwd - d_u * t_uw) + (theta_u + d_u * t_ud) * WordProbability(u,w))/(theta_u + c_udd);
  }
  //単語の追加
  void AddCustomer(std::deque<std::string> u, const std::string &w){
    std::deque<std::string> pu(u);
    pu.pop_front();
    while(ngrams[u].AddCustomer(w,WordProbability(pu,w))){
      u.pop_front();
      if(pu.size()>0) pu.pop_front();
    }
  }
  //単語の削除
  void RemoveCustomer(std::deque<std::string> u, const std::string &w){
    while(ngrams[u].RemoveCustomer(w)){
      u.pop_front();
    }
  }
  //ハイパーパラメータの更新
  void UpdateHyperParams(){
    std::map<std::deque<std::string>,Restaurant>::iterator itr;
    for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){
      int m = itr->first.size();
      int t_ud = itr->second.t_ud();
      if(t_ud >= 2){
	double tmp_am = 0, tmp_alpham = 0;
	for(int i=1; i<=t_ud-1; i++){
	  double yui = bernoulli_rand(theta_m[m]/(theta_m[m]+i*d_m[m]));
	  tmp_am += 1.0 - yui;
	  tmp_alpham += yui;
	}
	a_m[m] += tmp_am;
	alpha_m[m] += tmp_alpham;
	double xu = beta_rand(theta_m[m]+1.0, itr->second.c_udd()-1.0);
	beta_m[m] -= log(xu);
      }
      for(int i=0; i<(int)itr->second.tables.size(); i++){
	if(itr->second.tables[i].second >= 2){
	  double tmp_bm = 0;
	  for(int j=1; j<=(int)itr->second.tables[i].second-1; j++){
	    double zuwkj = bernoulli_rand((j-1.0)/(j-d_m[m]));
	    tmp_bm += 1.0 - zuwkj;
	  }
	  b_m[m] += tmp_bm;
	}
      }
    }
    //dとthetaの更新
    for(int m=0; m<n_val; m++){
      d_m[m] = beta_rand(a_m[m], b_m[m]);
      theta_m[m] = gamma_rand(alpha_m[m], 1.0/beta_m[m]);
    }
    //各レストランに反映
    for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){
      int m = itr->first.size();
      itr->second.UpdateParameters(d_m[m],theta_m[m]);
    }
  }

  //モデルの推定(Blocked Gibbs Sampler)
  void inference(const std::string &filename, int t_iteration, int t_burnin){
    for(int t=0; t<t_iteration; t++){
      if(t%1000==0) std::cout << "." << std::flush;
      std::ifstream ifs(filename.c_str()); //毎回ファイルから読み込むorz
      std::string text;
      //ngram読み込み
      while(std::getline(ifs, text)){
        std::stringstream ss(text);
        std::deque<std::string> u;
        std::string tmp, w;
        for(int i=0; i<n_val-1; i++){
          ss >> tmp;
          u.push_back(tmp);
        }
        ss >> w;

        if(t > 0) RemoveCustomer(u, w);
        AddCustomer(u, w);
      }
      //ハイパーパラメータの更新
      UpdateHyperParams();
      
      if(t >= t_burnin){//パラメータの値を保存しておく
        for(int i=0; i<n_val; i++){
          history_d_m[i].push_back(d_m[i]);
          history_theta_m[i].push_back(theta_m[i]);
        }
      }
    }
    std::cout << std::endl;

    /*//確認用dとthetaの出力
    for(int i=0; i<t_iteration-t_burnin; i++){
      std::cerr << i << "\t";
      for(int m=0; m<n_val; m++){
        std::cerr << history_d_m[m][i] << "\t" << history_theta_m[m][i] << "\t";        
      }
      std::cerr << std::endl;
    }
    */
    
    //burnin期間を除いたパラメータの平均値で最終的なd_mとtheta_mを決めてみる
    std::vector<double> ave_d_m(n_val), ave_theta_m(n_val);
    for(int m=0; m<n_val; m++){
      int num = history_d_m[m].size();
      for(int i=0; i<num; i++){
        ave_d_m[m] += history_d_m[m][i];
        ave_theta_m[m] += history_theta_m[m][i];
      }
      ave_d_m[m] /= num;
      ave_theta_m[m] /= num;
      //平均値を使う
      d_m[m] = ave_d_m[m];
      theta_m[m] = ave_theta_m[m];
    }
    //全てのレストランに反映
    std::map<std::deque<std::string>,Restaurant>::iterator itr;
    for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){
      int m = itr->first.size();
      itr->second.UpdateParameters(d_m[m],theta_m[m]);
    }    
  }
};


int main(){
  int n = 2; //bigram言語モデル
  int n_char = 10; //世界には10文字ぐらいあると仮定
  HPYLM hpylm(n, n_char);
  hpylm.inference("ngram.data", 100000, 20000); //100000回反復(うち最初の20000回はburninとして捨てる)

  //モデルの状態を表示
  std::cout << "======================" << std::endl;
  hpylm.show_ngrams();
  std::cout << "======================" << std::endl;

  //各ngramの確率を求める
  std::string text;
  while(std::getline(std::cin, text)){
    std::stringstream ss(text);
    std::deque<std::string> u;
    std::string tmp, w;
    for(int i=0; i<n-1; i++){
      ss >> tmp;
      u.push_back(tmp);
    }
    ss >> w;
    
    std::cout << "Prob : " << hpylm.WordProbability(u, w) << std::endl;
  }

  return 0;
}

結果

学習データ「ngram.data」
  • 半角スペース区切り
今 日
今 日
今 日
今 日
今 日
今 日
今 日
今 日
今 日
今 日
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
今 月
明 日
明 日
明 日
明 日
明 日
実行結果
$ ./a.out
................................................................................
....................
======================
Restaurant[](d=0.553945,theta=0.730763)
月      2
日      1
日      2
日      1
Restaurant[今](d=0.0573542,theta=0.595518)
日      10
月      8
月      12
Restaurant[明](d=0.0573542,theta=0.595518)
日      3
日      1
日      1
======================
今 日
Prob : 0.334784
今 月
Prob : 0.65643
今 年
Prob : 0.00109828
明 日
Prob : 0.916481
明 月
Prob : 0.0354769
明 年
Prob : 0.00600527
あ い
Prob : 0.0437773

P(日|今)は、最尤推定だと0.333333ぐらい
P(月|今)は、最尤推定だと0.666666ぐらい
P(年|今)は、最尤推定だと0.0
P(日|明)は、最尤推定だと1.0
P(月|明)は、最尤推定だと0.0
P(年|明)は、最尤推定だと0.0
P(い|あ)は、最尤推定だと0.0

おぉ、、そこそこそれっぽい結果がでてるようにみえる。。。
学習データ中には「明月」はないけど、「今月」が出てる関係で、他の学習データにないものに比べちょっと確率が高くできてる。
「あい」は未知の語同士だからわからないはずだけど、ちゃんと確率値が与えられてる。

dとθの収束について

モデルが収束してるのかdとθの変化についてプロットしてみる。10万回分。

赤線:d_0
緑線:theta_0
青線:d_1
ピンク線:theta_1

まだバグがあるのかもしれないけど、えらい収束がゆっくりしてるように見える。
(200回程度じゃ全然収束してない、、、ので反復回数を修正)
あと、最終的な値に使うのが生成した値じゃなかったので、コードの方をdとtheta値を保存して平均をとるように変更。
うーん、、難しい。

参考資料