CRFを試す

はじめに

条件付き確率場(Conditional Random Fields)を実装してみた。
本の式導出がわからなくて、夜な夜なmac book airを涙で濡らしながら書いたので、あやしい。

説明

  • 基本的に「言語処理のための機械学習入門」の本の通りに書いた(つもり、、、)
    • すごく自信ない、勉強用
  • ダミーラベルのB、EはそれぞれBOS、EOS
  • φ(x,yt,yt-1)ってなってるけど、とりあえずφ_i(xt,yt,yt-1)を素性として利用(「単語」と「その品詞」と「その一つ前の単語の品詞」)
  • 最急勾配法
  • L2正則化
  • forward-backwardの部分は何もやってない

コード

#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <algorithm>
#include <cmath>
#include <map>
#include <set>

using namespace std;

class CRF {
  //重みベクトル
  std::map<std::string,double> w; //素性リスト(Φ(x,yt,yt-1))
  //出現した単語およびラベル
  std::set<std::string> words; //単語リスト
  std::vector<std::string> labels; //素性リスト(0:BOS 1:EOS 2以降:label)

  ///学習////////////////////////////////////
  //周辺確率P(yt, yt-1|x)の計算
  double marginal_prob(int t, int ytt, int yt, const std::vector<std::string>& x, const std::vector<std::string>& y, const std::vector< std::vector<double> >& alpha, const std::vector< std::vector<double> >& beta){
    double phi_t = exp( w[x[t]+"_"+labels[ytt]+"_"+labels[yt]] );
    double Z = 0.0;

    for(int i=0; i<labels.size(); i++){
      Z += alpha[i][y.size()-1];
    }

    return (phi_t / Z) * alpha[ytt][t-1] * beta[yt][t];
  }

  //各学習データ(xi,yi)から勾配を計算
  void train_from_data(std::map<std::string,double> &dw, const std::vector<std::string>& x, const std::vector<std::string>& y){
    std::vector< std::vector<double> > alpha(labels.size(), std::vector<double>(y.size(), 0.0));
    std::vector< std::vector<double> > beta(labels.size(), std::vector<double>(y.size(), 0.0));
 
    //Φ(x^(i),y^(i))
    for(int j=1; j<y.size(); j++){
      dw[x[j]+"_"+y[j-1]+"_"+y[j]] += 1.0;
    }
       
    //forward
    alpha[0][0] = 1.0;
    for(int tt=1; tt<y.size(); tt++){
      for(int i=0; i<labels.size(); i++){
	for(int j=0; j<labels.size(); j++){
	  alpha[i][tt] += exp( w[x[tt]+"_"+labels[j]+"_"+labels[i]] ) * alpha[j][tt-1];
	}
      }
    }

    //backward
    beta[1][y.size()-1] = 1.0;
    for(int tt=y.size()-2; tt>=0; tt--){
      for(int i=0; i<labels.size(); i++){
	for(int j=0; j<labels.size(); j++){
	  beta[i][tt] += exp( w[x[tt+1]+"_"+labels[i]+"_"+labels[j]] ) * beta[j][tt+1];
	}
      }
    }

    //[問題の部分]alpha,betaがオーバーフローする
    std::cerr << "=== alpha ===" << std::endl;
    for(int i=0; i<labels.size(); i++){
      for(int tt=0; tt<y.size(); tt++){
	std::cerr << alpha[i][tt] << "\t";
      }
      std::cerr << std::endl;
    }

    //Σ_y{P(y|x^(i))Φ(x^(i),y)}
    for(int t=1; t<x.size(); t++){
      for(int yt=0; yt<labels.size(); yt++){
	for(int ytt=0; ytt<labels.size(); ytt++){
	  double mprob = marginal_prob(t, ytt, yt, x, y, alpha, beta);
	  dw[x[t]+"_"+labels[ytt]+"_"+labels[yt]] -= mprob;
	}
      }
    }
  }

