Elman netを試す
はじめに
プロフェッショナルな「深層学習」本で紹介されているRNNの一種のElman netを書いてみる。
Recurrent Neural Network(RNN)とは
- 再帰型ニューラルネット
- ネットワーク内部に内部有向閉路を持つようなニューラルネットの総称
- Feedforwardの時は、入力層から出力層へ一方方向
- この構造のおかげで、時系列や言語モデル、系列ラベリングなど前の状態を持つような問題に対して考慮できる
- いろんな種類がある(以下はwikipediaから)
- Fully Recurrent network
- Hopfield network
- Boltzmann machine
- Simple recurrent network
- Elman net
- Jordan net
- Echo state network
- Long short term memory(LSTM) network
- Bi-directional RNN
- Hierarchical RNN
- Stochastic neural networks
- Fully Recurrent network
Simple recurrent network
- 言葉のとおり、シンプルな構造のRNN
- 代表的には「Elman net」と「Jordan net」がある
- 3層(入力層、隠れ層、出力層)のFeedforward Neural netについて、どこからどこへ閉路を作るかで異なる
- Elman net : 隠れ層から隠れ層への辺がある
- Jordan net : 出力層から隠れ層への辺がある
コード
「深層学習」本では、Elman netのBPTT法(Backpropagation through time、時間でネットワークを分けて普通に逆誤差伝播)での学習が紹介されていると思われるので、これを書いてみる。
#include <iostream> #include <vector> #include <cstdio> #include <algorithm> #include <cmath> static const double PI = 3.14159265358979323846264338; //xorshift // 注意: longではなくint(32bit)にすべき unsigned long xor128(){ static unsigned long x=123456789, y=362436069, z=521288629, w=88675123; unsigned long t; t=(x^(x<<11)); x=y; y=z; z=w; return w=(w^(w>>19))^(t^(t>>8)); } //[0,1)の一様乱数 double frand(){ return xor128()%1000000/static_cast<double>(1000000); } //正規乱数 double normal_rand(double mu, double sigma2){ double sigma = sqrt(sigma2); double u1 = frand(), u2 = frand(); double z1 = sqrt(-2*log(u1)) * cos(2*PI*u2); //double z2 = sqrt(-2*log(u1)) * sin(2*PI*u2); return mu + sigma*z1; } //隠れ層が1層の単純なRNN // BPTT(Backpropagation through time)法で学習 class SimpleRNN { double eps; //学習率 int Nin, Nhid, Nout; //各層のユニット数(バイアスを除く) std::vector< std::vector<double> > W; //時刻t-1での隠れ層から時刻tでの隠れ層へ std::vector< std::vector<double> > Win; //時刻tでの入力層から隠れ層へ std::vector< std::vector<double> > Wout; //時刻tでの隠れ層から出力層へ std::vector< std::vector<double> > u, v; //各時刻での隠れ層の入力, 出力層の入力 std::vector< std::vector<double> > z, y; //各時刻での隠れ層の出力, 出力層の出力 std::vector< std::vector<double> > delta, delta_out; //各時刻での各ユニットのδ値 public: SimpleRNN(int Nin, int Nhid, int Nout, double eps): eps(eps), Nin(Nin), Nhid(Nhid), Nout(Nout), W(Nhid+1, std::vector<double>(Nhid, 0)), Win(Nin+1, std::vector<double>(Nhid, 0)), Wout(Nhid+1, std::vector<double>(Nout, 0)) { for(int i=0; i<Nhid+1; i++){ for(int j=0; j<Nhid; j++){ W[i][j] = normal_rand(0.0, 0.1); } } for(int i=0; i<Nin+1; i++){ for(int j=0; j<Nhid; j++){ Win[i][j] = normal_rand(0.0, 0.1); } } for(int i=0; i<Nhid+1; i++){ for(int j=0; j<Nout; j++){ Wout[i][j] = normal_rand(0.0, 0.1); } } } std::vector<int> forward_propagation(const std::vector< std::vector<double> >& in){ int T = in.size(); u = std::vector< std::vector<double> >(T, std::vector<double>(Nhid, 0)); v = std::vector< std::vector<double> >(T, std::vector<double>(Nout, 0)); z = std::vector< std::vector<double> >(T, std::vector<double>(Nhid+1, 0)); y = std::vector< std::vector<double> >(T, std::vector<double>(Nout, 0)); //各時刻tでの値を求める std::vector<int> ret; for(int t=0; t<T; t++){ for(int i=0; i<Nhid; i++){ u[t][i] = 0.0; } //入力層の出力->隠れ層の入力 for(int i=0; i<Nin; i++){ for(int j=0; j<Nhid; j++){ u[t][j] += in[t][i] * Win[i][j]; } } for(int i=0; i<Nhid; i++){ //バイアス u[t][i] += 1.0 * Win[Nin][i]; } //時刻t-1での隠れ層の出力->時刻tでの隠れ層の入力 for(int i=0; i<Nhid+1; i++){ if(t!=0){ for(int j=0; j<Nhid; j++){ u[t][j] += z[t-1][i] * W[i][j]; } } if(t==0 && i==Nhid){ //バイアス for(int j=0; j<Nhid; j++){ u[t][j] += 1.0 * W[Nin][j]; } } } //時刻tでの隠れ層の出力 for(int i=0; i<Nhid; i++){ z[t][i] = 1.0 / (1.0 + exp(-u[t][i])); } z[t][Nhid] = 1.0; //時刻tでの出力層の入力 for(int i=0; i<Nhid+1; i++){ for(int j=0; j<Nout; j++){ v[t][j] += z[t][i] * Wout[i][j]; } } //時刻tでの出力層の出力 double Z = 0; double mx = -1.0; int mx_i = -1; for(int i=0; i<Nout; i++){ Z += exp(v[t][i]); } for(int i=0; i<Nout; i++){ y[t][i] = exp(v[t][i]) / Z; if(mx < y[t][i]){ mx = y[t][i]; mx_i = i; } } ret.push_back(mx_i); } return ret; } double back_propagation(const std::vector< std::vector<double> >& in, const std::vector<int>& out){ double err = 0.0; int T = in.size(); std::vector<int> res = forward_propagation(in); delta = std::vector< std::vector<double> >(T, std::vector<double>(Nhid, 0)); delta_out = std::vector< std::vector<double> >(T, std::vector<double>(Nout, 0)); //時刻の逆方向からdelta値と重みの更新を連続的にしていく for(int t=T-1; t>=0; t--){ //時刻tでの出力層のdelta値 for(int i=0; i<Nout; i++){ if(out[t] == i){ delta_out[t][i] = y[t][i] - 1.0; err += - 1.0 * log(y[t][i]); }else{ delta_out[t][i] = y[t][i] - 0.0; } } //時刻tでの隠れ層の出力層からのdelta値 for(int i=0; i<Nhid; i++){ for(int j=0; j<Nout; j++){ double fu = 1.0 / (1.0 + exp(-u[t][i])); double fdash = fu * (1.0 - fu); delta[t][i] += Wout[i][j] * delta_out[t][j] * fdash; } } //時刻tでの隠れ層の時刻t+1での隠れ層からのdelta値 for(int i=0; i<Nhid; i++){ for(int j=0; j<Nhid; j++){ if(t+1>=T) continue; //時刻T以降の場合はdelta=0として扱う double fu = 1.0 / (1.0 + exp(-u[t][i])); double fdash = fu * (1.0 - fu); delta[t][i] += W[i][j] * delta[t+1][j] * fdash; } } //delta値を使って各重みの微分を計算し、SGDで重みを更新 for(int j=0; j<Nhid+1; j++){ //隠れ層->出力層 for(int k=0; k<Nout; k++){ Wout[j][k] -= eps * (delta_out[t][k] * z[t][j]); } } for(int pj=0; pj<Nhid+1; pj++){ //隠れ層->隠れ層 for(int j=0; j<Nhid; j++){ if(t-1>=0){ W[pj][j] -= eps * (delta[t][j] * z[t-1][pj]); }else{ //時刻t<0の場合は、バイアスだけ1でそれ以外の出力は0として扱う if(pj==Nhid){ //バイアス W[pj][j] -= eps * (delta[t][j] * 1.0); } } } } for(int i=0; i<Nin+1; i++){ //入力層->隠れ層 for(int j=0; j<Nhid; j++){ if(i<Nin){ Win[i][j] -= eps * (delta[t][j] * in[t][i]); }else{ //バイアス Win[i][j] -= eps * (delta[t][j] * 1.0); } } } } return err; } }; int main(){ const int TERM = 500; //TERM個単位でerrを確認 const double finish_err = 0.01; //errがfinish_err未満になったら学習終了 int Nin, Nhid, Nout; double eps; int trainN, testN; std::cin >> Nin >> Nhid >> Nout; std::cin >> eps; /// 学習 //////////////////// std::cin >> trainN; std::vector< std::vector< std::vector<double> > > in; std::vector< std::vector<int> > out; for(int i=0; i<trainN; i++){ int t; std::cin >> t; std::vector< std::vector<double> > in_one; std::vector<int> out_one; for(int j=0; j<t; j++){ double val; int outval; std::cin >> outval; std::vector<double> v; for(int k=0; k<Nin; k++){ std::cin >> val; v.push_back(val); } in_one.push_back(v); out_one.push_back(outval); } in.push_back(in_one); out.push_back(out_one); } //データシャッフル用 std::vector<int> rnd; for(int i=0; i<trainN; i++){ rnd.push_back(i); } std::random_shuffle(rnd.begin(), rnd.end()); //学習ループ SimpleRNN rnn(Nin, Nhid, Nout, eps); int iter = 0; double err = 0.0; while(true){ //TERM個単位で確認 if(iter!=0 && iter%TERM == 0){ err /= TERM; std::cerr << "err = " << err << std::endl; if(err < finish_err) break; err = 0; } err += rnn.back_propagation(in[rnd[iter%trainN]], out[rnd[iter%trainN]]); iter++; } /// テスト //////////////////// std::cin >> testN; for(int i=0; i<testN; i++){ int t; std::cin >> t; std::vector< std::vector<double> > in_one; std::vector<int> out_one; for(int j=0; j<t; j++){ double val; int outval; std::cin >> outval; std::vector<double> v; for(int k=0; k<Nin; k++){ std::cin >> val; v.push_back(val); } in_one.push_back(v); out_one.push_back(outval); } //結果の出力 std::vector<int> res = rnn.forward_propagation(in_one); std::cout << "ref:"; for(int j=0; j<out_one.size(); j++){ std::cout << " " << out_one[j]; } std::cout << std::endl; std::cout << "out:"; for(int j=0; j<res.size(); j++){ std::cout << " " << res[j]; } std::cout << std::endl; std::cout << std::endl; } return 0; }
入力ファイルの形式
[入力層のユニット数] [隠れ層のユニット数] [出力層のユニット数] [学習率] [学習インスタンス数] [ケース1の数] [出力1] [次元1の値] [次元2の値] ... [出力2] [次元1の値] [次元2の値] ... [出力3] [次元1の値] [次元2の値] ... ... [テストインスタンス数] [ケース1の数] [出力1] [次元1の値] [次元2の値] ... [出力2] [次元1の値] [次元2の値] ... [出力3] [次元1の値] [次元2の値] ... ...
結果
おもちゃなケースで確認だけ。
http://d.hatena.ne.jp/jetbead/20111125/1322236626 でのサンプルを今回の形式に変換したもの。
- 入力
10 10 3 0.01 6 3 2 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4 2 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 3 2 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 2 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 2 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4 2 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 2 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 3 2 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 2 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 4 2 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 6 3 2 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4 2 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 3 2 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 2 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 2 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4 2 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 2 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 3 2 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 2 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 4 2 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0
- 出力
$ ./a.out < in err = 2.20794 err = 1.84924 err = 1.54379 err = 1.22335 err = 0.931081 err = 0.681177 err = 0.493641 err = 0.360733 err = 0.265542 err = 0.198631 err = 0.153081 err = 0.121373 err = 0.0987013 err = 0.0824887 err = 0.0702189 err = 0.0606353 err = 0.0533114 err = 0.0473444 err = 0.0423609 err = 0.0383958 err = 0.034996 err = 0.0320187 err = 0.0295968 err = 0.0274426 err = 0.0254882 err = 0.0238809 err = 0.0224107 err = 0.0210389 err = 0.0199064 err = 0.0188471 err = 0.0178354 err = 0.0170008 err = 0.0162053 err = 0.0154302 err = 0.0147935 err = 0.0141766 err = 0.0135647 err = 0.0130652 err = 0.0125743 err = 0.0120795 err = 0.0116788 err = 0.0112797 err = 0.0108714 err = 0.010544 err = 0.0102138 err = 0.00987132 ref: 2 2 2 out: 2 2 2 ref: 2 2 0 1 out: 2 2 0 1 ref: 2 2 2 out: 2 2 2 ref: 2 2 0 1 out: 2 2 0 1 ref: 2 2 2 out: 2 2 2 ref: 2 2 2 2 out: 2 2 2 2
ちゃんと収束して学習データはちゃんとラベルつけられている。
しかし、展開する時間tが長くなるほどネットワークが深く(deep)なってしまい、勾配消失(または発散)問題が起きてしまってうまく学習できない。
参考
- 岡谷, 機械学習プロフェッショナルシリーズ「深層学習」, 講談社
- http://www.slideshare.net/beam2d/pfi-seminar-20141030rnn
- http://qiita.com/icoxfog417/items/2791ee878deee0d0fd9c
- https://en.wikipedia.org/wiki/Types_of_artificial_neural_networks#Recurrent_neural_network
- https://github.com/mattya/RNN-colle/wiki/Tutorial_jp
- http://wbawakate.jp/wp-content/uploads/2015/03/RNN%E3%83%95%E3%82%9A%E3%83%AC%E3%82%BB%E3%82%99%E3%83%B3.pdf