ウェーブレット木を試す

はじめに

巨大な文字列でも高速にクエリ処理できる噂の木を、挙動を確認するため作ってみた。

コード

  • アルファベット(a〜z)の文字列を扱う場合
  • 完備辞書の操作が愚直、ビット列がvector
  • 本を参考にしたけど、2か所間違ってる?
#include <iostream>
#include <vector>
#include <queue>
#include <cmath>

//top_kのためのタプル
struct ST {
  int t;
  size_t st, en;
  ST(int t, size_t st, size_t en):t(t),st(st),en(en){}
};
bool operator<(const ST& a, const ST& b){
  return (a.en-a.st) < (b.en-b.st);
}


//アルファベット([a-z]+)用のウェーブレット木
class AlphabetWaveletTree {
  //2分木なので、配列で各節点のデータを持つようにする
  //t番目の子は、2*t+1番目と2*t+2番目になるようにする

  //各節点ごとの完備辞書(ビット列)
  std::vector< std::vector<int> > B;
  //文字列復元のための各節点vごとの(インデクス)ポインタpv
  std::vector<int> pv;

  //葉ノードの開始インデクス
  int N; //2^x - 1

  ///// 完備辞書のための操作(効率の悪い実装) /////
  // B[t][i]を返す
  int access(size_t t, size_t i){
    return B[t][i];
  }
  // B[t][0,i)中のb\in{0,1}の数を返す
  size_t rank(size_t t, int b, size_t i){
    size_t ret = 0;
    for(size_t ii = 0; ii<i; ii++){
      if(B[t][ii] == b) ret++;
    }
    return ret;
  }
  // B[t]中で先頭からみてi+1番目に出現したb\in{0,1}の位置を返す
  size_t select(size_t t, int b, size_t i){
    size_t cnt = 0;
    if(t >= N) return 0;

    for(size_t ii = 0; ii<B[t].size(); ii++){
      if(B[t][ii] == b){
        cnt++;
        if(cnt == i+1) return ii;
      }
    }
    return -1;
  }
  ////////////////////////////////////////////////


  //根か否か
  bool isroot(int t){
    if(t == 0) return true;
    return false;
  }

  //葉か否か
  bool isleaf(int t){
    if(t >= N) return true;
    return false;
  }

  //cのd番目のビット
  int bit(char c, int d){
    return ((c-'a')>>(static_cast<int>(log2(N+1)+0.5)-d)) & 1; //'a'を2進数の00000として計算する
  }

  //cの葉を返す
  int getleaf(char c){
    int t = 0;
    int d = 1;

    while(!isleaf(t)){
      int b = bit(c, d);
      d++;
      t = child(t, b);
    }
    return t;
  }

  //文字を返す
  char getchar(int t){
    return 'a'+(t-N);
  }


  //節点tの親
  int parent(int t){
    return (t-1)/2;
  }

  //節点tのbの側の子
  int child(int t, int b){
    if(b == 0){
      return 2 * t + 1;
    }else{
      return 2 * t + 2;
    }
  }

  //ウェーブレット木の構築
  void build(const std::string& T){
    for(size_t i=0; i<T.length(); i++){
      char c = T[i];
      int t = 0; //WT.root
      int d = 1;

      while(!isleaf(t)){
        int b = bit(c, d); //cのd番目のビット
        B[t].push_back(b);
        t = child(t, b);
        d++;
      }
    }
  }

public:
  //textは[a-z]+な文字列
  AlphabetWaveletTree(const std::string& text):
    N(31),
    B(31, std::vector<int>()), //31以降のインデクスは葉ノード
    pv(31)
  {
    build(text);    
  }

  //dump
  void dump(){
    for(size_t i=0; i<N; i++){
      std::cout << i << ":";
      for(size_t j=0; j<B[i].size(); j++){
        std::cout << B[i][j];
      }
      std::cout << std::endl;
    }
  }

  //文字列の復元
  std::string reconst(){
    std::string ret = "";
    
    for(size_t i=0; i<N; i++){ pv[i] = 0; }
    
    while(pv[0] < B[0].size()){
      int t = 0;
      int c = 0, cp = static_cast<int>(log2(N+1)+0.5)-1;
      while(!isleaf(t)){
        int b = B[t][pv[t]];
        c += b << cp; cp--;
        pv[t]++;
        t = child(t, b);
      }
      ret += 'a' + c;
    }
    return ret;
  }

  ///// 文字列操作 /////
  
  //文字T[i]の復元
  char access(size_t i){
    int t = 0;
    size_t p = i;
    int c = 0, cp = static_cast<int>(log2(N+1)+0.5)-1;

    while(!isleaf(t)){
      int b = access(t, p);
      c += b << cp; cp--;
      p = rank(t, b, p); //b<-access(Bt, p)っぽい?
      t = child(t, b);
    }
    return 'a' + c;
  }

