Elman netを試す

はじめに

プロフェッショナルな「深層学習」本で紹介されているRNNの一種のElman netを書いてみる。

Recurrent Neural Network(RNN)とは

  • 再帰ニューラルネット
  • ネットワーク内部に内部有向閉路を持つようなニューラルネットの総称
    • Feedforwardの時は、入力層から出力層へ一方方向
  • この構造のおかげで、時系列や言語モデル、系列ラベリングなど前の状態を持つような問題に対して考慮できる
  • いろんな種類がある(以下はwikipediaから)
    • Fully Recurrent network
      • Hopfield network
      • Boltzmann machine
    • Simple recurrent network
      • Elman net
      • Jordan net
    • Echo state network
    • Long short term memory(LSTM) network
    • Bi-directional RNN
    • Hierarchical RNN
    • Stochastic neural networks
Simple recurrent network
  • 言葉のとおり、シンプルな構造のRNN
  • 代表的には「Elman net」と「Jordan net」がある
  • 3層(入力層、隠れ層、出力層)のFeedforward Neural netについて、どこからどこへ閉路を作るかで異なる
    • Elman net : 隠れ層から隠れ層への辺がある
    • Jordan net : 出力層から隠れ層への辺がある

コード

「深層学習」本では、Elman netのBPTT法(Backpropagation through time、時間でネットワークを分けて普通に逆誤差伝播)での学習が紹介されていると思われるので、これを書いてみる。

#include <iostream>
#include <vector>
#include <cstdio>
#include <algorithm>
#include <cmath>

static const double PI = 3.14159265358979323846264338;

//xorshift
// 注意: longではなくint(32bit)にすべき
unsigned long xor128(){
  static unsigned long x=123456789, y=362436069, z=521288629, w=88675123;
  unsigned long t;
  t=(x^(x<<11));
  x=y; y=z; z=w;
  return w=(w^(w>>19))^(t^(t>>8));
}
//[0,1)の一様乱数
double frand(){
  return xor128()%1000000/static_cast<double>(1000000); 
}
//正規乱数
double normal_rand(double mu, double sigma2){
  double sigma = sqrt(sigma2);
  double u1 = frand(), u2 = frand();
  double z1 = sqrt(-2*log(u1)) * cos(2*PI*u2);
  //double z2 = sqrt(-2*log(u1)) * sin(2*PI*u2);
  return mu + sigma*z1;
}


//隠れ層が1層の単純なRNN
// BPTT(Backpropagation through time)法で学習
class SimpleRNN {
  double eps; //学習率
  int Nin, Nhid, Nout; //各層のユニット数(バイアスを除く)
  std::vector< std::vector<double> > W; //時刻t-1での隠れ層から時刻tでの隠れ層へ
  std::vector< std::vector<double> > Win; //時刻tでの入力層から隠れ層へ
  std::vector< std::vector<double> > Wout; //時刻tでの隠れ層から出力層へ

  std::vector< std::vector<double> > u, v; //各時刻での隠れ層の入力, 出力層の入力
  std::vector< std::vector<double> > z, y; //各時刻での隠れ層の出力, 出力層の出力
  std::vector< std::vector<double> > delta, delta_out; //各時刻での各ユニットのδ値

public:
  SimpleRNN(int Nin, int Nhid, int Nout, double eps):
    eps(eps),
    Nin(Nin),
    Nhid(Nhid),
    Nout(Nout),
    W(Nhid+1, std::vector<double>(Nhid, 0)),
    Win(Nin+1, std::vector<double>(Nhid, 0)),
    Wout(Nhid+1, std::vector<double>(Nout, 0))
  {
    for(int i=0; i<Nhid+1; i++){
      for(int j=0; j<Nhid; j++){
        W[i][j] = normal_rand(0.0, 0.1);
      }
    }
    for(int i=0; i<Nin+1; i++){
      for(int j=0; j<Nhid; j++){
        Win[i][j] = normal_rand(0.0, 0.1);
      }
    }
    for(int i=0; i<Nhid+1; i++){
      for(int j=0; j<Nout; j++){
        Wout[i][j] = normal_rand(0.0, 0.1);
      }
    }
  }

