Minimal Acyclic Subsequential Transducerで遊ぶ

はじめに

https://pycon.jp/2015/ja/proposals/vote/11/
Pycon2015で発表された「Pythonで作って学ぶ形態素解析」で紹介されていた辞書データ構造の「Minimal Acyclic Subsequential Transducer」について、勉強のために書いてみた。

Minimal Acyclic Subsequential Transducerとは

  • Finite State Transducerの一種
  • Transducerにおいて、initial stateが一つで、同じ入力ラベルを共有する同じ状態からのの遷移が2つ以上なく、各最終状態での最終出力文字列が高々p個のとき、p-subsequentialで、pが整数ならfinitely subsequentialというらしい
  • minimal(状態数が最少)、Acyclic(サイクルが無い)
    • Cyril Allauzen, Mehryar Mohri, Finitely Subsequential Transducersとp-Subsequentiable Transducersあたりを読む
  • Lucene/Solrで使われているFST」はこれの事
    • Kagome(Golang), Janome(python)でも採用
    • 辞書サイズがコンパクトになるので内包してもそこまで大きくならない

構築方法


OpenFstとか使うイメージでいたけど、上記で紹介されている通り、一時的な状態を作らずに、ソート済みの入力から直接FSTを構築する方法が提案されている。
詳しい手順も上記の資料(qiita)で紹介されいているので、省略。

コード

論文にできるだけ従って、入力・出力は文字列、出力は一つのみ、で実装。
問題がある部分は、状態の探索(member関数)で線形に等価な状態を探索している点や、状態の等価性判定(equal関数)が下記の方法でもよいのか怪しい点(Def.2-3?)、など。
他の実装は全然見てないので、間違ってたら後で修正する。

#include <iostream>
#include <vector>
#include <algorithm>
#include <list>
#include <cstdio>
#include <map>
#include <queue>

//入力文字列の最大長
#define MAX_WORD_SIZE 100

struct State {
  State* next[0x100];
  std::string output[0x100];
  std::string state_out;
  bool final;
  int _num;

  State(){
    clear();
  }
  State(State* s){
    _num = -1;
    state_out = s->state_out;
    final = s->final;
    for(size_t i=0; i<0x100; i++){
      next[i] = s->next[i];
      output[i] = s->output[i];
    }
  }
  ~State(){}
  void clear(){
    _num = -1;
    state_out = "";
    final = false;
    for(size_t i=0; i<0x100; i++){
      next[i] = (State*)0;
      output[i] = "";
    }
  }

  //状態の等価チェック
  // 与えられた状態sに対して、遷移状態と出力記号などが完全に一致するかを確認
  bool equal(State* s){
    if(final && s->final){
      if(state_out == s->state_out) return true;
      return false;
    }

    std::queue< std::pair< std::pair<State*,std::string>, std::pair<State*,std::string> > > que;
    for(size_t i=0; i<0x100; i++){
      que.push(std::make_pair(std::make_pair(next[i],""), std::make_pair(s->next[i],"")));
    }

    while(!que.empty()){
      std::pair<std::pair<State*,std::string>, std::pair<State*,std::string> > p = que.front(); que.pop();
      if(p.first.second != p.second.second) return false;
      if(p.first.first == NULL && p.second.first == NULL) continue;
      if(p.first.first == NULL || p.second.first == NULL) return false;
      if(p.first.first->state_out != p.second.first->state_out) return false;

      for(size_t i=0; i<0x100; i++){
        if(p.first.first->next[i] == NULL && p.second.first->next[i] == NULL) continue;
        que.push(std::make_pair(std::make_pair(p.first.first->next[i],p.first.second + (char)i), 
                                std::make_pair(p.second.first->next[i],p.second.second + (char)i)));
      }
    }
    return true;
  }
};

struct Dictionary {
  Dictionary(){}
  ~Dictionary(){
    for(std::list<State*>::iterator itr = states.begin();
        itr != states.end();
        ++itr){
      delete *itr;
    }
    states.clear();
    for(size_t i=0; i<MAX_WORD_SIZE; i++){
      if(TempState[i]){
        delete TempState[i];
        TempState[i] = NULL;
      }
    }
  }

  //辞書の状態に存在する状態かどうかチェック
  // [注意] 線形に全部等価チェックしているので非常に重い
  State* member(State* s){
    for(std::list<State*>::iterator itr = states.begin();
        itr != states.end();
        ++itr){
      State* p = *itr;
      if(s->equal(p)) return p;
    }
    return NULL;
  }

  void insert(State* s){
    states.push_back(s);
  }

  State* copy_state(State* s){
    return new State(s);
  }

  State* find_minimized(State* s){
    State* r = member(s);
    if(r == NULL){
      r = copy_state(s);
      insert(r);
    }
    return r;
  }
  void set_transition(State* s, char c, State* t){
    s->next[c] = t;
  }

