TWCNB分類器を試す

はじめに

テキスト分類でよく使われるNaive Bayesにはいくつかの厳しい仮定や条件があり、それによって性能が落ちてしまっている。
経験則をいれたりして性能を向上させたTWCNB分類器を試してみる。

多項モデルによるNaiveBayes

  • l_MNB(d) = argmax_c{ log P(c) + Σ_w{f_w * log P(w|c) } }
    • l_MNB(d) : 多項モデルでの文書dの予測クラスラベル
    • P(c) : クラスcである確率
    • f_w : 文書での単語wの出現頻度
    • P(w|c) : クラスcでの単語wの出現確率
  • P(w|c)の推定値=(N_cw + α_w) / (Σ_w {N_cw} + 単語の種類数)
    • N_cw : クラスcで単語wが出現する訓練文書数
    • α_w : パラメータ(=1)
  • 【メモ】P(c)の推定値=(N_c + α_c) / (Σ_c {N_c} + クラス数)
    • N_c : クラスcの訓練文書数
    • α_c : パラメータ(=1)

Transformed Weight-normalized Complement Naive Bayes(TWCNB)とは

  • http://machinelearning.wustl.edu/mlpapers/paper_files/icml2003_RennieSTK03.pdf
  • Naive Bayes(ここでは多項モデル)にはたくさんの問題を含む
    • Skewed Data Bias
      • P(w|c)は、N_cwに依存していて、文書数が多いクラスの時に大きくなりやすい
      • クラスcについて考える時、クラスcを見るのではなく、クラスcでないもの(complement)を使って計算する
      • 一般に、文書数が増加したり、偏りが少なくなる
    • Weight Magnitude Errors
      • Naive Bayesの仮定によって重みベクトルの値が不当に大きくなってしまう
      • 正規化することで対処する
  • 上記の改良を加えたものを「Weight-normalized Complement Naive Bayes」とする
  • さらに、多項モデルのための文書データの変換を施す
    • Term Frequency
      • 経験的に分布がheavier-tails(裾が長い分布)になることを加える
    • Document Frequency
      • あまりよく出ない単語の影響を大きくし、よく出る単語の影響を小さくする
    • Document Length
      • 文書の長さが長いと単語の出現回数が出やすいので、変換する
  • 上記の改良を加えたものを「Transformed Weight-normalized Complement Naive Bayes」とする

使用したデータ

Yahoo! Japanのニュースのトピックス>バックナンバー>アーカイブから2013年6月のトピックスのデータを取得。
mecab(ipadic)で分かち書きして使用。
作成したデータ: https://github.com/jetbead/Prog/tree/master/TWCNB

※(2013/07/02追記)データはランダムにシャッフルして、2つに分割してる。素性は、形態素の表層のみ。

カテゴリは8つ。

番号 カテゴリ名 学習データの個数 テストデータの個数
0 computer 49 52
1 domestic 165 165
2 economy 117 137
3 entertainment 160 155
4 local 167 154
5 science 49 42
6 sports 187 180
7 world 106 102

コード

いつものごとく、どこか間違っているかもしれないけど、とりあえずそれっぽい結果がでているので、、、

 <iostream>
#include <fstream>
#include <set>
#include <map>
#include <vector>
#include <cmath>

class TWCNB {
  //クラス数
  int num_class;

  //ドキュメント情報
  std::vector< std::map<std::string,double> > dij;

  //単語集合
  std::set<std::string> wordset;

  //単語の重み
  std::vector< std::map<std::string,double> > weight;

  //Text Transformations
  void text_transformations(){
    //TF transform
    for(int j=0; j<dij.size(); j++){
      std::map<std::string,double>::iterator itr = dij[j].begin();
      while(itr != dij[j].end()){
        itr->second = log(itr->second + 1.0);
        ++itr;
      }
    }

    //IDF transform
    for(int j=0; j<dij.size(); j++){
      std::map<std::string,double>::iterator itr = dij[j].begin();
      while(itr != dij[j].end()){
        int occurs = 0;
        for(int jj=0; jj<dij.size(); jj++){
          if(dij[jj].count(itr->first)>0) occurs++;
        }
        itr->second = itr->second * log( (double)(dij.size()) / occurs );
        ++itr;
      }
    }
    
    //length norm
    for(int j=0; j<dij.size(); j++){
      double sum = 0.0;
      std::map<std::string,double>::iterator itr = dij[j].begin();
      while(itr != dij[j].end()){
        sum += itr->second * itr->second;
        ++itr;
      }
      itr = dij[j].begin();
      while(itr != dij[j].end()){
        itr->second = itr->second / sqrt( sum );
        ++itr;
      }
    }
  }
  
