FOBOSを試す

はじめに

FOBOSという最近のオンライン学習の方法を知ったので試してみた。

FOBOSとは

  • 劣勾配法という手法では、微分不可能な点を持つ目的関数でも学習可能だった
  • しかし正則化(特にL1正則化)を組み合わせてもうまく機能してくれなかった
  • 「勾配による重みの更新」と「正則化のための重みの更新」を分けることで正則化を機能させることができて、収束の証明もできた

使用したデータ

結果

#学習
$ svm-train -t 0 a9a.txt a9a.model
#予測
$ svm-predict a9a.t a9a.model out
Accuracy = 84.9764% (13835/16281) (classification)」
  • svm_FOBOS
    • 線形カーネル&適当にパラメータを指定した場合
    • perl svm_fobos.pl 学習データ 学習率 正則化の強さ < テストデータ」
# 正則化をかけない場合
$ perl svm_fobos.pl a9a.txt 0.06 0.0 < a9a.t
training..........finish!
Result : 0.835452367790676 (13602/16281)
Zero element : 0.0487804878048781 (6/123)

# 正則化をかけた場合
perl svm_fobos.pl a9a.txt 0.06 0.005 < a9a.t
training..........finish!
Result : 0.830907192432897 (13528/16281)
Zero element : 0.650406504065041 (80/123)
  • 確かに、ほとんど精度を落とさず正則化によって少ない素性で表現できてるみたい!
    • 43個ほどの素性だけで分類できてる
  • パラメータの選び方がわからない
    • 適当に試してよさそうなパラメータを選んだ

用いたコード

#! /usr/bin/perl
# Usage : perl svm_fobos.pl train_file param_eta param_lambda_hat < test_file
use strict;
use warnings;

#学習ファイル名
my $train_file = shift;
#パラメータ
my $param_eta = shift;
my $param_lambda_hat = shift;

#訓練回数(学習データの個数*$loop個の学習を行う)
my $loop = 10;

#重みベクトル
my $w = {};

## 学習データの読み込み
my @x_list;
my @t_list;
open IN, $train_file;
while(<IN>){
    chomp;
    my @list = split(/\s+/, $_);
    push(@t_list, $list[0]);
    my $hash;
    for(my $i=1; $i<@list; $i++){
	my ($a, $b) = split(/:/,$list[$i]);
	$hash->{$a} = $b;
	$w->{$a} = 0;
    }
    push(@x_list, $hash);
}

## 訓練
$|=1; #printのオートフラッシュ有効
print "training";
while($loop--){
    print ".";
    for(my $i = 0; $i < @x_list; $i++){
	train($w, $x_list[$i], $t_list[$i], $param_eta, $param_lambda_hat);
    }
}
print "finish!\n";

## 推定
my $num = 0;
my $success = 0;
while(<>){
    chomp;
    my @list = split(/\s+/, $_);
    my $hash;
    for(my $i=1; $i<@list; $i++){
	my ($a, $b) = split(/:/,$list[$i]);
	$hash->{$a} = $b;
    }

    my $t = predict($w, $hash);
    $num++;
    if($t * $list[0] > 0){
	$success++;
    }
}

## 結果の出力
print "Result : ",($success/$num)," (",$success,"/",$num,")\n";

#重みが0の要素の割合
my $elem_num = 0;
my $zero_num = 0;
foreach my $f(keys %$w){
    $elem_num++;
    if($w->{$f} == 0){
	$zero_num++;
    }
}
print "Zero element : ",($zero_num/$elem_num)," (",$zero_num,"/",$elem_num,")\n";

##################################################
#予測
sub predict {
    my $w = shift;
    my $x = shift;
    
    my $y = 0;
    foreach my $f (keys %$x){
	if($w->{$f}){
	    $y += ($w->{$f} * $x->{$f});
	}
    }
    return $y;
}

#学習
sub train {
    my $w = shift;
    my $x = shift;
    my $t = shift;
    my $eta = shift;
    my $lambda_hat = shift;

    #損失による勾配
    my $y = predict($w, $x);
    if($y*$t < 1){
	foreach my $f(keys %$x){
	    $w->{$f} += $eta * $t * $x->{$f};
	}

	
	#正則化
	foreach my $f(keys %$w){
	    if($w->{$f} > 0){
		if($w->{$f} - $lambda_hat > 0){
		    $w->{$f} -= $lambda_hat;
		}else{
		    $w->{$f} = 0;
		}
	    }else{
		if($w->{$f} + $lambda_hat > 0){
		    $w->{$f} = 0;
		}else{
		    $w->{$f} += $lambda_hat;
		}
	    }
	}
    }
}