takminの書きっぱなし備忘録 @はてなブログ

主にコンピュータビジョンなど技術について、たまに自分自身のことや思いついたことなど

OpenCVのCvDTreeの使用にあたっての注意点

OpenCVのMachine Learning関係はマニュアル見てもネットで探しても貧弱な情報しかないので、とりあえず自分が発見したことをさらしていく。

CvDTreeクラスは二分木を作ってくれるクラスで、CART(Classification and Regression Trees)モデルと言われているものの実装。詳しくは、ここらへんを参考

http://cwoweb2.bai.ne.jp/~jgb11101/files/CART.pdf
http://ibisforest.org/index.php?CART

また、以下はCvDTreeクラスについて一通りドキュメント読んだんだけど、実装がうまくいかない、という人を想定してます。まだ読んでない人は以下を読むこと。
http://www.opencv.jp/opencv-1.0.0/document/opencvref_ml_dtree.html

ここでは、例えば以下のような学習元データと教師データを与えることを想定。また、回帰分析ではなく、判別分析を想定しています。

  • 学習元データ
0.4,-1,3.2
0.5,-0.5,1.5
2,-1.5,2.1
4,-1.1,1.21
2.4,-0.6,2.1
0.3,-2.1,1.03
2.4,-1.2,2.3
1.4,-0.3,3
1.2,-0.8,2.6
0.8,-0.21,1.2
 -0.2,0.3,0.3
 -0.1,1.2,0.7
 -0.5,3.2,-0.3
 -1,1.1,-1.1
 -2,0.7,0
 -1.5,0.3,0.4
 -0.7,2.8,0.2
 -0.92,1.2,-0.5
 -0.12,1.1,-0.4
 -1.1,2,0.5
  • 教師データ
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
0

ここで、各行が一つのサンプルデータで、それが20サンプルあるという意味。

これを学習させる場合、例えばこんなコードになる。

CvMat* data = readInputData(filename)  // 学習データ読み込み
CvMat* res = readReaponseData(filename2)  // 教師データ読み込み

CvDTree* weakClassifier = new CvDTree();

CvMat* var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
cvSet( var_type, cvScalarAll(CV_VAR_NUMERICAL) );
var_type->data.ptr[data->cols] = CV_VAR_CATEGORICAL;

CvDTreeParams param = CvDTreeParams( 1, // max depth
                                 2, // min sample count
                                 0, // regression accuracy: N/A here
                                 false, // compute surrogate split, as we have missing data
                                 2, // max number of categories (use sub-optimal algorithm for larger numbers)
                                 0, // the number of cross-validation folds
                                 false, // use 1SE rule => smaller tree
                                 false, // throw away the pruned tree branches
                                 0 // the array of priors, the bigger p_weight, the more attention
                                 );

weakClassifier->train(data, CV_ROW_SAMPLE, res, 0, 0, var_type, 0, param);

ここで、注意すべきは2点。

・必ずカテゴリデータか、連続データかを指定するためのvar_typeデータ行列を用意すること。その際は学習データの列数(ここでは3)+1個の行と1つの列を用意。
・var_typeについて、入力値が連続値(浮動小数点で扱うもの)の場合はCV_VAR_NUMERICAL、離散値(ラベル。整数で表すもの)の場合はCV_VAR_CATEGORICALを選ぶ。ここでは、入力値が連続値で、教師信号が離散値なので、CV_VAR_NUMERICALを全体に対してセットした後、var_typeの最後の要素にCV_VAR_CATEGORICALをセットする。

======================
2009/04/13 追記
var_typeに関するソースおよび記述が間違っていたので修正しました。