  //単語重みweightの作成
  void calc_weight(const std::vector< std::pair< int, std::vector<std::string> > >& docs){
    for(int c=0; c<num_class; c++){
      weight.push_back(std::map<std::string,double>());
    }

    //complement class
    for(int c=0; c<num_class; c++){
      std::set<std::string>::iterator itr = wordset.begin();
      while(itr != wordset.end()){
        double sum_dij = 0.0;
        double sum_dkj = 0.0;
        for(int j=0; j<docs.size(); j++){
          if(docs[j].first != c){

            std::map<std::string,double>::iterator tmp = dij[j].begin();
            while(tmp != dij[j].end()){
              if(tmp->first == *itr){
                sum_dij += tmp->second;
              }
              sum_dkj += tmp->second;

              ++tmp;
            }
           
          }
        }
        weight[c][*itr] = (sum_dij + 1.0) / (sum_dkj + wordset.size());
        
        ++itr;
      }
    }
    
    //calc weight & normalization
    for(int c=0; c<num_class; c++){
      //std::cout << "class " << c << "=====================" << std::endl;
      double sum = 0.0;
      std::map<std::string,double>::iterator itr = weight[c].begin();
      while(itr != weight[c].end()){
        itr->second = log( itr->second );
        sum += fabs(itr->second);
        ++itr;
      }
      itr = weight[c].begin();
      while(itr != weight[c].end()){
        itr->second = itr->second / sum;
        //std::cout << itr->first << "\t" << itr->second << std::endl;
        ++itr;
      }
    }

  }

  //クラスcの単語wordの重み
  double get_weight(int c, const std::string& word){
    if(weight[c].count(word)==0) return 0.0;
    return weight[c][word];
  }

public:
  
  TWCNB(int num_class):num_class(num_class){}

  //学習
  void train(const std::vector< std::pair< int, std::vector<std::string> > >& docs){
    
    //単語集合wordsetとドキュメント情報dijの作成
    for(int j=0; j<docs.size(); j++){
      std::map<std::string,double> tmp;
      for(int i=0; i<docs[j].second.size(); i++){
        tmp[docs[j].second[i]] += 1.0;
        wordset.insert(docs[j].second[i]);
      }
      dij.push_back(tmp);
    }

    //Text Transformations
    text_transformations();

    //重みの計算
    calc_weight(docs);

  }

  //予測
  int predict(std::vector<std::string>& data){
    int ret = -1;
    double retv = 100000000.0;
    for(int c=0; c<num_class; c++){
      //std::cout << "class " << c << "===========" << std::endl;
      double val = 0.0;
      for(int i=0; i<data.size(); i++){
        //std::cout << data[i] << " " << get_weight(c, data[i]) << std::endl;
        val += get_weight(c, data[i]);
      }
      if(retv > val){
        ret = c;
        retv = val;
      }
      //std::cout << "class " << c << " : " << val << std::endl;
    }
    return ret;
  }

};


std::vector<std::string> split(const std::string &str){
  std::vector<std::string> ret;
  std::string tmp = "";
  for(int i=0; i<str.length(); i++){
    if(tmp != "" && str[i] == ' '){
      ret.push_back(tmp);
      tmp = "";
    }
    else if(str[i] != ' '){
      tmp += str[i];
    }
  }
  if(tmp != "") ret.push_back(tmp);
  return ret;
}

int main(int argc, char** argv){
  if(argc != 2) return 1;

  TWCNB twcnb(8); //クラス数を指定

  int type;
  std::string line;

  //train
  std::ifstream ifs(argv[1]);
  std::vector< std::pair< int, std::vector<std::string> > > docs;

  while(ifs >> type){
    std::getline(ifs, line);
    std::vector<std::string> doc = split(line);
    
    docs.push_back(std::make_pair< int, std::vector<std::string> >(type, doc));
  }
  
  twcnb.train(docs);
  

  //predict
  int num = 0, corr = 0;
  while(std::cin >> type){
    std::getline(std::cin, line);
    std::vector<std::string> doc = split(line);
    
    int res = twcnb.predict(doc);
    std::cout << res << "(correct:" << type << ")" << std::endl;
    if(res == type) corr++;
    num++;
  }

  std::cout << "Acc:" << (corr * 100.0 / num) << "% ";
  std::cout << "(" << corr << " / " << num << ")" << std::endl;

  return 0;
}

結果

#close
$ ./a.out train < train
...
Acc:99.3% (993 / 1000)

#open
$ ./a.out train < test
...
Acc:63.8298% (630 / 987)

結構正解できている?
トピックスのタイトル文は特殊(長さが同じ、あまり同じ単語が出現しにくい?、など)のようにも思われるので、経験的な部分が少し違ってうまく機能していないかも。

追記(20130702)

ちょっと上のデータの作り方がアレすぎるので、量を増やして、それっぽい感じの設定に直してみた。
素性は形態素の表層のみのまま。精度を上げることが目的ではないので、、(逃)

データセット

学習データ:2013年3月、4月、5月の3ヶ月分のYahoo! Japanのトピックスのデータを全て合わせたもの(5316件)
テストデータ:2013年6月のYahoo! Japanのトピックスのデータすべて(1802件)

結果
$ ./a.out train < test
...
Acc:70.1998% (1265 / 1802)

各クラスのweightの小さいもの・大きいもの10個ずつ見てみる。

