RangeCoderを試す

はじめに

RangeCoderのメモ。
中途半端な理解で適当に書き写そうとしたらひどいことになったので、まとめておく。。。

RangeCoderとは

以下、桁上りありの場合について下記サイトを参考に作成してみた。
http://www.geocities.jp/m_hiroi/light/pyalgo36.html

他にも「桁上りなし」や「適応型」などがあるよう。
http://www.geocities.jp/m_hiroi/light/pyalgo37.html

アルゴリズム

動作原理

  • ある範囲に対し、シンボルの出現確率で区分けして、該当するシンボルの範囲を次の範囲として、シンボルを割り当て続ける
  • 最終的に、範囲の左端の数値さえわかれば、同じ処理でシンボルを復元できる
範囲の拡大
  • かなり大きい範囲を用意しておかないとすぐに範囲が狭まってしまってシンボルの出現確率で区分けするに十分な整数がなくなってしまう
  • そこで、ある程度の大きさの範囲からスタートし、ある程度の大きさ以下になったら全体をx倍するような処理を繰り返すことで、これに対処する


実装方法
  • 上記は多倍長などを用いなくても、普通の整数型を使って、都度、上位byteを出力してしまって、範囲を表す変数を常にある程度の大きさに収まるようにすることで実装できる
  • 範囲の最大値max_range、拡大する範囲の閾値min_range、範囲の左端をlow、範囲の幅をrangeとすると、上記で言っているのは「lowをmax_range以下に保つように実装する」ということ
  • ここでは、以下のように定める
    • max_rangeのバイト数をrange_byteとする
    • min_rangeのバイト数はrange_byte-1
    • 拡大率は0x100(256)倍
拡大処理に伴う問題
  • 拡大処理を行うときに上位byteの出力に伴う問題がいくつかある


  • 図で、low=0x123456のときrangeが小さくなりすぎて拡大したとする(lowとrangeをそれぞれ0x100倍)
  • lowは0x12345600となるが、max_range以下に収まるようにしたいので、上位byteの0x12を出力してしまって、下位の0x345600を次のlowとする
  • 問題は「出力済みの桁の繰り上げ(桁上り)が発生しうる」こと
桁上り
  • 図で、単純に上位byteの0x12を出力してしまうと、次の範囲処理でmax_range=0x1000000を超えた場合、実は出力すべきは0x13だった、、、ということが起きうる
  • また、拡大処理した際に出力byteが0xffだとさらに出力をさかのぼって桁上りを処理しなければならない
    • 出力済みが「0x12, 0xff, 0xff, 0xff」で桁上りが発生したら「0x13,0x00,0x00,0x00」にしなければならない
  • これの解決方法の一つとして、実際に出力はせずに「バッファにためる」ことで確定するまで出力を保留する方法がある
    • 出力候補の先頭byteをbuff、後続の0xffの個数をcntとする
  • 一つ目の「次の範囲処理max_rangeを超える場合」は、超えた場合にはrangeがmax_range以下であるため、桁上り分は+1しかならないので、buff++をして、lowはmax_range以下の部分だけにマスクすればよい
  • このとき、cntが1個以上の場合は「0x12, 0xff, ..., 0xff」を+1すると「0x13, 0x00, ..., 0x00」のような状態に変化するので、最後の「0x00」以外は出力してしまって問題ない
  • 二つ目の「拡大処理した際」は、出力候補のlowの上位bit部分が0xffなら桁上りがまだあり得るので出力はせずにcnt+1だけ、0xff未満ならbuffとcnt個分の0xffを出力してbuffにその値をいれてあげればよい
終了処理
  • 残っているbuff,cnt,lowをすべて出力して終了
シンボルの出現頻度テーブル
  • デコード処理ではシンボルの出現確率で区分けするために出現頻度テーブルも保存しておく
  • 注意として、出現頻度の合計値はmin_range以下である必要がある
    • そうでないとrangeを区分に分割するときに整数に割り当てられない場合ができてしまう
  • 実装上は、頻度値をshort型などで持たせることで、頻度合計値をmin_range以下にしたり、テーブル保存時の容量削減をすると良いよう
デコード処理
  • 出現頻度テーブルと出力した数値を頭から読み込んでいけばよい
  • エンコード処理と同様に、rangeが小さくなりすぎたら拡大しながら読み込んでいく

コード

いくつかのデータで復元できているので、おそらく大丈夫。

#include <iostream>
#include <vector>

class RangeCoder {
  const int64_t range_byte = 6; //max_rangeのバイト数
  const int size = 0x10000; //出現するデータの種類数の最大値(入力シンボルをuint16_tにしているため)
  