  std::vector<int> forward_propagation(const std::vector< std::vector<double> >& in){
    int T = in.size();
    
    u = std::vector< std::vector<double> >(T, std::vector<double>(Nhid, 0));
    v = std::vector< std::vector<double> >(T, std::vector<double>(Nout, 0));
    z = std::vector< std::vector<double> >(T, std::vector<double>(Nhid+1, 0));
    y = std::vector< std::vector<double> >(T, std::vector<double>(Nout, 0));

    //各時刻tでの値を求める
    std::vector<int> ret;

    for(int t=0; t<T; t++){
      for(int i=0; i<Nhid; i++){
        u[t][i] = 0.0;
      }
      //入力層の出力->隠れ層の入力
      for(int i=0; i<Nin; i++){
        for(int j=0; j<Nhid; j++){
          u[t][j] += in[t][i] * Win[i][j];
        }
      }
      for(int i=0; i<Nhid; i++){ //バイアス
        u[t][i] += 1.0 * Win[Nin][i];
      }
      //時刻t-1での隠れ層の出力->時刻tでの隠れ層の入力
      for(int i=0; i<Nhid+1; i++){
        if(t!=0){
          for(int j=0; j<Nhid; j++){
            u[t][j] += z[t-1][i] * W[i][j];
          }
        }
        if(t==0 && i==Nhid){ //バイアス
          for(int j=0; j<Nhid; j++){
            u[t][j] += 1.0 * W[Nin][j];
          }
        }
      }
      //時刻tでの隠れ層の出力
      for(int i=0; i<Nhid; i++){
        z[t][i] = 1.0 / (1.0 + exp(-u[t][i]));
      }
      z[t][Nhid] = 1.0;
      //時刻tでの出力層の入力
      for(int i=0; i<Nhid+1; i++){
        for(int j=0; j<Nout; j++){
          v[t][j] += z[t][i] * Wout[i][j];
        }
      }
      //時刻tでの出力層の出力
      double Z = 0;
      double mx = -1.0;
      int mx_i = -1;
      for(int i=0; i<Nout; i++){
        Z += exp(v[t][i]);
      }
      for(int i=0; i<Nout; i++){
        y[t][i] = exp(v[t][i]) / Z;

        if(mx < y[t][i]){
          mx = y[t][i];
          mx_i = i;
        }
      }
      ret.push_back(mx_i);
    }
    return ret;
  }

  double back_propagation(const std::vector< std::vector<double> >& in, const std::vector<int>& out){
    double err = 0.0;
    int T = in.size();
    std::vector<int> res = forward_propagation(in);

    delta = std::vector< std::vector<double> >(T, std::vector<double>(Nhid, 0));
    delta_out = std::vector< std::vector<double> >(T, std::vector<double>(Nout, 0));

    //時刻の逆方向からdelta値と重みの更新を連続的にしていく
    for(int t=T-1; t>=0; t--){
      //時刻tでの出力層のdelta値
      for(int i=0; i<Nout; i++){
        if(out[t] == i){
          delta_out[t][i] = y[t][i] - 1.0;
          err += - 1.0 * log(y[t][i]);
        }else{
          delta_out[t][i] = y[t][i] - 0.0;
        }
      }
      //時刻tでの隠れ層の出力層からのdelta値
      for(int i=0; i<Nhid; i++){
        for(int j=0; j<Nout; j++){
          double fu = 1.0 / (1.0 + exp(-u[t][i]));
          double fdash = fu * (1.0 - fu);
          delta[t][i] += Wout[i][j] * delta_out[t][j] * fdash;
        }
      }
      //時刻tでの隠れ層の時刻t+1での隠れ層からのdelta値
      for(int i=0; i<Nhid; i++){
        for(int j=0; j<Nhid; j++){
          if(t+1>=T) continue; //時刻T以降の場合はdelta=0として扱う

          double fu = 1.0 / (1.0 + exp(-u[t][i]));
          double fdash = fu * (1.0 - fu);
          delta[t][i] += W[i][j] * delta[t+1][j] * fdash;
        }
      }
      //delta値を使って各重みの微分を計算し、SGDで重みを更新
      for(int j=0; j<Nhid+1; j++){ //隠れ層->出力層
        for(int k=0; k<Nout; k++){
          Wout[j][k] -= eps * (delta_out[t][k] * z[t][j]);
        }
      }
      for(int pj=0; pj<Nhid+1; pj++){ //隠れ層->隠れ層
        for(int j=0; j<Nhid; j++){
          if(t-1>=0){
            W[pj][j] -= eps * (delta[t][j] * z[t-1][pj]);
          }else{ //時刻t<0の場合は、バイアスだけ1でそれ以外の出力は0として扱う
            if(pj==Nhid){ //バイアス
              W[pj][j] -= eps * (delta[t][j] * 1.0);
            }
          }
        }
      }
      for(int i=0; i<Nin+1; i++){ //入力層->隠れ層
        for(int j=0; j<Nhid; j++){
          if(i<Nin){
            Win[i][j] -= eps * (delta[t][j] * in[t][i]);
          }else{ //バイアス
            Win[i][j] -= eps * (delta[t][j] * 1.0);
          }
        }
      }

    }
    return err;
  }

};