====== class 0=======
## MIN ##
の	-6.59209e-05
に	-6.748e-05
で	-6.80728e-05
が	-6.92568e-05
「	-7.12277e-05
」	-7.12845e-05
、	-7.23602e-05
を	-7.23651e-05
は	-7.23863e-05
へ	-7.24419e-05
## MAX ##
台数	-0.000125664
野望	-0.000125664
友達	-0.000125664
協業	-0.000125664
量販	-0.000125664
包囲	-0.000125664
かまっ	-0.000125664
勢い	-0.000125664
刷新	-0.000125664
初代	-0.000125664

====== class 1=======
## MIN ##
の	-6.68606e-05
に	-6.80059e-05
で	-6.86516e-05
が	-6.92855e-05
、	-7.22937e-05
「	-7.23023e-05
」	-7.23779e-05
を	-7.33895e-05
は	-7.38445e-05
へ	-7.48796e-05
## MAX ##
投開票	-0.000124776
付与	-0.000124776
貧困	-0.000124776
抜い	-0.000124776
矢	-0.000124776
貫け	-0.000124776
貯水	-0.000124776
貯蓄	-0.000124776
貯蔵	-0.000124776
貴子	-0.000124776

====== class 2=======
## MIN ##
の	-6.61113e-05
に	-6.78181e-05
で	-6.84166e-05
が	-6.9281e-05
「	-7.09723e-05
」	-7.10321e-05
は	-7.25993e-05
を	-7.26664e-05
、	-7.38554e-05
へ	-7.40872e-05
## MAX ##
紳士	-0.000125089
左遷	-0.000125089
巨額	-0.000125089
常	-0.000125089
乳幼児	-0.000125089
メーデー	-0.000125089
わり	-0.000125089
アイ	-0.000125089
乱高下	-0.000125089
モス	-0.000125089

====== class 3=======
## MIN ##
の	-6.61881e-05
で	-6.80801e-05
に	-6.80891e-05
が	-6.98152e-05
」	-7.18953e-05
「	-7.19192e-05
へ	-7.21485e-05
は	-7.26386e-05
を	-7.2776e-05
、	-7.35003e-05
## MAX ##
公演	-0.000124642
公式	-0.000124642
八田	-0.000124642
八方	-0.000124642
八幡	-0.000124642
全開	-0.000124642
全裸	-0.000124642
全力	-0.000124642
歌声	-0.000124642
入館	-0.000124642

====== class 4=======
## MIN ##
の	-6.62503e-05
に	-6.83131e-05
が	-6.94653e-05
で	-6.95163e-05
「	-7.07731e-05
」	-7.08328e-05
は	-7.18399e-05
へ	-7.2331e-05
を	-7.27115e-05
、	-7.2759e-05
## MAX ##
スキミング	-0.000124898
スクープ	-0.000124898
スタジアム	-0.000124898
血	-0.000124898
看守	-0.000124898
構図	-0.000124898
血痕	-0.000124898
危篤	-0.000124898
相模原	-0.000124898
道頓堀	-0.000124898

====== class 5=======
## MIN ##
の	-6.59822e-05
に	-6.75336e-05
で	-6.8434e-05
が	-6.8839e-05
「	-7.07585e-05
」	-7.08116e-05
を	-7.20987e-05
は	-7.21595e-05
へ	-7.23278e-05
、	-7.23724e-05
## MAX ##
一元	-0.000125608
通院	-0.000125608
通説	-0.000125608
両目	-0.000125608
逆流	-0.000125608
追わ	-0.000125608
交配	-0.000125608
人工	-0.000125608
今世紀	-0.000125608
近親	-0.000125608

====== class 6=======
## MIN ##
の	-6.64465e-05
に	-6.79847e-05
で	-6.89817e-05
が	-7.03666e-05
を	-7.19457e-05
「	-7.22785e-05
」	-7.23553e-05
へ	-7.26176e-05
は	-7.30116e-05
、	-7.38334e-05
## MAX ##
申し立て	-0.000124588
ラミ	-0.000124588
ラフプレー	-0.000124588
憲	-0.000124588
ラスト	-0.000124588
甲斐	-0.000124588
甲子園	-0.000124588
ライト	-0.000124588
由	-0.000124588
ライアン	-0.000124588

====== class 7=======
## MIN ##
の	-6.64092e-05
に	-6.79787e-05
で	-6.89882e-05
が	-6.99935e-05
「	-7.0927e-05
」	-7.09624e-05
は	-7.25251e-05
へ	-7.26899e-05
、	-7.27488e-05
を	-7.29493e-05
## MAX ##
闇	-0.000125231
閲覧	-0.000125231
ミッソーニ	-0.000125231
標的	-0.000125231
ムシャラフ	-0.000125231
ムバラク	-0.000125231
案内	-0.000125231
参謀	-0.000125231
メンネア	-0.000125231
モスクワ	-0.000125231

語彙力。

その他

TWCNBは経験的にわかっている現象を取り入れることで精度を向上しているが、
数学的に導出されたものではない。
以下など、数学的に導出したものを取り扱うものが提案されているよう。
http://www.ninjal.ac.jp/event/specialists/project-meeting/files/JCLWorkshop_no1_papers/JCLWorkshop2012_15.pdf