  const int64_t max_range = 0x1LL << (8 * range_byte); //rangeの最大幅
  const int64_t min_range = 0x1LL << (8 * (range_byte-1)); //rangeを拡大する幅の閾値
  const int64_t mask = max_range - 1; //max_range分のマスク(0xfff...fff)
  const int64_t shift = 8 * (range_byte-1); //出力分計算用(バッファ処理用「0xff00...00」生成と1byteずつの出力するため用)

  std::vector<int32_t> count, count_sum; //頻度分布、頻度累積分布
  int32_t sum; //頻度合計値

  int32_t orig_size; //入力シンボル数保存用
  int pos; // 出力位置保存用

  void init(){
    for(int i=0; i<size; i++){
      count[i] = 0;
      count_sum[i] = 0;
    }
    pos = 0;
  }

  //頻度分布、頻度累積分布の作成
  void make_count(const std::vector<uint16_t>& in){
    init();

    for(uint16_t ch : in){
      int chi = ch & 0xffff;
      count[chi]++;
    }

    //頻度値を16bitに抑えるため、頻度値全体を1/2^n倍する
    int32_t max_count = 0;
    for(int i=0; i<size; i++){
      max_count = std::max(max_count, count[i]);
    }
    if(max_count > 0xffff){
      int n = 0;
      while(max_count > 0xffff){
        max_count >>= 1;
        n++;
      }
      for(int i=0; i<size; i++){
        if(count[i] > 0){
          count[i] >>= n;
          count[i] |= 1;
        }
      }
    }
    //頻度合計値はmin_range以下にしなければいけないので、
    //抑えるための処理(1/2を繰り返す)
    //注意: 無限ループに入りうる
    while(true){
      int32_t c_sum = 0;
      for(int i=0; i<size; i++){
        c_sum += count[i];
      }
      if(c_sum < min_range) break;
      for(int i=0; i<size; i++){
        if(count[i] > 0){
          count[i] >>= 1;
          if(count[i] == 0) count[i] = 1;
        }
      }
    }

    //頻度累積分布の作成
    sum = 0;
    for(int i=0; i<size; i++){
      count_sum[i+1] = sum + count[i];
      sum += count[i];
    }
  }

  //1byte分出力
  void putc(std::vector<uint8_t>& ret, uint8_t c){
    ret.push_back(c);
  }

  //1byte分読込
  int getc(const std::vector<uint8_t>& in){
    if(pos >= in.size()) return 0;
    return in[pos++] & 0xff;
  }

  //ヘッダ情報の出力
  void save_header(std::vector<uint8_t>& ret, int32_t orig_size){
    //シンボル数
    putc(ret, (orig_size>>24) & 0xff);
    putc(ret, (orig_size>>16) & 0xff);
    putc(ret, (orig_size>>8) & 0xff);
    putc(ret, orig_size & 0xff);

    //頻度分布
    int32_t num = 0; //シンボルのユニーク数
    for(int i=0; i<size; i++) if(count[i] > 0) num++;
    putc(ret, (num>>24) & 0xff);
    putc(ret, (num>>16) & 0xff);
    putc(ret, (num>>8) & 0xff);
    putc(ret, num & 0xff);

    for(int i=0; i<size; i++){
      if(count[i] > 0){
        //シンボル番号
        putc(ret, (i>>8) & 0xff);
        putc(ret, i & 0xff);
        //シンボルの出現頻度
        putc(ret, (count[i]>>8) & 0xff);
        putc(ret, count[i] & 0xff);
      }
    }
  }

  //ヘッダ情報の読込
  void load_header(const std::vector<uint8_t>& in){
    init();

    //シンボル数
    orig_size = getc(in); orig_size <<= 8;
    orig_size |= getc(in); orig_size <<= 8;
    orig_size |= getc(in); orig_size <<= 8;
    orig_size |= getc(in);

    //頻度分布
    int32_t num = getc(in); num <<= 8;
    num |= getc(in); num <<= 8;
    num |= getc(in); num <<= 8;
    num |= getc(in);

    for(int i=0; i<num; i++){
      //シンボル番号
      int32_t id = getc(in); id <<= 8;
      id |= getc(in);
      //シンボルの出現頻度
      int32_t cnt = getc(in); cnt <<= 8;
      cnt |= getc(in);

      count[id] = cnt;
    }

    //頻度累積分布の作成
    sum = 0;
    for(int i=0; i<size; i++){
      count_sum[i+1] = sum + count[i];
      sum += count[i];
    }
  }

  //デコード用範囲から該当するシンボルの探索
  int32_t search_code(int32_t val){
    int32_t i = 0;
    int32_t j = size-1;
    while(i < j){
      int32_t k = (i + j) / 2;
      if(count_sum[k+1] <= val){
        i = k + 1;
      }else{
        j = k;
      }
    }
    return i;
  }

public:
  RangeCoder():count(size+1, 0), count_sum(size+1, 0){}

