XBWを試す

はじめに

XBWをWaveletMatrixを使って、試しに実装してみた。

XBWとは

コード

XBWでの表のソートはそのまま文字列同士のソートをしている。
チェックは、コード中にあるように、適当にkey文字列とkey文字列じゃないのを生成してtrieと結果が一緒になるかだけ。
WaveletMatrixの方で理解のためにいくつか関数を書いているけど、rank()ぐらいしか使っていないので、それ以外はあまりVerifyできていない。
(g++はバージョン「5.4.0」、オプション「-std=gnu++1y -O2」で実行してる)

#include <vector>
#include <map>
#include <cstdint>
#include <algorithm>
#include <iostream>
#include <queue>
#include <random>

//完備辞書(Fully Indexable Dictionary)
// 【使いまわすときの注意】
// - 全部set()したら、最後にfinalize()を呼ぶこと
// - select(x)の実装で、xが要素数よりも多い場合-1を返す実装にしている
// - 32bit/64bit書き換えは、BIT_SIZE,BLOCK_TYPE,popcount,整数リテラルのサフィックスなどを書き換えること
class FID {
  static const int BIT_SIZE = 64;
  using BLOCK_TYPE = uint64_t;
  int size;
  int block_size;
  std::vector<BLOCK_TYPE> blocks;
  std::vector<int> s_rank;
public:
  //for BIT_SIZE == 32
  /*
  BLOCK_TYPE popcount(BLOCK_TYPE x){
    x = ((x & 0xaaaaaaaa) >> 1) + (x & 0x55555555);
    x = ((x & 0xcccccccc) >> 2) + (x & 0x33333333);
    x = ((x & 0xf0f0f0f0) >> 4) + (x & 0x0f0f0f0f);
    x = ((x & 0xff00ff00) >> 8) + (x & 0x00ff00ff);
    x = ((x & 0xffff0000) >> 16) + (x & 0x0000ffff);
    return x;
  }
   */
  //__builtin_popcount()
  
  //for BIT_SIZE == 64
  BLOCK_TYPE popcount(BLOCK_TYPE x){
    x = ((x & 0xaaaaaaaaaaaaaaaaULL) >> 1) + (x & 0x5555555555555555ULL);
    x = ((x & 0xccccccccccccccccULL) >> 2) + (x & 0x3333333333333333ULL);
    x = ((x & 0xf0f0f0f0f0f0f0f0ULL) >> 4) + (x & 0x0f0f0f0f0f0f0f0fULL);
    x = ((x & 0xff00ff00ff00ff00ULL) >> 8) + (x & 0x00ff00ff00ff00ffULL);
    x = ((x & 0xffff0000ffff0000ULL) >> 16) + (x & 0x0000ffff0000ffffULL);
    x = ((x & 0xffffffff00000000ULL) >> 32) + (x & 0x00000000ffffffffULL);
    return x;
  }
  //__builtin_popcountll()
public:
  FID(int size):
  size(size),
  block_size(((size + BIT_SIZE - 1) / BIT_SIZE) + 1),
  blocks(block_size, 0),
  s_rank(block_size, 0){}
  
  void set(int i){
    blocks[i/BIT_SIZE] |= 1ULL << (i%BIT_SIZE);
  }
  
  void finalize(){
    s_rank[0] = 0;
    for(int i=1; i<block_size; i++){
      s_rank[i] = s_rank[i-1] + popcount(blocks[i-1]);
    }
  }
  
  bool access(int i){
    return (blocks[i/BIT_SIZE] >> (i%BIT_SIZE)) & 1ULL;
  }

  //iより前のビットが立っている個数
  int rank(int i){
    BLOCK_TYPE mask = (1ULL << (i%BIT_SIZE)) - 1;
    return s_rank[i/BIT_SIZE] + popcount(mask & blocks[i/BIT_SIZE]);
  }

  //x番目にビットが立っている位置
  int select(int x){
    if(rank((block_size-1) * BIT_SIZE) <= x) return -1; //注意
    int lb = 0, ub = block_size-1;
    while(ub-lb>1){
      int m = (lb+ub)/2;
      if(s_rank[m]<=x) lb = m;
      else ub = m;
    }
    int lbb = lb*BIT_SIZE, ubb = (lb+1)*BIT_SIZE;
    while(ubb-lbb>1){
      int m = (lbb+ubb)/2;
      if(rank(m)<=x) lbb = m;
      else ubb = m;
    }
    return lbb;
  }
};