int main(){
  const int TERM = 500; //TERM個単位でerrを確認
  const double finish_err = 0.01; //errがfinish_err未満になったら学習終了

  int Nin, Nhid, Nout;
  double eps;
  int trainN, testN;

  std::cin >> Nin >> Nhid >> Nout;
  std::cin >> eps;
  
  /// 学習 ////////////////////
  std::cin >> trainN;
  std::vector< std::vector< std::vector<double> > > in;
  std::vector< std::vector<int> > out;
  for(int i=0; i<trainN; i++){
    int t;
    std::cin >> t;

    std::vector< std::vector<double> > in_one;
    std::vector<int> out_one;
    for(int j=0; j<t; j++){
      double val;
      int outval;
      std::cin >> outval;
      std::vector<double> v;
      for(int k=0; k<Nin; k++){
        std::cin >> val;
        v.push_back(val);
      }
      in_one.push_back(v);
      out_one.push_back(outval);
    }

    in.push_back(in_one);
    out.push_back(out_one);
  }
  //データシャッフル用
  std::vector<int> rnd;
  for(int i=0; i<trainN; i++){
    rnd.push_back(i);
  }
  std::random_shuffle(rnd.begin(), rnd.end());

  //学習ループ
  SimpleRNN rnn(Nin, Nhid, Nout, eps);
  int iter = 0;
  double err = 0.0;
  while(true){
    //TERM個単位で確認
    if(iter!=0 && iter%TERM == 0){
      err /= TERM;
      std::cerr << "err = " << err << std::endl;
      if(err < finish_err) break;
      err = 0;
    }

    err += rnn.back_propagation(in[rnd[iter%trainN]], out[rnd[iter%trainN]]);
    iter++;
  }

  /// テスト ////////////////////
  std::cin >> testN;
  for(int i=0; i<testN; i++){
    int t;
    std::cin >> t;

    std::vector< std::vector<double> > in_one;
    std::vector<int> out_one;
    for(int j=0; j<t; j++){
      double val;
      int outval;
      std::cin >> outval;
      std::vector<double> v;
      for(int k=0; k<Nin; k++){
        std::cin >> val;
        v.push_back(val);
      }
      in_one.push_back(v);
      out_one.push_back(outval);
    }

    //結果の出力
    std::vector<int> res = rnn.forward_propagation(in_one);
    std::cout << "ref:";
    for(int j=0; j<out_one.size(); j++){
      std::cout << " " << out_one[j];
    }
    std::cout << std::endl;

    std::cout << "out:";
    for(int j=0; j<res.size(); j++){
      std::cout << " " << res[j];
    }
    std::cout << std::endl;

    std::cout << std::endl;
  }

  return 0;
}

入力ファイルの形式

[入力層のユニット数]  [隠れ層のユニット数]  [出力層のユニット数]
[学習率]
[学習インスタンス数]
[ケース1の数]
[出力1]  [次元1の値]  [次元2の値] ...
[出力2]  [次元1の値]  [次元2の値] ...
[出力3]  [次元1の値]  [次元2の値] ...
...
[テストインスタンス数]
[ケース1の数]
[出力1]  [次元1の値]  [次元2の値] ...
[出力2]  [次元1の値]  [次元2の値] ...
[出力3]  [次元1の値]  [次元2の値] ...
...

結果

おもちゃなケースで確認だけ。
http://d.hatena.ne.jp/jetbead/20111125/1322236626 でのサンプルを今回の形式に変換したもの。

  • 入力
10 10 3
0.01
6
3
2  1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2  0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2  0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4
2  1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2  0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0  0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
1  0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3
2  0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
2  0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0
2  0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4
2  0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
2  0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0
0  0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
1  0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3
2  1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2  0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0
2  0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
4
2  0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
2  0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
2  0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
2  0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0
6
3
2  1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2  0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2  0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4
2  1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2  0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0  0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
1  0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3
2  0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
2  0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0
2  0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4
2  0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0
2  0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0
0  0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
1  0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3
2  1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2  0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0
2  0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
4
2  0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0
2  0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0
2  0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
2  0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0
  • 出力
$ ./a.out < in
err = 2.20794
err = 1.84924
err = 1.54379
err = 1.22335
err = 0.931081
err = 0.681177
err = 0.493641
err = 0.360733
err = 0.265542
err = 0.198631
err = 0.153081
err = 0.121373
err = 0.0987013
err = 0.0824887
err = 0.0702189
err = 0.0606353
err = 0.0533114
err = 0.0473444
err = 0.0423609
err = 0.0383958
err = 0.034996
err = 0.0320187
err = 0.0295968
err = 0.0274426
err = 0.0254882
err = 0.0238809
err = 0.0224107
err = 0.0210389
err = 0.0199064
err = 0.0188471
err = 0.0178354
err = 0.0170008
err = 0.0162053
err = 0.0154302
err = 0.0147935
err = 0.0141766
err = 0.0135647
err = 0.0130652
err = 0.0125743
err = 0.0120795
err = 0.0116788
err = 0.0112797
err = 0.0108714
err = 0.010544
err = 0.0102138
err = 0.00987132
ref: 2 2 2
out: 2 2 2

ref: 2 2 0 1
out: 2 2 0 1

ref: 2 2 2
out: 2 2 2

ref: 2 2 0 1
out: 2 2 0 1

ref: 2 2 2
out: 2 2 2

ref: 2 2 2 2
out: 2 2 2 2

ちゃんと収束して学習データはちゃんとラベルつけられている。
しかし、展開する時間tが長くなるほどネットワークが深く(deep)なってしまい、勾配消失(または発散)問題が起きてしまってうまく学習できない。