  std::vector<uint8_t> encode(const std::vector<uint16_t>& in){
    std::vector<uint8_t> ret;

    make_count(in);
    save_header(ret, in.size());

    //桁上りバッファ用
    // buff: 出力候補
    // cnt: buff以降に続く出力候補0xffの数
    // => [0x12, 0xff, 0xff, ..., 0xff]のような情報
    int64_t buff = 0, cnt = 0;
    //範囲計算用
    int64_t low = 0, range = max_range; //下限と範囲

    
    for(int i=0; i<in.size(); i++){
      //入力シンボル
      int32_t ch = in[i] & 0xffff;
      //該当範囲の計算
      int64_t temp = range / sum;
      low += count_sum[ch] * temp;
      range = count[ch] * temp;

      //桁上りの処理
      //該当範囲がmax_rangeを超えてしまった場合
      if(low >= max_range){
        buff++;
        low &= mask;
        //もしcntが0より大きければ、[0x12,0xff,0xff...,0xff]のような状態なので、
        //buffが+1されたことで[0x13,0x00,0x00...,0x00]のようになり、最後の0x00以外は出力してよい
        if(cnt > 0){
          putc(ret, buff & 0xff); //buffの出力
          for(int j=0; j<cnt-1; j++) putc(ret, 0x00); //0x00をcnt-1個分出力、最後の0x00はbuffとする
          buff = 0x00;
          cnt = 0x00;
        }
      }
      //範囲が小さくなったら全体をを拡大(256倍)
      while(range < min_range){
        //拡大することでmax_rangeを超える上位1バイト分が、
        // - 0xffより小さければ上位8bitは0xffではないので、
        //   バッファを出力して、その上位8bit分をbuffにいれる(rangeは範囲内なので0xfe以下なら絶対に0xffまでしかいかない)
        // - 0xffの場合は、まだ桁上りがあるかもしれないので、0xffを増やす意味でcntを+1する
        if(low < (0xffLL << shift)){ //low < 0xff000...000
          putc(ret, buff & 0xff);
          for(int j=0; j<cnt; j++) putc(ret, 0xff);
          buff = (low >> shift) & 0xff;
          cnt = 0;
        }else{
          cnt++;
        }
        //全体を256倍に拡大する
        low = (low << 8) & mask;
        range <<= 8;
      }
    }
    //最後に残っている情報(buff, cnt, low)をすべて出力
    int32_t ch = 0xff;
    if(low >= max_range){
      buff++;
      ch = 0;
    }
    putc(ret, buff & 0xff);
    for(int j=0; j<cnt; j++) putc(ret, ch & 0xff);
    for(int j=shift; j>=0; j-=8){
      putc(ret, (low >> j) & 0xff);
    }

    return ret;
  }

  std::vector<uint16_t> decode(const std::vector<uint8_t>& in){
    std::vector<uint16_t> ret;

    load_header(in);
    getc(in);

    int64_t range = max_range;
    int64_t low = 0;
    for(int i=shift; i>=0; i-=8){
      low |= getc(in);
      if(i>0) low <<= 8;
    }

    while(ret.size() < orig_size){
      int64_t temp = range / sum;
      int32_t ch = search_code(low / temp);
      low -= temp * count_sum[ch];
      range = temp * count[ch];
      
      while(range < min_range){
        range <<= 8;
        low = ((low << 8) + getc(in)) & mask;
      }
      
      ret.push_back( ch & 0xffff );
    }
    return ret;
  }
};

int main(){
  RangeCoder encoder, decoder;

  //入力シンボル列
  std::vector<uint16_t> v;
  for(int i=0; i<10; i++){
    v.push_back(4);
    v.push_back(3);
    v.push_back(2);
    v.push_back(2);
    v.push_back(1);
    v.push_back(1);
    v.push_back(1);
    v.push_back(1);
  }

  //エンコード
  std::cout << "encoding" << std::endl;
  std::vector<uint8_t> encode_code = encoder.encode(v);

  std::cout << v.size() * 16 << " => " << encode_code.size() * 8 << std::endl;
  
  //デコード
  std::cout << "decoding" << std::endl;
  std::vector<uint16_t> decode_code = decoder.decode(encode_code);

  //結果の確認
  for(int i=0; i<decode_code.size(); i++){
    std::cout << i << "\t" << v[i] << "\t" << decode_code[i] << std::endl;
  }

  return 0;
}

結果

encoding
1280 => 384
decoding
0       4       4
1       3       3
2       2       2
3       2       2
4       1       1
5       1       1
6       1       1
7       1       1
8       4       4
...

1280bitのシンボル列を384bitに圧縮できているよう。
上記、書き換え不可能な出力ではなく、入れた後で書き換え可能な出力用vectorに入れているので、素直に桁上げしにいく処理にした方がコードは短くなりそう。