Kneser-Ney smoothingで遊ぶ

はじめに

100-nlp-papersで紹介されてた一番最初の論文に、クナイザーネイスムージングのスッキリな実装が載っていたので書いてみる。

Joshua Goodman: A bit of progress in language modeling, MSR Technical Report, 2001.

Kneser-Ney smoothingとは

  • 言語モデルのスムージング(平滑化)手法の一種で、高い言語モデル性能を実現している
  • イデアとしては「(n-1)-gramが出現した文脈での異なり数」を使うこと
    • 頻度を使うと、高頻度なn-gramではその(n-1)-gramも多くなってしまうため、特定文脈でしかでないような(n-1)-gramに対しても高い確率値ことになっていて、歪んだ結果になってしまう
      • 「San Francisco」の頻度が多いと「Francisco」の頻度も高くなるが、P(Francisco|on)とかはあまり出現しないので低くなってほしいところ、「Franciscoの頻度」を使って確率値を推定すると高くなってしまう
    • 頻度ではなく、異なり数で(n-1)-gramの確率を推定することで、補正する
  • 上のレポートでは、Interpolatedな補間方法での実装例を紹介している
    • back-offな方法も考えらえる
    • discount(割引値)パラメータをn-gramごとに分けた方法は「modified Kneser-Ney smoothing」と呼ばれている

UNKの扱い

  • レポートのAppendixのFigure17と18はそのままだと学習データに出現しない単語UNKが出てくると、unigramが0なので、確率も0になってしまう
  • レポートの8ページ目では、一様分布1/|V|(Vは語彙集合)を使ってスムージングしてこれを避けると紹介されている
  • λはどうするの?というのは、以下のページで議論されているように、λ(ε)と考えると、「λ==discountパラメータ」としてもよいかなと思うので、コードではそのようにした

コード

UNKのためにちょっと修正した。あんまりちゃんとチェックできていないけど、それっぽい数値を返しているのでおそらく大丈夫。

#include <iostream>
#include <fstream>
#include <vector>
#include <string>
#include <cmath>
#include <unordered_map>

class InterpolatedKneserNeyTrigram {
  const std::string delim = "\t";
  double discount; //(0,1)の範囲で指定する
  std::unordered_map<std::string,int> TD, TN, TZ, BD, BN, BZ, UN;
  int UD;
public:
  InterpolatedKneserNeyTrigram():discount(0.1),UD(0){}
  InterpolatedKneserNeyTrigram(double d):discount(d),UD(0){}