//ウェーブレット行列(Wavelet Matrix)
// 【使いまわすときの注意】
// - 全部set()したら、最後にfinalize()を呼ぶこと
// - 32bit/64bit書き換えは、BIT_SIZE,VAL_TYPEなどを書き換えること
class WaveletMatrix {
  static const int BIT_SIZE = 8;
  using VAL_TYPE = uint8_t;
  int size;
  std::vector<VAL_TYPE> v;
  std::vector<FID> matrix;
  std::vector<int> sep;

  struct mytuple {
    int b, s, e;
    mytuple(int b, int s, int e):b(b),s(s),e(e){}
    bool operator<(const mytuple& x) const {
      return e-s < x.e-x.s;
    }
  };
public:
  WaveletMatrix(int size):
  size(size),
  v(size, 0),
  matrix(BIT_SIZE, FID(size)),
  sep(BIT_SIZE, 0){}

  void set(int i, VAL_TYPE val){
    v[i] = val;
  }

  void finalize(){
    std::vector<VAL_TYPE> w(v.size(), 0);
    for(int b=BIT_SIZE-1; b>=0; b--){
      for(int i=0; i<size; i++){
        if((v[i] >> b) & 1ULL) matrix[b].set(i);
        else sep[b]++;
      }
      int b1=0, b2=sep[b];
      for(int i=0; i<size; i++){
        if((v[i] >> b) & 1ULL) w[b2++] = v[i];
        else w[b1++] = v[i];
      }
      for(int i=0; i<size; i++){
        v[i] = w[i];
      }
      matrix[b].finalize();
    }
  }

  //元の配列のi番目の要素
  VAL_TYPE access(int i){
    VAL_TYPE ret = 0;
    for(int b=BIT_SIZE-1; b>=0; b--){
      if(matrix[b].access(i)){
        i = sep[b] + matrix[b].rank(i);
        ret = (ret << 1) + 1ULL;
      }else{
        i = i - matrix[b].rank(i);
        ret = (ret << 1);
      }
    }
    return ret;
  }

  //[0,i)の範囲にxが何個存在するか
  int rank(int i, VAL_TYPE x){
    int lb = 0, ub = i;
    for(int b=BIT_SIZE-1; b>=0; b--){
      if((x >> b) & 1ULL){
        lb = matrix[b].rank(lb);
        ub = matrix[b].rank(ub);
        lb += sep[b];
        ub += sep[b];
      }else{
        lb = lb - matrix[b].rank(lb);
        ub = ub - matrix[b].rank(ub);                
      }
    }
    return ub - lb;
  }

  //i番目(0-index)のxが出現する位置
  int select(int i, VAL_TYPE x){
    int lb = 0, ub = size;
    while(ub-lb>1){
      int m = (lb+ub)/2;
      if(rank(m, x)<=i) lb = m;
      else ub = m;
    }
    return lb;
  }

  //[s,e)の範囲を切り出してソートしたときのn番目(0-index)の要素
  VAL_TYPE quantile(int s, int e, int n){
    for(int b=BIT_SIZE-1; b>=0; b--){
      int zn = (e - s) - (matrix[b].rank(e) - matrix[b].rank(s));
      if(zn <= n){
        s = matrix[b].rank(s);
        e = matrix[b].rank(e);
        s += sep[b];
        e += sep[b];
        n = n - zn;
      }else{
        s = s - matrix[b].rank(s);
        e = e - matrix[b].rank(e);                
      }      
    }
    return v[s];
  }