  ///予測////////////////////////////////////
  //ビタビアルゴリズム
  std::vector<std::string> predict_by_viterbi(const std::vector<std::string> &x){
    std::vector<std::string> ret;
    std::vector< std::vector<double> > t(x.size(), std::vector<double>(labels.size(), 0.0)); //その地点までの最大値を保持
    std::vector< std::vector<int> > s(x.size(), std::vector<int>(labels.size(), 2)); //最後に通ったラベル番号を保持

    //DP
    for(int i=1; i<x.size(); i++){
      for(int j=0; j<labels.size(); j++){
	if(i==1){ //from BOS
	  double cost = w[x[i]+"_"+labels[0]+"_"+labels[j]];
	  double tmp = cost;
	  if(t[i][j] < t[i-1][0]+tmp){
	    t[i][j] = t[i-1][0]+tmp;
	    s[i][j] = 0;
	  }
	}
	else if(i==x.size()-1){ //to EOS
	  double cost = w[x[i]+"_"+labels[j]+"_"+labels[1]];
	  double tmp = cost;
	  if(t[i][1] < t[i-1][j]+tmp){
	    t[i][1] = t[i-1][j]+tmp;
	    s[i][1] = j;
	  }
	}
	else{ //others
	  for(int k=0; k<labels.size(); k++){
	    double cost = w[x[i]+"_"+labels[k]+"_"+labels[j]];
	    double tmp = cost;
	    if(t[i][j] < t[i-1][k]+tmp){
	      t[i][j] = t[i-1][k]+tmp;
	      s[i][j] = k;
	    }
	  }
	}
      }
    }

    //backtracking
    int idx = 1; //EOS
    ret.push_back("EOS");
    for(int i=x.size()-1; i>0; i--){
      ret.push_back(labels[s[i][idx]]);
      idx = s[i][idx];
    }
    std::reverse(ret.begin(), ret.end());

    return ret;
  }

public:
  ///学習////////////////////////////////////
  void train(std::string filename, int loop = 1, double epsilon = 0.1, double C = 1.0){
    std::string line, word, tag;
    std::vector<string> sentence;
    std::vector<string> right_tags;
    std::set<std::string> label_set;
    
    //x(単語),y(ラベル)にどんなものがあるか調べて、素性を作成
    {
      words.clear();
      label_set.clear();
      words.insert("BOS"); words.insert("EOS");
      label_set.insert("BOS"); label_set.insert("EOS");

      std::ifstream ifs(filename.c_str());
      //ファイルの読み込み
      while(getline(ifs,line)){
	if(line.length() != 0){
	  std::stringstream ss(line);
	  ss >> word >> tag;
	  words.insert(word);
	  label_set.insert(tag);
	}
      }
      labels.push_back("BOS");
      labels.push_back("EOS");
      for(std::set<std::string>::iterator itr = label_set.begin(); itr != label_set.end(); ++itr){
	if(*itr == "BOS" || *itr == "EOS") continue;
	labels.push_back(*itr);
      }

      //素性を作成 Φ(xt,yt-1,yt)
      std::set<std::string>::iterator ixt, iyt, iytt;
      for(ixt = words.begin(); ixt != words.end(); ++ixt){
	for(iytt = label_set.begin(); iytt != label_set.end(); ++iytt){
	  for(iyt = label_set.begin(); iyt != label_set.end(); ++iyt){
	    w[(*ixt)+"_"+(*iytt)+"_"+(*iyt)] = 0.0;
	  }
	}
      }
      std::cerr << "num of feature : " << w.size() << std::endl; //素性の数
    }

    //学習
    std::map<std::string,double> dw; //∇w L(w_old)
    for(int l=0; l < loop; l++){
      std::cerr << "loop : " << l << std::endl;
      for(std::map<std::string,double>::iterator itr = w.begin(); itr != w.end(); ++itr){
	dw[itr->first] = 0.0;
      }
      //ファイルの読み込み
      std::ifstream ifs(filename.c_str());
      sentence.clear();
      right_tags.clear();      
      sentence.push_back("BOS");
      right_tags.push_back("BOS");
      int cases = 1;
      while(getline(ifs,line)){
	if(line.length() == 0){
	  std::cerr << "case : " << cases++ << std::endl;
	  
	  sentence.push_back("EOS");
	  right_tags.push_back("EOS");

	  //データから勾配の計算
	  train_from_data(dw, sentence, right_tags);

	  sentence.clear();
	  right_tags.clear();
	  sentence.push_back("BOS");
	  right_tags.push_back("BOS");
	}else{
	  std::stringstream ss(line);
	  ss >> word >> tag;
	  sentence.push_back(word);
	  right_tags.push_back(tag);
	}
      }
      
      //重みの更新(w_new = w_old + epsilon * (∇w L(w_old) - 正規化項))
      for(std::map<std::string,double>::iterator itr = w.begin(); itr != w.end(); ++itr){
	itr->second += epsilon * (dw[itr->first] - C * w[itr->first]);
	std::cerr << itr->first << "\t" << itr->second << std::endl;
      }
    }
  }

