TWCNB分類器を試す
はじめに
テキスト分類でよく使われるNaive Bayesにはいくつかの厳しい仮定や条件があり、それによって性能が落ちてしまっている。
経験則をいれたりして性能を向上させたTWCNB分類器を試してみる。
多項モデルによるNaiveBayes
- l_MNB(d) = argmax_c{ log P(c) + Σ_w{f_w * log P(w|c) } }
- l_MNB(d) : 多項モデルでの文書dの予測クラスラベル
- P(c) : クラスcである確率
- f_w : 文書での単語wの出現頻度
- P(w|c) : クラスcでの単語wの出現確率
- P(w|c)の推定値=(N_cw + α_w) / (Σ_w {N_cw} + 単語の種類数)
- N_cw : クラスcで単語wが出現する訓練文書数
- α_w : パラメータ(=1)
- 【メモ】P(c)の推定値=(N_c + α_c) / (Σ_c {N_c} + クラス数)
- N_c : クラスcの訓練文書数
- α_c : パラメータ(=1)
Transformed Weight-normalized Complement Naive Bayes(TWCNB)とは
- http://machinelearning.wustl.edu/mlpapers/paper_files/icml2003_RennieSTK03.pdf
- Naive Bayes(ここでは多項モデル)にはたくさんの問題を含む
- Skewed Data Bias
- P(w|c)は、N_cwに依存していて、文書数が多いクラスの時に大きくなりやすい
- クラスcについて考える時、クラスcを見るのではなく、クラスcでないもの(complement)を使って計算する
- 一般に、文書数が増加したり、偏りが少なくなる
- Weight Magnitude Errors
- Naive Bayesの仮定によって重みベクトルの値が不当に大きくなってしまう
- 正規化することで対処する
- Skewed Data Bias
- 上記の改良を加えたものを「Weight-normalized Complement Naive Bayes」とする
- さらに、多項モデルのための文書データの変換を施す
- Term Frequency
- 経験的に分布がheavier-tails(裾が長い分布)になることを加える
- Document Frequency
- あまりよく出ない単語の影響を大きくし、よく出る単語の影響を小さくする
- Document Length
- 文書の長さが長いと単語の出現回数が出やすいので、変換する
- Term Frequency
- 上記の改良を加えたものを「Transformed Weight-normalized Complement Naive Bayes」とする
使用したデータ
Yahoo! Japanのニュースのトピックス>バックナンバー>アーカイブから2013年6月のトピックスのデータを取得。
mecab(ipadic)で分かち書きして使用。
作成したデータ: https://github.com/jetbead/Prog/tree/master/TWCNB
※(2013/07/02追記)データはランダムにシャッフルして、2つに分割してる。素性は、形態素の表層のみ。
カテゴリは8つ。
番号 | カテゴリ名 | 学習データの個数 | テストデータの個数 |
---|---|---|---|
0 | computer | 49 | 52 |
1 | domestic | 165 | 165 |
2 | economy | 117 | 137 |
3 | entertainment | 160 | 155 |
4 | local | 167 | 154 |
5 | science | 49 | 42 |
6 | sports | 187 | 180 |
7 | world | 106 | 102 |
コード
いつものごとく、どこか間違っているかもしれないけど、とりあえずそれっぽい結果がでているので、、、
<iostream> #include <fstream> #include <set> #include <map> #include <vector> #include <cmath> class TWCNB { //クラス数 int num_class; //ドキュメント情報 std::vector< std::map<std::string,double> > dij; //単語集合 std::set<std::string> wordset; //単語の重み std::vector< std::map<std::string,double> > weight; //Text Transformations void text_transformations(){ //TF transform for(int j=0; j<dij.size(); j++){ std::map<std::string,double>::iterator itr = dij[j].begin(); while(itr != dij[j].end()){ itr->second = log(itr->second + 1.0); ++itr; } } //IDF transform for(int j=0; j<dij.size(); j++){ std::map<std::string,double>::iterator itr = dij[j].begin(); while(itr != dij[j].end()){ int occurs = 0; for(int jj=0; jj<dij.size(); jj++){ if(dij[jj].count(itr->first)>0) occurs++; } itr->second = itr->second * log( (double)(dij.size()) / occurs ); ++itr; } } //length norm for(int j=0; j<dij.size(); j++){ double sum = 0.0; std::map<std::string,double>::iterator itr = dij[j].begin(); while(itr != dij[j].end()){ sum += itr->second * itr->second; ++itr; } itr = dij[j].begin(); while(itr != dij[j].end()){ itr->second = itr->second / sqrt( sum ); ++itr; } } } //単語重みweightの作成 void calc_weight(const std::vector< std::pair< int, std::vector<std::string> > >& docs){ for(int c=0; c<num_class; c++){ weight.push_back(std::map<std::string,double>()); } //complement class for(int c=0; c<num_class; c++){ std::set<std::string>::iterator itr = wordset.begin(); while(itr != wordset.end()){ double sum_dij = 0.0; double sum_dkj = 0.0; for(int j=0; j<docs.size(); j++){ if(docs[j].first != c){ std::map<std::string,double>::iterator tmp = dij[j].begin(); while(tmp != dij[j].end()){ if(tmp->first == *itr){ sum_dij += tmp->second; } sum_dkj += tmp->second; ++tmp; } } } weight[c][*itr] = (sum_dij + 1.0) / (sum_dkj + wordset.size()); ++itr; } } //calc weight & normalization for(int c=0; c<num_class; c++){ //std::cout << "class " << c << "=====================" << std::endl; double sum = 0.0; std::map<std::string,double>::iterator itr = weight[c].begin(); while(itr != weight[c].end()){ itr->second = log( itr->second ); sum += fabs(itr->second); ++itr; } itr = weight[c].begin(); while(itr != weight[c].end()){ itr->second = itr->second / sum; //std::cout << itr->first << "\t" << itr->second << std::endl; ++itr; } } } //クラスcの単語wordの重み double get_weight(int c, const std::string& word){ if(weight[c].count(word)==0) return 0.0; return weight[c][word]; } public: TWCNB(int num_class):num_class(num_class){} //学習 void train(const std::vector< std::pair< int, std::vector<std::string> > >& docs){ //単語集合wordsetとドキュメント情報dijの作成 for(int j=0; j<docs.size(); j++){ std::map<std::string,double> tmp; for(int i=0; i<docs[j].second.size(); i++){ tmp[docs[j].second[i]] += 1.0; wordset.insert(docs[j].second[i]); } dij.push_back(tmp); } //Text Transformations text_transformations(); //重みの計算 calc_weight(docs); } //予測 int predict(std::vector<std::string>& data){ int ret = -1; double retv = 100000000.0; for(int c=0; c<num_class; c++){ //std::cout << "class " << c << "===========" << std::endl; double val = 0.0; for(int i=0; i<data.size(); i++){ //std::cout << data[i] << " " << get_weight(c, data[i]) << std::endl; val += get_weight(c, data[i]); } if(retv > val){ ret = c; retv = val; } //std::cout << "class " << c << " : " << val << std::endl; } return ret; } }; std::vector<std::string> split(const std::string &str){ std::vector<std::string> ret; std::string tmp = ""; for(int i=0; i<str.length(); i++){ if(tmp != "" && str[i] == ' '){ ret.push_back(tmp); tmp = ""; } else if(str[i] != ' '){ tmp += str[i]; } } if(tmp != "") ret.push_back(tmp); return ret; } int main(int argc, char** argv){ if(argc != 2) return 1; TWCNB twcnb(8); //クラス数を指定 int type; std::string line; //train std::ifstream ifs(argv[1]); std::vector< std::pair< int, std::vector<std::string> > > docs; while(ifs >> type){ std::getline(ifs, line); std::vector<std::string> doc = split(line); docs.push_back(std::make_pair< int, std::vector<std::string> >(type, doc)); } twcnb.train(docs); //predict int num = 0, corr = 0; while(std::cin >> type){ std::getline(std::cin, line); std::vector<std::string> doc = split(line); int res = twcnb.predict(doc); std::cout << res << "(correct:" << type << ")" << std::endl; if(res == type) corr++; num++; } std::cout << "Acc:" << (corr * 100.0 / num) << "% "; std::cout << "(" << corr << " / " << num << ")" << std::endl; return 0; }
結果
#close $ ./a.out train < train ... Acc:99.3% (993 / 1000) #open $ ./a.out train < test ... Acc:63.8298% (630 / 987)
結構正解できている?
トピックスのタイトル文は特殊(長さが同じ、あまり同じ単語が出現しにくい?、など)のようにも思われるので、経験的な部分が少し違ってうまく機能していないかも。
追記(20130702)
ちょっと上のデータの作り方がアレすぎるので、量を増やして、それっぽい感じの設定に直してみた。
素性は形態素の表層のみのまま。精度を上げることが目的ではないので、、(逃)
データセット
学習データ:2013年3月、4月、5月の3ヶ月分のYahoo! Japanのトピックスのデータを全て合わせたもの(5316件)
テストデータ:2013年6月のYahoo! Japanのトピックスのデータすべて(1802件)
結果
$ ./a.out train < test ... Acc:70.1998% (1265 / 1802)
各クラスのweightの小さいもの・大きいもの10個ずつ見てみる。
====== class 0======= ## MIN ## の -6.59209e-05 に -6.748e-05 で -6.80728e-05 が -6.92568e-05 「 -7.12277e-05 」 -7.12845e-05 、 -7.23602e-05 を -7.23651e-05 は -7.23863e-05 へ -7.24419e-05 ## MAX ## 台数 -0.000125664 野望 -0.000125664 友達 -0.000125664 協業 -0.000125664 量販 -0.000125664 包囲 -0.000125664 かまっ -0.000125664 勢い -0.000125664 刷新 -0.000125664 初代 -0.000125664 ====== class 1======= ## MIN ## の -6.68606e-05 に -6.80059e-05 で -6.86516e-05 が -6.92855e-05 、 -7.22937e-05 「 -7.23023e-05 」 -7.23779e-05 を -7.33895e-05 は -7.38445e-05 へ -7.48796e-05 ## MAX ## 投開票 -0.000124776 付与 -0.000124776 貧困 -0.000124776 抜い -0.000124776 矢 -0.000124776 貫け -0.000124776 貯水 -0.000124776 貯蓄 -0.000124776 貯蔵 -0.000124776 貴子 -0.000124776 ====== class 2======= ## MIN ## の -6.61113e-05 に -6.78181e-05 で -6.84166e-05 が -6.9281e-05 「 -7.09723e-05 」 -7.10321e-05 は -7.25993e-05 を -7.26664e-05 、 -7.38554e-05 へ -7.40872e-05 ## MAX ## 紳士 -0.000125089 左遷 -0.000125089 巨額 -0.000125089 常 -0.000125089 乳幼児 -0.000125089 メーデー -0.000125089 わり -0.000125089 アイ -0.000125089 乱高下 -0.000125089 モス -0.000125089 ====== class 3======= ## MIN ## の -6.61881e-05 で -6.80801e-05 に -6.80891e-05 が -6.98152e-05 」 -7.18953e-05 「 -7.19192e-05 へ -7.21485e-05 は -7.26386e-05 を -7.2776e-05 、 -7.35003e-05 ## MAX ## 公演 -0.000124642 公式 -0.000124642 八田 -0.000124642 八方 -0.000124642 八幡 -0.000124642 全開 -0.000124642 全裸 -0.000124642 全力 -0.000124642 歌声 -0.000124642 入館 -0.000124642 ====== class 4======= ## MIN ## の -6.62503e-05 に -6.83131e-05 が -6.94653e-05 で -6.95163e-05 「 -7.07731e-05 」 -7.08328e-05 は -7.18399e-05 へ -7.2331e-05 を -7.27115e-05 、 -7.2759e-05 ## MAX ## スキミング -0.000124898 スクープ -0.000124898 スタジアム -0.000124898 血 -0.000124898 看守 -0.000124898 構図 -0.000124898 血痕 -0.000124898 危篤 -0.000124898 相模原 -0.000124898 道頓堀 -0.000124898 ====== class 5======= ## MIN ## の -6.59822e-05 に -6.75336e-05 で -6.8434e-05 が -6.8839e-05 「 -7.07585e-05 」 -7.08116e-05 を -7.20987e-05 は -7.21595e-05 へ -7.23278e-05 、 -7.23724e-05 ## MAX ## 一元 -0.000125608 通院 -0.000125608 通説 -0.000125608 両目 -0.000125608 逆流 -0.000125608 追わ -0.000125608 交配 -0.000125608 人工 -0.000125608 今世紀 -0.000125608 近親 -0.000125608 ====== class 6======= ## MIN ## の -6.64465e-05 に -6.79847e-05 で -6.89817e-05 が -7.03666e-05 を -7.19457e-05 「 -7.22785e-05 」 -7.23553e-05 へ -7.26176e-05 は -7.30116e-05 、 -7.38334e-05 ## MAX ## 申し立て -0.000124588 ラミ -0.000124588 ラフプレー -0.000124588 憲 -0.000124588 ラスト -0.000124588 甲斐 -0.000124588 甲子園 -0.000124588 ライト -0.000124588 由 -0.000124588 ライアン -0.000124588 ====== class 7======= ## MIN ## の -6.64092e-05 に -6.79787e-05 で -6.89882e-05 が -6.99935e-05 「 -7.0927e-05 」 -7.09624e-05 は -7.25251e-05 へ -7.26899e-05 、 -7.27488e-05 を -7.29493e-05 ## MAX ## 闇 -0.000125231 閲覧 -0.000125231 ミッソーニ -0.000125231 標的 -0.000125231 ムシャラフ -0.000125231 ムバラク -0.000125231 案内 -0.000125231 参謀 -0.000125231 メンネア -0.000125231 モスクワ -0.000125231
語彙力。
その他
TWCNBは経験的にわかっている現象を取り入れることで精度を向上しているが、
数学的に導出されたものではない。
以下など、数学的に導出したものを取り扱うものが提案されているよう。
http://www.ninjal.ac.jp/event/specialists/project-meeting/files/JCLWorkshop_no1_papers/JCLWorkshop2012_15.pdf
参考文献
- http://machinelearning.wustl.edu/mlpapers/paper_files/icml2003_RennieSTK03.pdf
- http://d.hatena.ne.jp/tkng/20081217/1229475900
- http://d.hatena.ne.jp/kisa12012/20110520/1305888712
- http://ibisforest.org/index.php?complement%20naive%20Bayes
- http://d.hatena.ne.jp/ruby-U/20100504/1272946467
- http://d.hatena.ne.jp/y_tag/20110213/complement_naive_bayes
- http://cseweb.ucsd.edu/~elkan/254/NaiveBayesForText.pdf