  //[s,e)の範囲で出現回数が多い数値順に、その数値と出現回数のTop-K
  std::vector<std::pair<VAL_TYPE,int>> top_k(int s, int e, int k){
    std::vector<std::pair<VAL_TYPE,int>> ret;
    std::priority_queue<mytuple> que;
    que.push(mytuple(BIT_SIZE-1,s,e));
    while(!que.empty()){
      mytuple q = que.top(); que.pop();
      int b = q.b, st = q.s, en = q.e;
      if(b < 0){
        ret.push_back(std::make_pair(v[st], en-st));
        if((int)ret.size() >= k) break;
      }else{
        int os = matrix[b].rank(st) + sep[b];
        int oe = matrix[b].rank(en) + sep[b];
        int zs = st - matrix[b].rank(st);
        int ze = en - matrix[b].rank(en);
        if(ze-zs > 0) que.push(mytuple(b-1,zs,ze));
        if(oe-os > 0) que.push(mytuple(b-1,os,oe));
      }
    }
    return ret;
  }

  //[s,e)の範囲でx<=c<yを満たすような数値cの合計出現数
  int rangefreq(int s, int e, VAL_TYPE x, VAL_TYPE y){
    int ret = 0;
    std::queue<std::pair<mytuple,VAL_TYPE>> que;
    que.push(std::make_pair(mytuple(BIT_SIZE-1,s,e),0));
    while(!que.empty()){
      std::pair<mytuple,VAL_TYPE> q = que.front(); que.pop();
      int b = q.first.b, st = q.first.s, en = q.first.e;
      VAL_TYPE mn = q.second;
      VAL_TYPE mx = q.second | ((b>=0)?0:((-1ULL) >> (BIT_SIZE - 1 - b)));
      if(x <= mn && mx < y){
        ret += en-st;
      }
      else if(mx < x || y <= mn){
        continue;
      }
      else {
        if(b < 0) continue;
        int os = matrix[b].rank(st) + sep[b];
        int oe = matrix[b].rank(en) + sep[b];
        int zs = st - matrix[b].rank(st);
        int ze = en - matrix[b].rank(en);
        if(ze-zs > 0) que.push(std::make_pair(mytuple(b-1,zs,ze), q.second));
        if(oe-os > 0) que.push(std::make_pair(mytuple(b-1,os,oe), q.second | (1ULL << b)));
      }
    }
    return ret;
  }
};

//XBW
class XBW {
  using VAL_TYPE = uint8_t;
  const char LAST_CHAR = (char)(0xff);
  
  struct Trie {
    bool flg;
    std::string rpp;
    std::map<char,Trie> next;
    Trie(){ flg = false; }
    void insert(const std::string &str){
      Trie *r = this;
      for(size_t i=0; i<str.length(); i++){
        r = &(r->next[str[i]]);
      }
      r->flg = true;
    }
    bool find(const std::string &str){
      Trie *r = this;
      for(size_t i=0; i<str.length(); i++){
        if(r->next.count(str[i]) == 0) return false;
        r = &(r->next[str[i]]);
      }
      return r->flg;
    }
  };
  struct ST {
    std::string children;
    std::string rpp;
    ST(std::string children, std::string rpp):children(children),rpp(rpp){}
    bool operator<(const ST& x) const { return rpp < x.rpp; }
  };

  Trie root;
  int xbw_size;
  std::string xbw_str;

  WaveletMatrix wm;
  FID fid;
  std::map<char,int> C;

  void build(std::vector<ST>& v){
    //XBWのサイズと文字の出現数のカウント
    std::map<char,int> cnt;
    xbw_size = 0;
    for(size_t i=0; i<v.size(); i++){
      xbw_size += v[i].children.length();
      for(size_t j=0; j<v[i].children.length(); j++){
        cnt[v[i].children[j]]++;
      }
    }

    //構築
    wm = WaveletMatrix(xbw_size);
    fid = FID(xbw_size);
    int idx = 0;
    for(size_t i=0; i<v.size(); i++){
      fid.set(idx);
      for(size_t j=0; j<v[i].children.length(); j++){
        wm.set(idx, (VAL_TYPE)(v[i].children[j]));
        idx++;
      }
    }
    wm.finalize();
    fid.finalize();

    C[(char)(0)] = 1;
    for(int i=1; i<256; i++){
      C[(char)(i)] = C[(char)(i-1)] + cnt[(char)(i-1)];
    }

    //trieの削除
    //root.next.clear();
  }