  ///予測////////////////////////////////////
  std::vector<std::string> predict(const std::vector<std::string>& sentence){
    std::vector<std::string> ret;
    std::vector<std::string> x;

    x.push_back("BOS");
    for(int i=0; i<sentence.size(); i++){
      x.push_back(sentence[i]);
    }
    x.push_back("EOS");

    std::vector<std::string> y = predict_by_viterbi(x);
    for(int i=1; i<y.size()-1; i++){
      ret.push_back(y[i]);
    }
    return ret;
  }
};

int main(){
  CRF crf;

  //学習
  crf.train("mytrain.txt",100,0.1,1.0);

  //テスト
  {
    ifstream ifs("mytest.txt");
    string line, word, tag;
    vector<string> sentence; //x
    vector<string> right_tags; //y
    int acc=0, num=0;

    while(getline(ifs,line)){
      if(line.length() == 0){
	//予測
	vector<string> predict_tags = crf.predict(sentence);

	//結果の表示
	for(int i=0; i<predict_tags.size(); i++){
	  cout << sentence[i] << "\t" << predict_tags[i] << "(" << right_tags[i] << ")" << endl;
	  if(predict_tags[i]==right_tags[i]){
	    acc++;
	  }
	  num++;
	}

	sentence.clear();
	right_tags.clear();
      }else{
	stringstream ss(line);
	ss >> word >> tag;
	sentence.push_back(word);
	right_tags.push_back(tag);
      }
    }

    //正解率
    cout << "Accuracy = " << (100*acc/(double)num) << "%(" << acc << "/" << num << ")" << endl;
  }
  
  return 0;
}

結果

学習データ・評価データ
  • 適当に短いものを用意(単語\tタグ\n)
    • windowsのときは、ちゃんとタブになってることと改行がLFになってることを確認
  • 「so happy」をチャンキング
I       O
am      O
happy   O
.       O

I       O
am      O
so      B
happy   I
.       O

You     O
are     O
happy   O
.       O

You     O
are     O
so      B
happy   I
.       O

I       O
think   O
so      O
.       O

so      O
much    O
for     O
that    O
.       O

実行結果

  • 「単語\t予想タグ(正解タグ)」
I	O(O)
am	O(O)
happy	O(O)
.	O(O)
I	O(O)
am	O(O)
so	B(B)
happy	I(I)
.	O(O)
You	O(O)
are	O(O)
happy	O(O)
.	O(O)
You	O(O)
are	O(O)
so	B(B)
happy	I(I)
.	O(O)
I	O(O)
think	O(O)
so	O(O)
.	O(O)
so	O(O)
much	O(O)
for	O(O)
that	O(O)
.	O(O)
Accuracy = 100%(27/27)
  • できてるっぽい?
    • 素性の重みの収束は確認できた

問題なこと

  • forward-backwardアルゴリズムで用いるアルファ、ベータの値がオーバーフローしてしまう
    • 品詞の種類が多いほど、入力文が長いほど
	x0	x1		x2	x3	x4	x5	x6
BOS	1	0.895056	5.26513	28.3707	151.775	791.396	5458.81 
EOS	0	0.895058	5.26513	28.3707	151.775	791.396	34264.4	#←でかい
B	0	0.884377	5.26513	28.3707	151.802	791.396	5458.81	
I	0	0.886109	5.26513	28.3707	149.321	791.396	5458.81	
O	0	1.82223		7.87219	39.5866	199.034	2293.23	5458.81
  • 本の演習問題だと各アルファが1.0を超えなそうだったけど、実際は余裕で超えてしまうみたい(inf)
  • 対処としては、以下のような方法があるみたい
    • 対数(logsumexp)で計算
    • スケーリング法
    • オーバーフロー・アンダーフローしない実数型を自分で定義