  //ファイルに書き出し
  void save(const std::string& filename){
    std::ofstream fout(filename);
    if(!fout){ std::cerr << "cannot open file" << std::endl; return; }
    fout << discount << std::endl;
    fout << TD.size() << std::endl;
    std::unordered_map<std::string,int>::iterator it;
    for(it=TD.begin(); it!=TD.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << TN.size() << std::endl;
    for(it=TN.begin(); it!=TN.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << TZ.size() << std::endl;
    for(it=TZ.begin(); it!=TZ.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << BD.size() << std::endl;
    for(it=BD.begin(); it!=BD.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << BN.size() << std::endl;
    for(it=BN.begin(); it!=BN.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << BZ.size() << std::endl;
    for(it=BZ.begin(); it!=BZ.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << UN.size() << std::endl;
    for(it=UN.begin(); it!=UN.end(); ++it) fout << it->first << std::endl << it->second << std::endl;
    fout << UD << std::endl;    
    fout.close();
  }
  //ファイルから読み込み
  void load(const std::string& filename){
    std::ifstream fin(filename);
    if(!fin){ std::cerr << "cannot open file" << std::endl; return; }
    fin >> discount;
    int td, tn, tz, bd, bn, bz, un;
    std::string s;
    int c;
    fin >> td;
    for(int i=0; i<td; i++){ getline(fin,s); getline(fin, s); fin >> c; TD[s] = c; }
    fin >> tn;
    for(int i=0; i<tn; i++){ getline(fin,s); getline(fin, s); fin >> c; TN[s] = c; }
    fin >> tz;
    for(int i=0; i<tz; i++){ getline(fin,s); getline(fin, s); fin >> c; TZ[s] = c; }
    fin >> bd;
    for(int i=0; i<bd; i++){ getline(fin,s); getline(fin, s); fin >> c; BD[s] = c; }
    fin >> bn;
    for(int i=0; i<bn; i++){ getline(fin,s); getline(fin, s); fin >> c; BN[s] = c; }
    fin >> bz;
    for(int i=0; i<bz; i++){ getline(fin,s); getline(fin, s); fin >> c; BZ[s] = c; }
    fin >> un;
    for(int i=0; i<un; i++){ getline(fin,s); getline(fin, s); fin >> c; UN[s] = c; }
    fin >> UD;
    fin.close();
  }

  void set_discount(double d){ discount = d; }
  double get_discount() const { return discount; }
  
  void add_sentence(const std::vector<std::string>& sentence){
    std::string w2 = "", w1 = "";

    for(size_t i=0; i<sentence.size(); i++){
      std::string w0 = sentence[i];
      TD[ w2 + delim + w1 ]++;
      if(TN[ w2 + delim + w1 + delim + w0 ]++ == 0){
        TZ[ w2 + delim + w1 ]++;

        BD[ w1 ]++;
        if(BN[ w1 + delim + w0 ]++ == 0){
          BZ[ w1 ]++;

          UD++;
          UN[ w0 ]++;
        }
      }
      w2 = w1;
      w1 = w0;
    }
  }

  double prob(const std::vector<std::string>& sentence){
    std::string w2 = "", w1 = "";
    double ret = 0;

    for(size_t i=0; i<sentence.size(); i++){
      std::string w0 = sentence[i];
      double prob = 0;

      //そのままだとUN[w0]==0のときprob==0になるため、1/|V|を使うように変更
      double uniform = 1.0 / UN.size();
      double unigram = 0.0;
      if(UN.count( w0 ) > 0){
        unigram = (UN[ w0 ] - discount) / (double)UD;
      }
      unigram += discount * uniform;
      if(BD.count( w1 ) > 0){
        double bigram = 0;
        if(BN.count( w1 + delim + w0 ) > 0){
          bigram = (BN[ w1 + delim + w0 ] - discount) / BD[ w1 ];
        }
        bigram += BZ[ w1 ] * discount / BD[ w1 ] * unigram;

        if(TD.count( w2 + delim + w1 ) > 0){
          double trigram = 0;
          if(TN.count( w2 + delim + w1 + delim + w0 ) > 0){
            trigram = (TN[ w2 + delim + w1 + delim + w0 ] - discount) / TD[ w2 + delim + w1 ];
          }
          trigram += TZ[ w2 + delim + w1 ] * discount / TD[ w2 + delim + w1 ] * bigram;
          prob = trigram;
        }else{
          prob = bigram;
        }
      }else{
        prob = unigram;
      }
      ret += log(prob);      
      w2 = w1;
      w1 = w0;
    }
    return ret;
  }
};


int main(){
  InterpolatedKneserNeyTrigram lm;  
  std::vector< std::vector<std::string> > train_v, valid_v;
  
  {//ファイルの読み込み
    std::ifstream trainfs("train.txt");
    std::ifstream validfs("valid.txt");
    std::string w;
    std::vector<std::string> tmp;
    while(trainfs >> w){
      tmp.push_back(w);
      if(w == "EOS"){
        train_v.push_back(tmp);
        tmp.clear();
      }
    }
    
    tmp.clear();
    while(validfs >> w){
      tmp.push_back(w);
      if(w == "EOS"){
        valid_v.push_back(tmp);
        tmp.clear();
      }
    }
  }
  
  {//学習用の文を全部入れる
    for(size_t i=0; i<train_v.size(); i++){
      lm.add_sentence(train_v[i]);
    }
  }
  
  {//よさそうなdを探す
    double best = log(0), best_d = 0;
    double prec = 0.001;
    for(double d=prec; d<1; d+=prec){
      lm.set_discount(d);
      double logq = 0.0;
      for(size_t i=0; i<valid_v.size(); i++){
        logq += lm.prob(valid_v[i]);
      }
      std::cerr << d << "\t" << logq << std::endl;
      if(best < logq){
        best = logq;
        best_d = d;
      }
    }
    lm.set_discount(best_d);
    std::cerr << "best: " << best << " (d = " << best_d << ")" << std::endl;
  }

  lm.save("lm.data");

  return 0;
}

実験

データの準備

「坊ちゃん」の言語モデルを作ってみる。
青空文庫から「坊ちゃん」のテキストを取得し、「≪≫」などで囲まれた部分を削除したものを用意。
全部で470行で、10行ごとをdiscount係数確認用にする。

さらにそれを、mecab+ipadicで1行1単語にした以下のようなテキストを準備する。
(学習用train.txt 424行分、確認用valid.txt 46行分)

親譲り
の
無鉄砲
で
小
供
の
時
から
損
...
と
答え
た
。
EOS
親類
の
もの
...

(1行分の終わりには「EOS」を含む)

最適なパラメータの探索

確認用のデータで一番尤度が高くなるパラメータを採用する。


最適なのは、discount=0.897のとき、対数尤度が-29116ぐらい。


事例

いくつかの例で確率を見てみる。

s=「親譲り の 無鉄砲 だ EOS」 → log(P(s)) = -23.7238
s=「親譲り の ブレイブハート だ EOS」 → log(P(s)) = -33.8758
s=「吾輩 は 猫 だ EOS」 → log(P(s)) = -36.8097
s=「無鉄砲 な フレンズ だ EOS」 → log(P(s)) = -38.3098