XBWを試す
はじめに
XBWをWaveletMatrixを使って、試しに実装してみた。
XBWとは
- 効率よくTrie木を表現する方法
- Burrows-Wheeler Transform(BWT)の(木への)拡張
- 詳しい解説や作り方は以下のページや「高速文字列解析の世界」などを参照
コード
XBWでの表のソートはそのまま文字列同士のソートをしている。
チェックは、コード中にあるように、適当にkey文字列とkey文字列じゃないのを生成してtrieと結果が一緒になるかだけ。
WaveletMatrixの方で理解のためにいくつか関数を書いているけど、rank()ぐらいしか使っていないので、それ以外はあまりVerifyできていない。
(g++はバージョン「5.4.0」、オプション「-std=gnu++1y -O2」で実行してる)
#include <vector> #include <map> #include <cstdint> #include <algorithm> #include <iostream> #include <queue> #include <random> //完備辞書(Fully Indexable Dictionary) // 【使いまわすときの注意】 // - 全部set()したら、最後にfinalize()を呼ぶこと // - select(x)の実装で、xが要素数よりも多い場合-1を返す実装にしている // - 32bit/64bit書き換えは、BIT_SIZE,BLOCK_TYPE,popcount,整数リテラルのサフィックスなどを書き換えること class FID { static const int BIT_SIZE = 64; using BLOCK_TYPE = uint64_t; int size; int block_size; std::vector<BLOCK_TYPE> blocks; std::vector<int> s_rank; public: //for BIT_SIZE == 32 /* BLOCK_TYPE popcount(BLOCK_TYPE x){ x = ((x & 0xaaaaaaaa) >> 1) + (x & 0x55555555); x = ((x & 0xcccccccc) >> 2) + (x & 0x33333333); x = ((x & 0xf0f0f0f0) >> 4) + (x & 0x0f0f0f0f); x = ((x & 0xff00ff00) >> 8) + (x & 0x00ff00ff); x = ((x & 0xffff0000) >> 16) + (x & 0x0000ffff); return x; } */ //__builtin_popcount() //for BIT_SIZE == 64 BLOCK_TYPE popcount(BLOCK_TYPE x){ x = ((x & 0xaaaaaaaaaaaaaaaaULL) >> 1) + (x & 0x5555555555555555ULL); x = ((x & 0xccccccccccccccccULL) >> 2) + (x & 0x3333333333333333ULL); x = ((x & 0xf0f0f0f0f0f0f0f0ULL) >> 4) + (x & 0x0f0f0f0f0f0f0f0fULL); x = ((x & 0xff00ff00ff00ff00ULL) >> 8) + (x & 0x00ff00ff00ff00ffULL); x = ((x & 0xffff0000ffff0000ULL) >> 16) + (x & 0x0000ffff0000ffffULL); x = ((x & 0xffffffff00000000ULL) >> 32) + (x & 0x00000000ffffffffULL); return x; } //__builtin_popcountll() public: FID(int size): size(size), block_size(((size + BIT_SIZE - 1) / BIT_SIZE) + 1), blocks(block_size, 0), s_rank(block_size, 0){} void set(int i){ blocks[i/BIT_SIZE] |= 1ULL << (i%BIT_SIZE); } void finalize(){ s_rank[0] = 0; for(int i=1; i<block_size; i++){ s_rank[i] = s_rank[i-1] + popcount(blocks[i-1]); } } bool access(int i){ return (blocks[i/BIT_SIZE] >> (i%BIT_SIZE)) & 1ULL; } //iより前のビットが立っている個数 int rank(int i){ BLOCK_TYPE mask = (1ULL << (i%BIT_SIZE)) - 1; return s_rank[i/BIT_SIZE] + popcount(mask & blocks[i/BIT_SIZE]); } //x番目にビットが立っている位置 int select(int x){ if(rank((block_size-1) * BIT_SIZE) <= x) return -1; //注意 int lb = 0, ub = block_size-1; while(ub-lb>1){ int m = (lb+ub)/2; if(s_rank[m]<=x) lb = m; else ub = m; } int lbb = lb*BIT_SIZE, ubb = (lb+1)*BIT_SIZE; while(ubb-lbb>1){ int m = (lbb+ubb)/2; if(rank(m)<=x) lbb = m; else ubb = m; } return lbb; } }; //ウェーブレット行列(Wavelet Matrix) // 【使いまわすときの注意】 // - 全部set()したら、最後にfinalize()を呼ぶこと // - 32bit/64bit書き換えは、BIT_SIZE,VAL_TYPEなどを書き換えること class WaveletMatrix { static const int BIT_SIZE = 8; using VAL_TYPE = uint8_t; int size; std::vector<VAL_TYPE> v; std::vector<FID> matrix; std::vector<int> sep; struct mytuple { int b, s, e; mytuple(int b, int s, int e):b(b),s(s),e(e){} bool operator<(const mytuple& x) const { return e-s < x.e-x.s; } }; public: WaveletMatrix(int size): size(size), v(size, 0), matrix(BIT_SIZE, FID(size)), sep(BIT_SIZE, 0){} void set(int i, VAL_TYPE val){ v[i] = val; } void finalize(){ std::vector<VAL_TYPE> w(v.size(), 0); for(int b=BIT_SIZE-1; b>=0; b--){ for(int i=0; i<size; i++){ if((v[i] >> b) & 1ULL) matrix[b].set(i); else sep[b]++; } int b1=0, b2=sep[b]; for(int i=0; i<size; i++){ if((v[i] >> b) & 1ULL) w[b2++] = v[i]; else w[b1++] = v[i]; } for(int i=0; i<size; i++){ v[i] = w[i]; } matrix[b].finalize(); } } //元の配列のi番目の要素 VAL_TYPE access(int i){ VAL_TYPE ret = 0; for(int b=BIT_SIZE-1; b>=0; b--){ if(matrix[b].access(i)){ i = sep[b] + matrix[b].rank(i); ret = (ret << 1) + 1ULL; }else{ i = i - matrix[b].rank(i); ret = (ret << 1); } } return ret; } //[0,i)の範囲にxが何個存在するか int rank(int i, VAL_TYPE x){ int lb = 0, ub = i; for(int b=BIT_SIZE-1; b>=0; b--){ if((x >> b) & 1ULL){ lb = matrix[b].rank(lb); ub = matrix[b].rank(ub); lb += sep[b]; ub += sep[b]; }else{ lb = lb - matrix[b].rank(lb); ub = ub - matrix[b].rank(ub); } } return ub - lb; } //i番目(0-index)のxが出現する位置 int select(int i, VAL_TYPE x){ int lb = 0, ub = size; while(ub-lb>1){ int m = (lb+ub)/2; if(rank(m, x)<=i) lb = m; else ub = m; } return lb; } //[s,e)の範囲を切り出してソートしたときのn番目(0-index)の要素 VAL_TYPE quantile(int s, int e, int n){ for(int b=BIT_SIZE-1; b>=0; b--){ int zn = (e - s) - (matrix[b].rank(e) - matrix[b].rank(s)); if(zn <= n){ s = matrix[b].rank(s); e = matrix[b].rank(e); s += sep[b]; e += sep[b]; n = n - zn; }else{ s = s - matrix[b].rank(s); e = e - matrix[b].rank(e); } } return v[s]; } //[s,e)の範囲で出現回数が多い数値順に、その数値と出現回数のTop-K std::vector<std::pair<VAL_TYPE,int>> top_k(int s, int e, int k){ std::vector<std::pair<VAL_TYPE,int>> ret; std::priority_queue<mytuple> que; que.push(mytuple(BIT_SIZE-1,s,e)); while(!que.empty()){ mytuple q = que.top(); que.pop(); int b = q.b, st = q.s, en = q.e; if(b < 0){ ret.push_back(std::make_pair(v[st], en-st)); if((int)ret.size() >= k) break; }else{ int os = matrix[b].rank(st) + sep[b]; int oe = matrix[b].rank(en) + sep[b]; int zs = st - matrix[b].rank(st); int ze = en - matrix[b].rank(en); if(ze-zs > 0) que.push(mytuple(b-1,zs,ze)); if(oe-os > 0) que.push(mytuple(b-1,os,oe)); } } return ret; } //[s,e)の範囲でx<=c<yを満たすような数値cの合計出現数 int rangefreq(int s, int e, VAL_TYPE x, VAL_TYPE y){ int ret = 0; std::queue<std::pair<mytuple,VAL_TYPE>> que; que.push(std::make_pair(mytuple(BIT_SIZE-1,s,e),0)); while(!que.empty()){ std::pair<mytuple,VAL_TYPE> q = que.front(); que.pop(); int b = q.first.b, st = q.first.s, en = q.first.e; VAL_TYPE mn = q.second; VAL_TYPE mx = q.second | ((b>=0)?0:((-1ULL) >> (BIT_SIZE - 1 - b))); if(x <= mn && mx < y){ ret += en-st; } else if(mx < x || y <= mn){ continue; } else { if(b < 0) continue; int os = matrix[b].rank(st) + sep[b]; int oe = matrix[b].rank(en) + sep[b]; int zs = st - matrix[b].rank(st); int ze = en - matrix[b].rank(en); if(ze-zs > 0) que.push(std::make_pair(mytuple(b-1,zs,ze), q.second)); if(oe-os > 0) que.push(std::make_pair(mytuple(b-1,os,oe), q.second | (1ULL << b))); } } return ret; } }; //XBW class XBW { using VAL_TYPE = uint8_t; const char LAST_CHAR = (char)(0xff); struct Trie { bool flg; std::string rpp; std::map<char,Trie> next; Trie(){ flg = false; } void insert(const std::string &str){ Trie *r = this; for(size_t i=0; i<str.length(); i++){ r = &(r->next[str[i]]); } r->flg = true; } bool find(const std::string &str){ Trie *r = this; for(size_t i=0; i<str.length(); i++){ if(r->next.count(str[i]) == 0) return false; r = &(r->next[str[i]]); } return r->flg; } }; struct ST { std::string children; std::string rpp; ST(std::string children, std::string rpp):children(children),rpp(rpp){} bool operator<(const ST& x) const { return rpp < x.rpp; } }; Trie root; int xbw_size; std::string xbw_str; WaveletMatrix wm; FID fid; std::map<char,int> C; void build(std::vector<ST>& v){ //XBWのサイズと文字の出現数のカウント std::map<char,int> cnt; xbw_size = 0; for(size_t i=0; i<v.size(); i++){ xbw_size += v[i].children.length(); for(size_t j=0; j<v[i].children.length(); j++){ cnt[v[i].children[j]]++; } } //構築 wm = WaveletMatrix(xbw_size); fid = FID(xbw_size); int idx = 0; for(size_t i=0; i<v.size(); i++){ fid.set(idx); for(size_t j=0; j<v[i].children.length(); j++){ wm.set(idx, (VAL_TYPE)(v[i].children[j])); idx++; } } wm.finalize(); fid.finalize(); C[(char)(0)] = 1; for(int i=1; i<256; i++){ C[(char)(i)] = C[(char)(i-1)] + cnt[(char)(i-1)]; } //trieの削除 //root.next.clear(); } int rank(int i, VAL_TYPE x){ int pos = fid.select(i); return wm.rank(((pos<0)?xbw_size:pos),x); } public: XBW():wm(1),fid(1){} void add(const std::string& key){ root.insert(key); } void finalize(){ //chilren, reverse prefix pathの表の作成 std::vector<ST> v; std::queue<Trie*> que; que.push(&root); while(!que.empty()){ Trie *r = que.front(); que.pop(); std::string children; std::string rpp = r->rpp; for(std::map<char,Trie>::iterator it=(r->next).begin(); it!=(r->next).end(); ++it){ children += it->first; (it->second).rpp = rpp + it->first; que.push(&(it->second)); } if(r->flg){ children += LAST_CHAR; } std::reverse(rpp.begin(), rpp.end()); v.push_back(ST(children, rpp)); } std::sort(v.begin(), v.end()); //XBW文字列の作成 for(size_t i=0; i<v.size(); i++){ if(i>0) xbw_str += "|"; for(size_t j=0; j<v[i].children.length(); j++){ if(v[i].children[j] == LAST_CHAR) xbw_str += "__LAST__"; //表示の都合上 else xbw_str += v[i].children[j]; } } build(v); } std::string get_xbw_string(){ return xbw_str; } bool trie_find(const std::string& key){ return root.find(key); } bool find(const std::string& key){ int r = 0; for(size_t i=0; i<key.length(); i++){ if(rank(r+1,(VAL_TYPE)(key[i])) - rank(r,(VAL_TYPE)(key[i])) == 0) return false; r = C[key[i]] + rank(r,(VAL_TYPE)(key[i])); } if(rank(r+1,(VAL_TYPE)(LAST_CHAR)) - rank(r,(VAL_TYPE)(LAST_CHAR)) == 0) return false; return true; } }; int main(){ std::mt19937 rnd{ std::random_device()() }; std::map<std::string,int> keys, no_keys; int max_size = 500000; //keyとno_keysの要素数 int max_len = 40; //文字列の長さの最大値 int turn = 0; while(keys.size() < max_size || no_keys.size() < max_size){ //generate key string int len = rnd() % max_len + 1; std::string key = ""; for(int j=0; j<len; j++){ key += (char)(' ' + rnd()%95); } if(keys.count(key) != 0 || no_keys.count(key) != 0) continue; if(turn == 0 && keys.size() < max_size){ keys[key] = 1; turn = 1 - turn; } else if(turn == 1 && no_keys.size() < max_size){ no_keys[key] = 1; turn = 1 - turn; } } std::cout << "key generated..." << std::endl; XBW xbw; for(const auto& x : keys){ xbw.add(x.first); } xbw.finalize(); std::cout << "XBW built..." << std::endl; //std::cout << "XBW = " << xbw.get_xbw_string() << std::endl; bool error = false; for(const auto& x : keys){ if(xbw.trie_find(x.first) != xbw.find(x.first)){ std::cout << "error : " << x.first << std::endl; error = true; } } for(const auto& x : no_keys){ if(xbw.trie_find(x.first) != xbw.find(x.first)){ std::cout << "error : " << x.first << std::endl; error = true; } } if(!error) std::cout << "no error" << std::endl; return 0; }
確認のために、解説ページのTrieからXBWを出力してみる。
main()の内容を以下のように変更すると確認できる。
int main(){ std::vector<std::string> v{"to","tea","ten","i","in","inn","we"}; XBW xbw; for(const auto& x : v){ xbw.add(x); } xbw.finalize(); std::cout << xbw.get_xbw_string() << std::endl; return 0; }
結果。
itw|__LAST__|an|__LAST__|n__LAST__|__LAST__|n__LAST__|__LAST__|__LAST__|eo|e