多腕バンディットとUCB1で遊ぶ

はじめに

ちょっと遊びで多腕バンディット問題で遊んでみた。

UCB1-tunedも書いてみたけどUCB1より最終的な儲けが低くてあれ?ってなった。どっか間違ってるか。。。
追記(2012/2/12):コメントをいただいて、修正しました。一応、報酬額がUCB1よりtunedの方が高くなっているので、一緒にのせてみます。

修正

コメント指摘をうけ、元論文( http://www.eecs.berkeley.edu/~pabbeel/cs287-fa09/readings/Auer+al-UCB.pdf )を確認してみました。
「K個の独立で、未知だがそれぞれ期待値がμiの一様分布に従う確率変数Xi」と定義されているみたいで、報酬期待値μ*もμiの最大と定義されているので、評価値の計算もこれを用いなければなりませんでした。
「報酬を表すi.i.d.な確率変数X_{i,t}の範囲は、UCB1の証明(Proof of Theorem1)に使っているChernoff-Hoeffding bound(Fact 1)が[0,1]としているものを使っている」のでそれにあわせなければならないのですね。


以下、「所持金」という形ではなく、「はずれれば0、当たれば+1としたときの報酬合計」を出力するように修正したコード。

コード

#include <iostream>
#include <vector>
#include <cmath>
#include <climits>
using namespace std;

//当たった場合の報酬
static const int BENEFIT = 1; //※修正

//乱数生成用
// 注意: longではなくint(32bit)にすべき
unsigned long xor128(){
  static unsigned long x=123456789, y=362436069, z=521288629, w=88675123;
  unsigned long t;
  t=(x^(x<<11));
  x=y; y=z; z=w;
  return w=(w^(w>>19))^(t^(t>>8));
}
//[0,1)の一様乱数
// 注意: int_maxぐらいで割るべき
double frand(){
  //return xor128()%1000000000/1000000000.0; //※念のためここも修正
  return xor128()%ULONG_MAX/static_cast<double>(ULONG_MAX);
}

//スロットマシン
struct bandit {
  double p; //当たる確率
  long X; //今までの報酬の合計
  long X2; //今までの報酬の2乗の合計
  int n; //これまでの総プレイ回数
  
  bandit():X(0),X2(0),n(0){
    //※修正
    while(true){
      p = frand();
      if(p<0.4) break;
    }
  }
  int play(){
    n++;
    if(p >= frand()){//当たった場合
      X += BENEFIT;
      X2 += BENEFIT * BENEFIT;
      return BENEFIT;
    }
    return 0; //外れた場合
  }
};

int main(){
  vector<bandit> bandits;
  //スロットの準備
  for(int i=0; i<10; i++){
    bandit tmp;
    bandits.push_back(tmp);
    cout << "BANDIT " << i << " -> p=" << tmp.p << endl;
  }
  cout << endl;
  
  int money = 0; //報酬額合計
  int count = 0; //ゲームカウント
  
  ////////////////////////
  //UCB1アルゴリズム
  ////////////////////////
  //1.初期化
  for(int i=0; i<bandits.size(); i++){
    money += bandits[i].play(); //プレイする
    
    //プレイ後の報酬額
    cout << count++ << " " << money << "(played:" << i << ")" << endl;
  }
  //2.繰り返し
  for(int t=bandits.size(); t<=10000; t++){    
    //すべてのbanditの中から評価値の一番高いものを選ぶ
    double eval = -1.0;
    int evali = -1;
    for(int i=0; i<bandits.size(); i++){
      double tmp = (double)(bandits[i].X)/(bandits[i].n) + sqrt(2*log((double)t)/bandits[i].n);
      if(eval < tmp){
        eval = tmp;
        evali = i;
      }
    }
    //そのマシンをプレイ
    money += bandits[evali].play();
    
    //プレイ後の報酬額
    cout << count++ << " " << money << "(played:" << evali << ")" << endl;
  }
  return 0;
}

結果

  • 各スロットの当たる確率
    • BANDIT5が一番当たる確率が高いみたい
BANDIT 0 -> p=0.106706
BANDIT 1 -> p=0.120232
BANDIT 2 -> p=0.166993
BANDIT 3 -> p=0.0320996
BANDIT 4 -> p=0.0920471
BANDIT 5 -> p=0.302981
BANDIT 6 -> p=0.276434
BANDIT 7 -> p=0.104422
BANDIT 8 -> p=0.149822
BANDIT 9 -> p=0.217559
  • 結果
    • 10000回ぐらいやった時の報酬額は2608
...
9988 2604(played:5)
9989 2604(played:5)
9990 2605(played:5)
9991 2605(played:5)
9992 2605(played:5)
9993 2605(played:5)
9994 2605(played:5)
9995 2606(played:5)
9996 2606(played:5)
9997 2606(played:5)
9998 2607(played:5)
9999 2608(played:5)
10000 2608(played:5)

コード2

  • ちゃんと情報を集めてないから自信ないけど、UCB1-tunedっぽいコード
#include <iostream>
#include <vector>
#include <cmath>
#include <climits>
using namespace std;

//当たった場合の報酬
static const int BENEFIT = 1;

//乱数生成用
// 注意: longではなくint(32bit)にすべき
unsigned long xor128(){
  static unsigned long x=123456789, y=362436069, z=521288629, w=88675123;
  unsigned long t;
  t=(x^(x<<11));
  x=y; y=z; z=w;
  return w=(w^(w>>19))^(t^(t>>8));
}
//[0,1)の一様乱数
// 注意: int_maxぐらいで割るべき
double frand(){
  return xor128()%ULONG_MAX/static_cast<double>(ULONG_MAX);
}

//スロットマシン
struct bandit {
  double p; //当たる確率
  long X; //今までの報酬の合計
  long X2; //今までの報酬の2乗の合計
  int n; //これまでの総プレイ回数
  
  bandit():X(0),X2(0),n(0){
    while(true){
      p = frand();
      if(p<0.4) break;
    }
  }
  int play(){
    n++;
    if(p >= frand()){//当たった場合
      X += BENEFIT;
      X2 += BENEFIT * BENEFIT;
      return BENEFIT;
    }
    return 0; //外れた場合
  }
};

int main(){
  vector<bandit> bandits;
  //スロットの準備
  for(int i=0; i<10; i++){
    bandit tmp;
    bandits.push_back(tmp);
    cout << "BANDIT " << i << " -> p=" << tmp.p << endl;
  }
  cout << endl;
  
  int money = 0; //報酬額合計
  int count = 0; //ゲームカウント
  
  ////////////////////////
  //UCB1-TURNEDアルゴリズム
  ////////////////////////
  //1.初期化
  for(int i=0; i<bandits.size(); i++){
    money += bandits[i].play(); //プレイする
    
    //プレイ後の報酬合計値
    cout << count++ << " " << money << "(played:" << i << ")" << endl;
  }
  //2.繰り返し
  for(int t=bandits.size(); t<=10000; t++){
    if(money<=0) break;
    
    //すべてのbanditの中から評価値の一番高いものを選ぶ
    double eval = -1.0;
    int evali = -1;
    for(int i=0; i<bandits.size(); i++){
      double aveX = (double)(bandits[i].X)/(bandits[i].n);
      double Vjs = (double)(bandits[i].X2)/(bandits[i].n) - aveX * aveX + sqrt(2*log((double)t)/bandits[i].n);
      double tmp = (double)(bandits[i].X)/(bandits[i].n) + sqrt(log((double)t)/bandits[i].n * min(0.25, Vjs));
      if(eval < tmp){
        eval = tmp;
        evali = i;
      }
    }
    //そのマシンをプレイ
    money += bandits[evali].play();
    
    //プレイ後の報酬合計値
    cout << count++ << " " << money << "(played:" << evali << ")" << endl;
  }
  return 0;
}

結果2

  • 結果
    • 10000回ぐらいやった時の報酬額は2992
...
9988 2988(played:5)
9989 2988(played:5)
9990 2989(played:5)
9991 2989(played:5)
9992 2989(played:5)
9993 2989(played:5)
9994 2989(played:5)
9995 2990(played:5)
9996 2990(played:5)
9997 2990(played:5)
9998 2991(played:5)
9999 2992(played:5)
10000 2992(played:5)

以下、修正前のコードと結果。間違っているので注意。

コード

#include <iostream>
#include <vector>
#include <cmath>
using namespace std;

//最初に持っているお金
static const int INIT_MONEY = 1000;
//当たった場合の報酬
static const int BENEFIT = 3;

//乱数生成用
unsigned long xor128(){
  static unsigned long x=123456789, y=362436069, z=521288629, w=88675123;
  unsigned long t;
  t=(x^(x<<11));
  x=y; y=z; z=w;
  return w=(w^(w>>19))^(t^(t>>8));
}
//[0,1)の一様乱数
double frand(){
  return xor128()%1000000000/1000000000.0;
}

//スロットマシン
struct bandit {
  double p; //当たる確率
  long X; //今までの報酬の合計
  long X2; //今までの報酬の2乗の合計
  int n; //これまでの総プレイ回数
  
  bandit():X(0),X2(0),n(0){
    p = frand();
    while(p>0.4) p -= 0.4; //大きくなりすぎないようにちょっと調整
  }
  int play(){
    n++;
    if(p >= frand()){//当たった場合
      X += BENEFIT;
      X2 += BENEFIT * BENEFIT;
      return BENEFIT;
    }
    return 0; //外れた場合
  }
};

int main(){
  vector<bandit> bandits;
  //スロットの準備
  for(int i=0; i<10; i++){
    bandit tmp;
    bandits.push_back(tmp);
    cout << "BANDIT " << i << " -> p=" << tmp.p << endl;
  }
  cout << endl;
  
  int money = INIT_MONEY; //所持金
  int count = 0; //ゲームカウント
  
  ////////////////////////
  //UCB1アルゴリズム
  ////////////////////////
  //1.初期化
  for(int i=0; i<bandits.size(); i++){
    money--; //コインを一枚入れて
    money += bandits[i].play(); //プレイする
    
    //プレイ後のお金
    cout << count++ << " " << money << "(played:" << i << ")" << endl;
  }
  //2.繰り返し
  for(int t=bandits.size(); t<=10000; t++){
    if(money<=0) break;
    
    //すべてのbanditの中から評価値の一番高いものを選ぶ
    double eval = -1.0;
    int evali = -1;
    for(int i=0; i<bandits.size(); i++){
      double tmp = (double)(bandits[i].X)/(bandits[i].n) + sqrt(2*log((double)t)/bandits[i].n);
      if(eval < tmp){
        eval = tmp;
        evali = i;
      }
    }
    //そのマシンをプレイ
    money--;
    money += bandits[evali].play();
    
    //プレイ後のお金
    cout << count++ << " " << money << "(played:" << evali << ")" << endl;
  }
  return 0;
}

結果

  • 各スロットの当たる確率
    • BANDIT9が一番当たる確率が高いみたい
BANDIT 0 -> p=0.301688
BANDIT 1 -> p=0.0582991
BANDIT 2 -> p=0.100873
BANDIT 3 -> p=0.233119
BANDIT 4 -> p=0.116392
BANDIT 5 -> p=0.37727
BANDIT 6 -> p=0.199949
BANDIT 7 -> p=0.31723
BANDIT 8 -> p=0.137867
BANDIT 9 -> p=0.395339
  • 結果
    • 最初は結構バラバラに選ばれるが、最終的に当たる確率が高いスロットの評価値が高くなって選ばれやすくなった
    • 1000円スタートで10000回やったら3818円まで増えた
0 1002(played:0)
1 1001(played:1)
2 1000(played:2)
3 1002(played:3)
4 1001(played:4)
5 1003(played:5)
6 1005(played:6)
7 1004(played:7)
8 1003(played:8)
9 1002(played:9)
10 1001(played:0)
11 1000(played:3)
12 999(played:5)
13 998(played:6)
...
9987 3816(played:9)
9988 3818(played:9)
9989 3817(played:9)
9990 3819(played:9)
9991 3818(played:9)
9992 3817(played:9)
9993 3816(played:9)
9994 3815(played:9)
9995 3814(played:9)
9996 3813(played:9)
9997 3815(played:9)
9998 3817(played:9)
9999 3816(played:9)
10000 3818(played:9)