ノンパラベイズな言語モデルを試す
はじめに
最近「言語モデル」がマイブームなので、最近有名になりつつあるというMCMC法を使ったベイズな言語モデルとして、「階層的Pitman-Yor言語モデル(HPYLM)」を試しにちょっと作ってみた。
とりあえず、文字bigramのHPYLMを試してみる。
毎度のことながら勉強用、実験用なのででかいデータはちょっとまずいと思う。。。
(追記3/17 22:30)
コードで最終的な値に使う所がおかしいのを修正とdとθの収束についてを追加
コード
#include <iostream> #include <sstream> #include <fstream> #include <vector> #include <deque> #include <map> #include <algorithm> #include <string> #include <cmath> #include <climits> //xorshift // 注意: longではなくint(32bit)にすべき unsigned long xor128(){ static unsigned long x=123456789, y=362436069, z=521288629, w=88675123; unsigned long t; t=(x^(x<<11)); x=y; y=z; z=w; return w=(w^(w>>19))^(t^(t>>8)); } //[0,1)の一様乱数 // 注意: int_maxぐらいで割るべき double frand(){ return xor128()%ULONG_MAX/static_cast<double>(ULONG_MAX); } //Bernoulli試行(確率pで1、1-pで0を返す) double bernoulli_rand(double p){ double r = frand(); if(r<p) return 1.0; return 0.0; } //gamma分布に従う乱数 double gamma_rand(double shape, double scale){ double n, b1, b2, c1, c2; if(4.0 < shape) n = 1.0/sqrt(shape); else if(0.4 < shape) n = 1.0/shape + (1.0/shape)*(shape-0.4)/3.6; else if(0.0 < shape) n = 1.0/shape; else return -1; b1 = shape - 1.0/n; b2 = shape + 1.0/n; if(0.4 < shape) c1 = b1 * (log(b1)-1.0) / 2.0; else c1 = 0; c2 = b2 * (log(b2)-1.0) / 2.0; while(true){ double v1 = frand(), v2 = frand(); double w1, w2, y, x; w1 = c1 + log(v1); w2 = c2 + log(v2); y = n * (b1*w2-b2*w1); if(y < 0) continue; x = n * (w2-w1); if(log(y) < x) continue; return exp(x) * scale; } return -1; } //beta分布に従う乱数 double beta_rand(double a, double b){ double gamma1 = gamma_rand(a,1.0); double gamma2 = gamma_rand(b,1.0); return gamma1/(gamma1+gamma2); } //離散確率でインデックスを返す int selectProb(const std::vector< std::pair<int,double> > &p){ double sum = 0.0; for(int i=0; i<(int)p.size(); i++) sum += p[i].second; double r = frand()*sum, q = 0; for(int i=0; i<(int)p.size()-1; i++){ q += p[i].second; if(r<q) return p[i].first; } return p[p.size()-1].first; } //各ngramのレストラン class Restaurant { bool base_flag; //特殊なレストラン(u="") double theta, d; //parameters public: std::deque< std::pair<std::string,int> > tables; //各テーブル(料理と人数) Restaurant(bool base_flag=false):base_flag(base_flag){ d = 0.6; //適当 theta = 1; //適当 } int c_uwd(const std::string &w){ int cnt = 0; for(int i=0; i<(int)tables.size(); i++){ if(tables[i].first == w){ cnt += tables[i].second; } } return cnt; } double d_u(){ return d; } int t_uw(const std::string &w){ int cnt = 0; for(int i=0; i<(int)tables.size(); i++){ if(tables[i].first == w){ cnt++; } } return cnt; } double theta_u(){ return theta; } int c_udd(){ int cnt = 0; for(int i=0; i<(int)tables.size(); i++) cnt += tables[i].second; return cnt; } int t_ud(){ return tables.size(); } //AddCustomer bool AddCustomer(const std::string &w, const double &pw){ std::vector< std::pair<int,double> > p; //wを持つ各テーブルのインデクスと人数 //既存のテーブル for(int i=0; i<(int)tables.size(); i++){ if(tables[i].first == w){ p.push_back(std::make_pair(i,std::max(0.0,tables[i].second - d_u()))); } } //未知のテーブル p.push_back(std::make_pair(tables.size(),(theta_u() + d_u() * t_ud()) * pw)); int k = selectProb(p); if(k < (int)tables.size()){ //既知のテーブルに座る tables[k].second++; }else{ //未知のテーブルに座る tables.push_back(std::make_pair(w,1)); if(base_flag) return false; return true; } return false; } //RemoveCustomer bool RemoveCustomer(const std::string &w){ std::vector< std::pair<int,double> > p; //wを持つ各テーブルのインデクスと人数 for(int i=0; i<(int)tables.size(); i++){ if(tables[i].first == w){ p.push_back(std::make_pair(i,tables[i].second)); } } int k = selectProb(p); tables[k].second--; if(tables[k].second == 0){ tables.erase(tables.begin() + k); if(base_flag) return false; return true; } return false; } //パラメータの更新 void UpdateParameters(double d_, double theta_){ d = d_; theta = theta_; } }; //階層Pitman-Yor言語モデル class HPYLM { int n_val; //n-gram modelのnの値 int num_word_type; //文字種の数 std::map<std::deque<std::string>,Restaurant> ngrams; //各レストラン std::vector<double> a_m, b_m, alpha_m, beta_m, d_m, theta_m; //各ngramで共通のハイパーパラメータ //ハイパーパラメータの履歴を保存しておく用 std::vector< std::vector<double> > history_d_m, history_theta_m; //基底測度G_0(w) double WordProbBase(const std::string &w){ return 1.0/num_word_type; } public: //ngramのnの値, G0用の文字種の数 HPYLM(int n, int nwt){ n_val = n; num_word_type = nwt; for(int i=0; i<n_val; i++){ a_m.push_back(1.0); b_m.push_back(1.0); alpha_m.push_back(1.0); beta_m.push_back(1.0); d_m.push_back(0.6); //適当 theta_m.push_back(0.1); //適当 history_d_m.push_back(std::vector<double>()); history_theta_m.push_back(std::vector<double>()); } ngrams.insert(std::make_pair(std::deque<std::string>(), Restaurant(true))); } void show_ngrams(){ std::map<std::deque<std::string>,Restaurant>::iterator itr; for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){ std::cout << "Restaurant["; for(int i=0; i<(int)(itr->first.size()); i++){ std::cout << itr->first[i]; if(i!=(int)(itr->first.size()-1)) std::cout << "|"; } std::cout << "]" << "(d=" << itr->second.d_u() << ",theta=" << itr->second.theta_u() << ")" << std::endl; for(int i=0; i<(int)(itr->second.tables.size()); i++){ std::cout << itr->second.tables[i].first << "\t" << itr->second.tables[i].second << std::endl; } } } //単語の出現確率 double WordProbability(std::deque<std::string> u, const std::string &w){ std::map<std::deque<std::string>,Restaurant>::iterator itr; std::pair<std::map<std::deque<std::string>,Restaurant>::iterator,bool> res; res = ngrams.insert(std::make_pair(u,Restaurant())); itr = res.first; double c_uwd = itr->second.c_uwd(w); double d_u = itr->second.d_u(); double t_uw = itr->second.t_uw(w); double theta_u = itr->second.theta_u(); double c_udd = itr->second.c_udd(); double t_ud = itr->second.t_ud(); if(u.size() == 0) return ((c_uwd - d_u * t_uw) + (theta_u + d_u * t_ud) * WordProbBase(w))/(theta_u + c_udd); u.pop_front(); return ((c_uwd - d_u * t_uw) + (theta_u + d_u * t_ud) * WordProbability(u,w))/(theta_u + c_udd); } //単語の追加 void AddCustomer(std::deque<std::string> u, const std::string &w){ std::deque<std::string> pu(u); pu.pop_front(); while(ngrams[u].AddCustomer(w,WordProbability(pu,w))){ u.pop_front(); if(pu.size()>0) pu.pop_front(); } } //単語の削除 void RemoveCustomer(std::deque<std::string> u, const std::string &w){ while(ngrams[u].RemoveCustomer(w)){ u.pop_front(); } } //ハイパーパラメータの更新 void UpdateHyperParams(){ std::map<std::deque<std::string>,Restaurant>::iterator itr; for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){ int m = itr->first.size(); int t_ud = itr->second.t_ud(); if(t_ud >= 2){ double tmp_am = 0, tmp_alpham = 0; for(int i=1; i<=t_ud-1; i++){ double yui = bernoulli_rand(theta_m[m]/(theta_m[m]+i*d_m[m])); tmp_am += 1.0 - yui; tmp_alpham += yui; } a_m[m] += tmp_am; alpha_m[m] += tmp_alpham; double xu = beta_rand(theta_m[m]+1.0, itr->second.c_udd()-1.0); beta_m[m] -= log(xu); } for(int i=0; i<(int)itr->second.tables.size(); i++){ if(itr->second.tables[i].second >= 2){ double tmp_bm = 0; for(int j=1; j<=(int)itr->second.tables[i].second-1; j++){ double zuwkj = bernoulli_rand((j-1.0)/(j-d_m[m])); tmp_bm += 1.0 - zuwkj; } b_m[m] += tmp_bm; } } } //dとthetaの更新 for(int m=0; m<n_val; m++){ d_m[m] = beta_rand(a_m[m], b_m[m]); theta_m[m] = gamma_rand(alpha_m[m], 1.0/beta_m[m]); } //各レストランに反映 for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){ int m = itr->first.size(); itr->second.UpdateParameters(d_m[m],theta_m[m]); } } //モデルの推定(Blocked Gibbs Sampler) void inference(const std::string &filename, int t_iteration, int t_burnin){ for(int t=0; t<t_iteration; t++){ if(t%1000==0) std::cout << "." << std::flush; std::ifstream ifs(filename.c_str()); //毎回ファイルから読み込むorz std::string text; //ngram読み込み while(std::getline(ifs, text)){ std::stringstream ss(text); std::deque<std::string> u; std::string tmp, w; for(int i=0; i<n_val-1; i++){ ss >> tmp; u.push_back(tmp); } ss >> w; if(t > 0) RemoveCustomer(u, w); AddCustomer(u, w); } //ハイパーパラメータの更新 UpdateHyperParams(); if(t >= t_burnin){//パラメータの値を保存しておく for(int i=0; i<n_val; i++){ history_d_m[i].push_back(d_m[i]); history_theta_m[i].push_back(theta_m[i]); } } } std::cout << std::endl; /*//確認用dとthetaの出力 for(int i=0; i<t_iteration-t_burnin; i++){ std::cerr << i << "\t"; for(int m=0; m<n_val; m++){ std::cerr << history_d_m[m][i] << "\t" << history_theta_m[m][i] << "\t"; } std::cerr << std::endl; } */ //burnin期間を除いたパラメータの平均値で最終的なd_mとtheta_mを決めてみる std::vector<double> ave_d_m(n_val), ave_theta_m(n_val); for(int m=0; m<n_val; m++){ int num = history_d_m[m].size(); for(int i=0; i<num; i++){ ave_d_m[m] += history_d_m[m][i]; ave_theta_m[m] += history_theta_m[m][i]; } ave_d_m[m] /= num; ave_theta_m[m] /= num; //平均値を使う d_m[m] = ave_d_m[m]; theta_m[m] = ave_theta_m[m]; } //全てのレストランに反映 std::map<std::deque<std::string>,Restaurant>::iterator itr; for(itr = ngrams.begin(); itr != ngrams.end(); ++itr){ int m = itr->first.size(); itr->second.UpdateParameters(d_m[m],theta_m[m]); } } }; int main(){ int n = 2; //bigram言語モデル int n_char = 10; //世界には10文字ぐらいあると仮定 HPYLM hpylm(n, n_char); hpylm.inference("ngram.data", 100000, 20000); //100000回反復(うち最初の20000回はburninとして捨てる) //モデルの状態を表示 std::cout << "======================" << std::endl; hpylm.show_ngrams(); std::cout << "======================" << std::endl; //各ngramの確率を求める std::string text; while(std::getline(std::cin, text)){ std::stringstream ss(text); std::deque<std::string> u; std::string tmp, w; for(int i=0; i<n-1; i++){ ss >> tmp; u.push_back(tmp); } ss >> w; std::cout << "Prob : " << hpylm.WordProbability(u, w) << std::endl; } return 0; }
結果
学習データ「ngram.data」
- 半角スペース区切り
今 日 今 日 今 日 今 日 今 日 今 日 今 日 今 日 今 日 今 日 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 今 月 明 日 明 日 明 日 明 日 明 日
実行結果
$ ./a.out ................................................................................ .................... ====================== Restaurant[](d=0.553945,theta=0.730763) 月 2 日 1 日 2 日 1 Restaurant[今](d=0.0573542,theta=0.595518) 日 10 月 8 月 12 Restaurant[明](d=0.0573542,theta=0.595518) 日 3 日 1 日 1 ====================== 今 日 Prob : 0.334784 今 月 Prob : 0.65643 今 年 Prob : 0.00109828 明 日 Prob : 0.916481 明 月 Prob : 0.0354769 明 年 Prob : 0.00600527 あ い Prob : 0.0437773
P(日|今)は、最尤推定だと0.333333ぐらい
P(月|今)は、最尤推定だと0.666666ぐらい
P(年|今)は、最尤推定だと0.0
P(日|明)は、最尤推定だと1.0
P(月|明)は、最尤推定だと0.0
P(年|明)は、最尤推定だと0.0
P(い|あ)は、最尤推定だと0.0
おぉ、、そこそこそれっぽい結果がでてるようにみえる。。。
学習データ中には「明月」はないけど、「今月」が出てる関係で、他の学習データにないものに比べちょっと確率が高くできてる。
「あい」は未知の語同士だからわからないはずだけど、ちゃんと確率値が与えられてる。
dとθの収束について
モデルが収束してるのかdとθの変化についてプロットしてみる。10万回分。
赤線:d_0
緑線:theta_0
青線:d_1
ピンク線:theta_1
まだバグがあるのかもしれないけど、えらい収束がゆっくりしてるように見える。
(200回程度じゃ全然収束してない、、、ので反復回数を修正)
あと、最終的な値に使うのが生成した値じゃなかったので、コードの方をdとtheta値を保存して平均をとるように変更。
うーん、、難しい。