  //辞書登録。inは辞書順ソートされている必要がある
  void create(const std::vector<std::string>& in, const std::vector<std::string>& out){
    for(size_t i=0; i<MAX_WORD_SIZE; i++) TempState[i] = new State;
    PreviousWord = "";
    TempState[0]->clear();

    std::string CurrentWord = "";
    std::string CurrentOutput = "";
    for(size_t t=0; t<in.size(); t++){
      CurrentWord = in[t];
      CurrentOutput = out[t];
      int i, j;
      i = 1;
      while(i<=CurrentWord.length() && i<=PreviousWord.length() && (CurrentWord[i-1] == PreviousWord[i-1])) i++;
      int PrefixLengthPlus1 = i;


      for(i=PreviousWord.size(); i>=PrefixLengthPlus1; i--){
        set_transition(TempState[i-1], PreviousWord[i-1], find_minimized(TempState[i]));
      }

      for(i=PrefixLengthPlus1; i<=CurrentWord.length(); i++){
        TempState[i]->clear();
        set_transition(TempState[i-1], CurrentWord[i-1], TempState[i]);
      }

      if(in[t] != PreviousWord){
        TempState[CurrentWord.length()]->final = true;
      }


      for(j=1; j<=PrefixLengthPlus1-1; j++){
        if(TempState[j-1] == NULL) continue;
        std::string outputStr = TempState[j-1]->output[CurrentWord[j-1]];
        std::string CommonPrefix = "";
        for(int k=0; k<std::min(outputStr.length(), CurrentOutput.length()); k++){
          if(outputStr[k] != CurrentOutput[k]) break;
          CommonPrefix += outputStr[k];
        }
        std::string WordSuffix = outputStr.substr(CommonPrefix.length());

        TempState[j-1]->output[CurrentWord[j-1]] = CommonPrefix;
        for(size_t c=0; c<0x100; c++){
          if(TempState[j]->next[c] != NULL){
            TempState[j]->output[c] = WordSuffix + TempState[j]->output[c];
          }
        }
        CurrentOutput = CurrentOutput.substr(CommonPrefix.length());
      }
      if(in[t] == PreviousWord){
        TempState[CurrentWord.length()]->state_out += CurrentOutput;
      }else{
        TempState[PrefixLengthPlus1-1]->output[CurrentWord[PrefixLengthPlus1-1]] = CurrentOutput;
      }
      PreviousWord = CurrentWord;
    }
    for(int i=CurrentWord.length(); i>=1; i--){
      TempState[i-1]->next[PreviousWord[i-1]] = find_minimized(TempState[i]);
      InitialState = find_minimized(TempState[0]);
    }
  }

  //検索
  std::string search(const std::string& query){
    std::string ret = "";
    State* p = InitialState;
    ret += p->state_out;
    for(int i=0; i<query.length(); i++){
      if(!p->next[query[i]]) return "";
      ret += p->output[query[i]];
      p = p->next[query[i]];
      ret += p->state_out;
    }
    if(p->final) return ret;
    return "";
  }


  //内部状態を出力
  void dump(){
    std::queue<State*> que;
    que.push(InitialState);
    int state_id = 0;
    while(!que.empty()){
      State* p = que.front(); que.pop();
      if(p->_num >= 0) continue;
      p->_num = state_id;
      for(int i=0; i<0x100; i++){
        if(p->next[i] == NULL) continue;
        que.push(p->next[i]);
      }
      state_id++;
    }

    std::cout << "size : " << states.size() << std::endl;
    std::map<int,int> memo;
    que.push(InitialState);
    while(!que.empty()){
      State* p = que.front(); que.pop();
      if(memo.count(p->_num) > 0) continue;
      memo[p->_num] = 1;

      std::cout << "node " << p->_num << " : [" << p->state_out << "] " << (p->final?"final":"") << std::endl;
      for(int i=0; i<0x100; i++){
        if(p->next[i] == NULL) continue;
        std::cout << "  " << (char)i << "->" << p->next[i]->_num << " [" << p->output[i] << "]" << std::endl;
        que.push(p->next[i]);
      }
    }
  }

private:
  std::list<State*> states;
  State* TempState[MAX_WORD_SIZE];
  std::string PreviousWord;
  State* InitialState;
};


int main(){
  Dictionary dic;

  std::vector<std::string> in, out;
  in.push_back("apr"); out.push_back("30");
  in.push_back("aug"); out.push_back("31");
  in.push_back("dec"); out.push_back("31");
  in.push_back("feb"); out.push_back("28");
  in.push_back("jan"); out.push_back("31");
  in.push_back("jul"); out.push_back("31");

  dic.create(in, out);

  dic.dump();

  std::cout << "-------" << std::endl;

  std::cout << "apr : " << dic.search("apr") << std::endl;
  std::cout << "aug : " << dic.search("aug") << std::endl;
  std::cout << "dec : " << dic.search("dec") << std::endl;
  std::cout << "feb : " << dic.search("feb") << std::endl;
  std::cout << "jan : " << dic.search("jan") << std::endl;
  std::cout << "jul : " << dic.search("jul") << std::endl;
  std::cout << "abc : " << dic.search("abc") << std::endl;

  return 0;
}

結果

size : 12
node 0 : [] 
  a->1 [3]
  d->2 [31]
  f->3 [28]
  j->4 [31]
node 1 : [] 
  p->5 [0]
  u->6 [1]
node 2 : [] 
  e->7 []
node 3 : [] 
  e->8 []
node 4 : [] 
  a->9 []
  u->10 []
node 5 : [] 
  r->11 []
node 6 : [] 
  g->11 []
node 7 : [] 
  c->11 []
node 8 : [] 
  b->11 []
node 9 : [] 
  n->11 []
node 10 : [] 
  l->11 []
node 11 : [] final
-------
apr : 30
aug : 31
dec : 31
feb : 28
jan : 31
jul : 31
abc : 

ランダムなアルファベット文字列でやってみても問題ないので一応大丈夫そう。
メモリ効率や高速化など実用レベルにするなら結構賢く書かないと大変そう。