  int rank(int i, VAL_TYPE x){
    int pos = fid.select(i);
    return wm.rank(((pos<0)?xbw_size:pos),x);
  }  
public:
  XBW():wm(1),fid(1){}

  void add(const std::string& key){
    root.insert(key);
  }

  void finalize(){
    //chilren, reverse prefix pathの表の作成
    std::vector<ST> v;
    std::queue<Trie*> que;
    que.push(&root);
    while(!que.empty()){
      Trie *r = que.front(); que.pop();
      std::string children;
      std::string rpp = r->rpp;
      for(std::map<char,Trie>::iterator it=(r->next).begin(); it!=(r->next).end(); ++it){
        children += it->first;
        (it->second).rpp = rpp + it->first;
        que.push(&(it->second));
      }
      if(r->flg){
        children += LAST_CHAR;
      }
      std::reverse(rpp.begin(), rpp.end());
      v.push_back(ST(children, rpp));
    }

    std::sort(v.begin(), v.end());

    //XBW文字列の作成
    for(size_t i=0; i<v.size(); i++){
      if(i>0) xbw_str += "|";
      for(size_t j=0; j<v[i].children.length(); j++){
        if(v[i].children[j] == LAST_CHAR) xbw_str += "__LAST__"; //表示の都合上
        else xbw_str += v[i].children[j];
      }
    }

    build(v);
  }

  std::string get_xbw_string(){
    return xbw_str;
  }
  
  bool trie_find(const std::string& key){
    return root.find(key);
  }
  
  bool find(const std::string& key){
    int r = 0;
    for(size_t i=0; i<key.length(); i++){
      if(rank(r+1,(VAL_TYPE)(key[i])) - rank(r,(VAL_TYPE)(key[i])) == 0) return false;
      r = C[key[i]] + rank(r,(VAL_TYPE)(key[i]));
    }
    if(rank(r+1,(VAL_TYPE)(LAST_CHAR)) - rank(r,(VAL_TYPE)(LAST_CHAR)) == 0) return false;
    return true;
  }
};

int main(){
  std::mt19937 rnd{ std::random_device()() };
  std::map<std::string,int> keys, no_keys;
  int max_size = 500000; //keyとno_keysの要素数
  int max_len = 40; //文字列の長さの最大値
  
  int turn = 0;
  while(keys.size() < max_size || no_keys.size() < max_size){
    //generate key string

    int len = rnd() % max_len + 1;
    std::string key = "";
    for(int j=0; j<len; j++){
      key += (char)(' ' + rnd()%95);
    }

    if(keys.count(key) != 0 || no_keys.count(key) != 0) continue;
    
    if(turn == 0 && keys.size() < max_size){
      keys[key] = 1;
      turn = 1 - turn;
    }
    else if(turn == 1 && no_keys.size() < max_size){
      no_keys[key] = 1;
      turn = 1 - turn;
    }
  }
  std::cout << "key generated..." << std::endl;
  
  XBW xbw;
  for(const auto& x : keys){
    xbw.add(x.first);
  }
  xbw.finalize();
  std::cout << "XBW built..." << std::endl;
  
  //std::cout << "XBW = " << xbw.get_xbw_string() << std::endl;
  
  bool error = false;
  for(const auto& x : keys){
    if(xbw.trie_find(x.first) != xbw.find(x.first)){
      std::cout << "error : " << x.first << std::endl;
      error = true;
    }
  }
  for(const auto& x : no_keys){
    if(xbw.trie_find(x.first) != xbw.find(x.first)){
      std::cout << "error : " << x.first << std::endl;
      error = true;
    }
  }
  if(!error) std::cout << "no error" << std::endl;
  
  return 0;
}


確認のために、解説ページのTrieからXBWを出力してみる。
main()の内容を以下のように変更すると確認できる。

int main(){
  std::vector<std::string> v{"to","tea","ten","i","in","inn","we"};
  XBW xbw;
  for(const auto& x : v){
    xbw.add(x);
  }
  xbw.finalize();
  std::cout << xbw.get_xbw_string() << std::endl;
  return 0;
}

結果。

itw|__LAST__|an|__LAST__|n__LAST__|__LAST__|n__LAST__|__LAST__|__LAST__|eo|e