ナイーブベイズ
ナイーブベイズ分類器とは?
- 古典的な分類器
- 事例dがどのクラスcに属するかを判定する
- 例えば、あるメール(事例)がスパム(クラス)かどうか?
事例dが与えられたとき、各クラスc1,c2,c3に属する確率はP(c1|d),P(c2|d),P(c3|d)になる。これを計算するためにベイズの定理
を使って、P(c|d)が最大になるクラスを求める。すなわち、
を計算する。
問題はP(d|c)で、事例dを1つ1つ調べてそのまま(単語の並びや種類を分解して考えず)計算に使うのはあまりにも非現実すぎる。なので、事例dを簡単にしたモデルで考える。
- 単語の種類だけを考える(多変数ベルヌーイモデル)
- その単語が含まれるかどうか?
- 単語の種類と回数を考える(多項モデル)
- その単語が何回含まれるか?
多変数ベルヌーイモデルのパラメータ推定
その単語が含まれるかどうかを式にしたい。ある単語wがあるクラスcにでてくる確率をとする。
出てこない確率はなので、出てくるかどうかをで表すと、
が事例dの生起確率になる。(スパムならこの単語が出やすいなどが計算される)
最尤推定
教師データからパラメータ、を具体的に求めるために、最尤推定を用いる。最適化問題は、
になる。
ナイーブベイズではこの最適化問題を解いて得られるパラメータ、が、
Pw,c = (クラスcに属する教師データで単語wが含まれる文書数)/(クラスcに属する教師データ数)
Pc = (クラスcに属する教師データ数)/(教師データ数)
と簡単な式になる。(導出は本の4.2)
# 「教師信号 文書」形式で与える ## 教師信号: 0以上の数値はクラスの番号、-1の場合はそれまでの学習で予測を行う 0 good bad good good fine 0 exciting exciting 0 good good exciting boring 1 bad boring boring boring 1 bad good bad 1 bad bad boring exciting -1 bad bad boring boring fine
#include <algorithm> #include <iostream> #include <vector> #include <map> #include <set> using namespace std; //多変数ベルヌーイモデルによるナイーブベイズ分類器(最尤推定) //教師情報付きドキュメント struct Document { int class_no; //クラスの番号 set<string> words; //出現する単語リスト void add(string word){ words.insert(word); } }; //ナイーブベイズ分類器 class NaiveBayes { //クラス数 int m_num_of_class; //出現したすべての単語 set<string> m_all_words; //各クラスに属するドキュメント数 vector<int> m_num_docs; //各クラスにおいてでてくる単語を含む文書数 vector< map<string,int> > m_class; //出現したドキュメントの総数 int m_num_of_all_docs; public: //コンストラクタ // num_of_class : 分類したいクラスの数 NaiveBayes(int num_of_class){ m_num_of_class = num_of_class; m_num_of_all_docs = 0; for(int i=0; i<m_num_of_class; i++){ m_num_docs.push_back(0); m_class.push_back(map<string,int>()); } } //訓練関数 // docs : 訓練したい教師データ付きドキュメント void train(const Document& docs){ if(docs.class_no < 0 || docs.class_no >= m_num_of_class){ cerr << "invalid train-data." << endl; return; } m_num_docs[docs.class_no]++; m_num_of_all_docs++; set<string>::const_iterator itr = docs.words.begin(); for(; itr != docs.words.end(); ++itr){ m_class[docs.class_no][(*itr)]++; m_all_words.insert(*itr); } } //予測関数 // words : 予測したい単語リスト int predict(vector<string>& words){ vector<double> ret(m_num_of_class, 0.0); //各クラスにおけるドキュメントの確率 //各クラスについて、ドキュメントの確率を求める for(int c=0; c<m_num_of_class; c++){ double Pc = m_num_docs[c]/(double)m_num_of_all_docs; double Pwc_all = 1.0; set<string>::iterator itr = m_all_words.begin(); for(; itr != m_all_words.end(); ++itr){ double Pwc = m_class[c][*itr]/(double)m_num_docs[c]; if(find(words.begin(), words.end(), *itr) != words.end()){ //出てきた場合 Pwc_all *= Pwc; }else{ //出てこなかった場合 Pwc_all *= (1.0-Pwc); } } ret[c] = Pc * Pwc_all; } //確率が最大なクラス番号を返す int ret_idx = -1; double maxP = -1.0; for(int i=0; i<m_num_of_class; i++){ cerr << "class " << i << ": " << ret[i] << endl; if(maxP<ret[i]){ maxP = ret[i]; ret_idx = i; } } return ret_idx; } }; vector<string> split(const string &str){ vector<string> ret; string tmp = ""; for(int i=0; i<str.length(); i++){ if(tmp != "" && str[i] == ' '){ ret.push_back(tmp); tmp = ""; } else if(str[i] != ' '){ tmp += str[i]; } } if(tmp != "") ret.push_back(tmp); return ret; } int main(){ NaiveBayes nb(2); int c_type; string str; while(cin >> c_type){ getline(cin, str); //学習 if(c_type >= 0){ Document d; d.class_no = c_type; vector<string> wrds = split(str); for(int i=0; i<wrds.size(); i++){ d.add(wrds[i]); } nb.train(d); } //予測 else{ vector<string> wrds = split(str); cout << nb.predict(wrds) << endl; } } return 0; }
最大事後確率推定
最尤推定は実はまずいことがあって、それは「データに出現しなかった単語はPw,c=0になってしまう」こと。
上の例だと、「fine」という単語がクラス1にでてきていないので、ほとんどの単語がクラス1でもクラス0と判定されてしまう。
そこで、本では、事前分布にディリクレ分布を仮定し最大事後確率推定する方法を紹介している。
なぜディリクレ分布かというと、0や1付近の数値をとりずらい分布になっているので、Pw,cが極端な値になりにくくなる。
最尤推定と同様に、最大事後確率推定での最適化問題
を解くと、
Pw,c = (クラスcに属する教師データで単語wが含まれる文書数 + (α-1) )/(クラスcに属する教師データ数 + 2(α-1) )
Pc = (クラスcに属する教師データ数 + (α-1) )/(教師データ数 + クラス数 * (α-1) )
となる。これは最尤推定の結果にすこし下駄をはかせた結果になる。
結果として、最大事後確率推定では確率が0.0にならず、直感的に正しい結果が得られる。
#include <algorithm> #include <iostream> #include <vector> #include <map> #include <set> using namespace std; //多変数ベルヌーイモデルによるナイーブベイズ分類器(MAP推定) //教師情報付きドキュメント struct Document { int class_no; //クラスの番号 set<string> words; //出現する単語リスト void add(string word){ words.insert(word); } }; //ナイーブベイズ分類器 class NaiveBayes { //クラス数 int m_num_of_class; //出現したすべての単語 set<string> m_all_words; //各クラスに属するドキュメント数 vector<int> m_num_docs; //各クラスにおいてでてくる単語を含む文書数 vector< map<string,int> > m_class; //出現したドキュメントの総数 int m_num_of_all_docs; //ディリクレ分布のパラメータ double m_alpha; public: //コンストラクタ // num_of_class : 分類したいクラスの数 NaiveBayes(int num_of_class, double alpha){ m_alpha = alpha; m_num_of_class = num_of_class; m_num_of_all_docs = 0; for(int i=0; i<m_num_of_class; i++){ m_num_docs.push_back(0); m_class.push_back(map<string,int>()); } } //訓練関数 // docs : 訓練したい教師データ付きドキュメント void train(const Document& docs){ if(docs.class_no < 0 || docs.class_no >= m_num_of_class){ cerr << "invalid train-data." << endl; return; } m_num_docs[docs.class_no]++; m_num_of_all_docs++; set<string>::const_iterator itr = docs.words.begin(); for(; itr != docs.words.end(); ++itr){ m_class[docs.class_no][(*itr)]++; m_all_words.insert(*itr); } } //予測関数 // words : 予測したい単語リスト int predict(vector<string>& words){ vector<double> ret(m_num_of_class, 0.0); //各クラスにおけるドキュメントの確率 //各クラスについて、ドキュメントの確率を求める for(int c=0; c<m_num_of_class; c++){ double Pc = (m_num_docs[c] + (m_alpha - 1.0))/((double)m_num_of_all_docs + m_num_of_class * (m_alpha - 1.0)); double Pwc_all = 1.0; set<string>::iterator itr = m_all_words.begin(); for(; itr != m_all_words.end(); ++itr){ double Pwc = (m_class[c][*itr] + (m_alpha - 1.0))/((double)m_num_docs[c] + 2 * (m_alpha - 1.0)); if(find(words.begin(), words.end(), *itr) != words.end()){ //出てきた場合 Pwc_all *= Pwc; }else{ //出てこなかった場合 Pwc_all *= (1.0-Pwc); } } ret[c] = Pc * Pwc_all; } //確率が最大なクラス番号を返す int ret_idx = -1; double maxP = -1.0; for(int i=0; i<m_num_of_class; i++){ cerr << "class " << i << ": " << ret[i] << endl; if(maxP<ret[i]){ maxP = ret[i]; ret_idx = i; } } return ret_idx; } }; vector<string> split(const string &str){ vector<string> ret; string tmp = ""; for(int i=0; i<str.length(); i++){ if(tmp != "" && str[i] == ' '){ ret.push_back(tmp); tmp = ""; } else if(str[i] != ' '){ tmp += str[i]; } } if(tmp != "") ret.push_back(tmp); return ret; } int main(){ NaiveBayes nb(2, 2); int c_type; string str; while(cin >> c_type){ getline(cin, str); //学習 if(c_type >= 0){ Document d; d.class_no = c_type; vector<string> wrds = split(str); for(int i=0; i<wrds.size(); i++){ d.add(wrds[i]); } nb.train(d); } //予測 else{ vector<string> wrds = split(str); cout << nb.predict(wrds) << endl; } } return 0; }