  //T[0,i)中の文字cの出現回数を返す
  int rank_c(size_t i, char c){
    int t = 0;
    size_t p = i;
    int d = 1;

    while(!isleaf(t)){
      int b = bit(c, d); //cのd番目のビット
      d++;
      p = rank(t, b, p);
      t = child(t, b);
    }

    return p;
  }

  //T中の(i+1)番目のcの出現位置を返す
  size_t select_c(size_t i, char c){
    int t = getleaf(c);
    size_t p = i;
    int d = static_cast<int>(log2(N+1)+0.5); //lg c

    while(!isroot(t)){
      int b = bit(c, d); //cのd番目のビット
      t = parent(t); //親ノードに移動するのが先っぽい?
      p = select(t, b, p);
      d--;
    }
    
    return p;
  }

  //T[s,e)で辞書順が(r+1)番目に小さいものを返す
  char quantile(size_t s, size_t e, size_t r){
    int t = 0;
    size_t st = s;
    size_t en = e;
    size_t remain = r;
    
    while(!isleaf(t)){
      size_t zn = rank(t, 0, en) - rank(t, 0, st);
      int b;
      if(remain < zn){
        b = 0;
      }else{
        b = 1;
        remain = remain - zn;
      }
      st = rank(t, b, st);
      en = rank(t, b, en);
      t = child(t, b);
    }

    return getchar(t);
  }

  //T[s,e)中で出現回数が多い文字順にその頻度とともにk個返す
  std::vector< std::pair<char,int> > top_k(size_t s, size_t e, size_t k){
    std::vector< std::pair<char,int> > result;
    std::priority_queue<ST> que;
    que.push(ST(0, s, e));

    while(!que.empty()){
      ST q = que.top(); que.pop();

      int t = q.t;
      size_t st = q.st;
      size_t en = q.en;

      if(isleaf(t)){
        result.push_back(std::make_pair<char,int>(getchar(t), en-st));
        if(result.size() == k){
          break;
        }
      }else{
        size_t zst = rank(t, 0, st);
        size_t zen = rank(t, 0, en);
        size_t ost = st - zst;
        size_t oen = en - zen;
        
        if(zen-zst > 0){
          que.push(ST(child(t,0), zst, zen));
        }
        if(oen-ost > 0){
          que.push(ST(child(t,1), ost, oen));
        }
      }
    }
    return result;
  }

};


int main(){

  std::string text = "aabcdeeeeeefgghijklmnoooopqrstuvwxyz";
  AlphabetWaveletTree wt(text);

  //木の構造を出力
  wt.dump();

  //元の文
  std::cout << "text   :" << text << std::endl;
  //WaveletTreeから復元した文字列
  std::cout << "reconst:" << wt.reconst() << std::endl;

  //WaveletTreeから1文字ずつ復元した文字列
  std::cout << "access :";
  for(size_t i=0; i<text.length(); i++){
    std::cout << wt.access(i);
  }
  std::cout << std::endl;

  //文字cの出現頻度
  std::cout << "rank_c:" << wt.rank_c(text.length(), 'o') << std::endl;

  //文字cがi番目に出現する位置(インデクス)
  std::cout << "select_c:" << wt.select_c(0, 'e') << std::endl;

  //辞書順で最小のもの
  std::cout << "first alphabet:" << wt.quantile(0, text.length(), 0) << std::endl;
  //辞書順で最大のもの
  std::cout << "last alphabet :" << wt.quantile(0, text.length(), text.length()-1) << std::endl;

  //頻度が多いものTopK
  int K = 5;
  std::vector< std::pair<char,int> > top_k = wt.top_k(0, text.length(), K);
  std::cout << "top_k:" << std::endl;
  for(size_t i=0; i<K; i++){
    std::cout << "[" << i+1 << "] " << top_k[i].first << " " << top_k[i].second << std::endl;
  }

  return 0;
}

結果

$ ./a.out
0:000000000000000000000000001111111111
1:00000000000000011111111111
2:0000000011
3:000001111111111
4:00001111111
5:00001111
6:00
7:00011
8:0000000111
9:0011
10:0011111
11:0011
12:0011
13:00
14:
15:001
16:01
17:0000001
18:001
19:01
20:01
21:01
22:00001
23:01
24:01
25:01
26:01
27:01
28:
29:
30:
text   :aabcdeeeeeefgghijklmnoooopqrstuvwxyz
reconst:aabcdeeeeeefgghijklmnoooopqrstuvwxyz
access :aabcdeeeeeefgghijklmnoooopqrstuvwxyz
rank_c:4
select_c:5
first alphabet:a
last alphabet :z
top_k:
[1] e 6
[2] o 4
[3] a 2
[4] g 2
[5] h 1