Label Propagationを試す
はじめに
前から気になっていた、label propagationを試してみる。
マルチエージェントとかの合意問題っぽい感じ?
Learning fron Labeled and Unlabeled Data with Label Propagation
- http://lvk.cs.msu.su/~bruzz/articles/classification/zhu02learning.pdf
- 教師あり学習ではラベル付きデータが必要だが、そういうデータが豊富に使えるとは限らない
- 少量のラベルデータと、大量のラベルのついていないデータ、というような場合がしばしばあったりする
- そこで、ラベルのついているデータの周辺のデータにその情報を伝搬させることで、ラベルのついていないデータのラベルを推定する
- 論文では行列Wをユークリッド距離に基づく重み行列としていたが、以下では単純な隣接行列で試してみた
コード
#include <iostream> #include <vector> //行列計算用 class Matrix { int m, n; std::vector<double> val; void splice(int i, int j){ int N = val.size(); if(i+(j-1)<N){ std::vector<double> vl(N-j, 0.0); int vl_idx = 0, val_idx = 0; while(val_idx<N){ if(val_idx<i||i+j-1<val_idx){ vl[vl_idx] = val[val_idx]; vl_idx++; } val_idx++; } val = vl; } } Matrix clone(){ Matrix ret(m, n); for(int i=0; i<m; i++){ for(int j=0; j<n; j++){ ret.setVal(i, j, val[n*i+j]); } } return ret; } public: Matrix():m(0),n(0){} Matrix(int m, int n):m(m),n(n){ if(m>0 && n>0){ for(int i=0; i<m*n; i++){ val.push_back(0.0); } } } int getM() const { return m; } int getN() const { return n; } double getVal(int i, int j) const { return val[n*i+j]; } void setVal(int i, int j, double x){ val[n*i+j] = x; } bool isSquare(){ return n==m; } Matrix add(const Matrix& mat){ if(m == mat.m && n == mat.n){ Matrix ret(m, n); for(int i=0; i<m; i++){ for(int j=0; j<n; j++){ ret.setVal(i, j, val[n*i+j] + mat.getVal(i,j)); } } return ret; } return Matrix(-1, -1); } Matrix operator+(const Matrix& mat){ return add(mat); } Matrix sub(const Matrix& mat){ if(m == mat.m && n == mat.n){ Matrix ret(m, n); for(int i=0; i<m; i++){ for(int j=0; j<n; j++){ ret.setVal(i, j, val[n*i+j] - mat.getVal(i,j)); } } return ret; } return Matrix(-1, -1); } Matrix operator-(const Matrix& mat){ return sub(mat); } Matrix prod(const Matrix& mat){ if(n == mat.m){ Matrix ret(m, mat.n); for(int i=0; i<m; i++){ for(int j=0; j<mat.n; j++){ double d = 0.0; for(int k=0; k<n; k++){ d += val[n*i+k] * mat.getVal(k,j); } ret.setVal(i, j, d); } } return ret; } return Matrix(-1, -1); } Matrix operator*(const Matrix& mat){ return prod(mat); } void time(double x){ for(int i=0; i<m; i++){ for(int j=0; j<n; j++){ val[n*i+j] *= x; } } } Matrix transpose(){ Matrix ret(n, m); for(int i=0; i<m; i++){ for(int j=0; j<n; j++){ ret.setVal(j, i, val[n*i+j]); } } return ret; } double cofactor(int i, int j){ Matrix mat = clone(); mat.splice(i*mat.n, mat.m); mat.m -= 1; for(int k=mat.m-1; k>=0; k--){ mat.splice(k*mat.n+j, 1); } mat.n -= 1; return mat.det() * ( ((i+j+2)%2==0) ? 1 : -1); } double det(){ if(isSquare()){ if(m == 2){ return val[0]*val[3]-val[1]*val[2]; }else{ double d = 0; for(int k=0; k<n; k++){ d += val[k] * cofactor(0, k); } return d; } }else{ return 0.0; } } Matrix cofactorMatrix(){ Matrix mat(m, n); for(int i=0; i<m; i++){ for(int j=0; j<n; j++){ mat.setVal(j, i, cofactor(i, j)); } } return mat; } Matrix inverse(){ if(isSquare()){ double d = det(); if(d != 0){ Matrix mat; if(m>2){ mat = cofactorMatrix(); } else { mat = Matrix(2, 2); mat.setVal(0, 0, val[3]); mat.setVal(0, 1, -val[1]); mat.setVal(1, 0, -val[2]); mat.setVal(1, 1, val[0]); } mat.time(1 / d); return mat; } }else{ return Matrix(-1, -1); } return Matrix(-1, -1); } }; std::ostream& operator<<(std::ostream& os, const Matrix& mat){ for(int i=0; i<mat.getM(); i++){ for(int j=0; j<mat.getN(); j++){ os << mat.getVal(i,j) << " "; } os << std::endl; } return os; } //Label Propagation class LabelPropagation { Matrix W, D; Matrix Y, nowY; int labeled; public: LabelPropagation(Matrix W, Matrix Y, int labeled): W(W), D(W.getM(),W.getN()), Y(Y), nowY(Y), labeled(labeled) { for(int i=0; i<W.getM(); i++){ double sum = 0.0; for(int j=0; j<W.getN(); j++){ sum += W.getVal(i, j); } D.setVal(i, i, sum); } } //繰り返し計算で求める void compute_iteration(int iter){ Matrix Dinv(W.getM(), W.getN()); for(int i=0; i<W.getM(); i++){ Dinv.setVal(i, i, 1.0 / D.getVal(i, i)); } for(int i=0; i<iter; i++){ std::cout << i << ":\t"; nowY = Dinv * (W * nowY); for(int i=0; i<labeled; i++){ nowY.setVal(0, i, Y.getVal(0, i)); } std::cout << nowY.transpose(); } } //Fixed Pointを行列計算で求める void compute_fixedpoint(){ int unlabeled = W.getM()-labeled; Matrix T(W.getM(), W.getN()); Matrix Tbar(W.getM(), W.getN()); Matrix I(unlabeled, unlabeled); Matrix Tuu(unlabeled, unlabeled); Matrix Tul(unlabeled, labeled); Matrix Yl(labeled, 1); //make T T = W; //make Tbar for(int i=0; i<T.getM(); i++){ double sum = 0.0; for(int k=0; k<T.getN(); k++){ sum += T.getVal(i, k); } for(int j=0; j<T.getN(); j++){ Tbar.setVal(i, j, T.getVal(i, j)/sum); } } //make Tuu for(int i=0; i<unlabeled; i++){ for(int j=0; j<unlabeled; j++){ Tuu.setVal(i, j, Tbar.getVal(labeled + i, labeled + j)); } } //make Tul for(int i=0; i<unlabeled; i++){ for(int j=0; j<labeled; j++){ Tul.setVal(i, j, Tbar.getVal(labeled + i, j)); } } //make I for(int i=0; i<I.getM(); i++){ I.setVal(i, i, 1.0); } //make Yl for(int i=0; i<labeled; i++){ Yl.setVal(i, 0, Y.getVal(i, 0)); } //fixed point Matrix Yu = (I-Tuu).inverse() * Tul * Yl; std::cout << "Labeled: " << Yl.transpose(); std::cout << "Unlabeled: " << Yu.transpose(); } }; int main(){ int M, labeled; double v; //Input std::cin >> M >> labeled; Matrix W(M, M), Y(M, 1); for(int i=0; i<M; i++){ for(int j=0; j<M; j++){ std::cin >> v; W.setVal(i,j,v); } } for(int i=0; i<labeled; i++){ std::cin >> v; Y.setVal(i, 0, v); } //Label Propagation // Y := {y1, ..., y_labeled, 0, ..., 0} LabelPropagation lp(W, Y, labeled); std::cout << "Iteration:" << std::endl; lp.compute_iteration(100); std::cout << "Fixed Point:" << std::endl; lp.compute_fixedpoint(); return 0; }
結果
入力形式は、以下。
ノード数 ラベル数 隣接行列 ラベルデータ
ノード番号は、0〜(ラベル数-1)までがラベル付きで、残りがラベルなしを想定。
どっちもうまくできてるっぽい。
Test 1
$ cat test1.txt 10 6 0 1 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 1 1 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 1 0 0 0 1 1 1 0 0 1 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 0 0 -1 -1 -1 1 1 1 $ ./a.out < test1.txt Iteration: 0: -1 -1 -1 1 1 1 0.333333 -1 -0.333333 1 1: -1 -1 -1 1 1 1 0.333333 -1 -0.333333 1 2: -1 -1 -1 1 1 1 0.333333 -1 -0.333333 1 ... 97: -1 -1 -1 1 1 1 0.333333 -1 -0.333333 1 98: -1 -1 -1 1 1 1 0.333333 -1 -0.333333 1 99: -1 -1 -1 1 1 1 0.333333 -1 -0.333333 1 Fixed Point: Labeled: -1 -1 -1 1 1 1 Unlabeled: 0.333333 -1 -0.333333 1
Test 2
$ cat test2.txt 5 2 0 0 1 0 0 0 0 0 0 1 1 0 0 1 0 0 0 1 0 1 0 1 0 1 0 -1 1 $ ./a.out < test2.txt Iteration: 0: -1 1 -0.5 0 0.5 1: -1 1 -0.5 0 0.5 2: -1 1 -0.5 0 0.5 ... 97: -1 1 -0.5 0 0.5 98: -1 1 -0.5 0 0.5 99: -1 1 -0.5 0 0.5 Fixed Point: Labeled: -1 1 Unlabeled: -0.5 0 0.5