ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

bayesopt を使用した交差検証済み SVM 分類器の最適化

この例では、関数 bayesopt を使用して SVM 分類を最適化する方法を示します。この分類は、混合ガウス モデルによる点の位置に作用します。モデルの詳細については、The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009) の 17 ページを参照してください。このモデルでは、平均 (1,0) および単位分散をもつ 2 次元の独立した正規分布になっている 10 個の基底点をはじめに "green" クラスについて生成します。また、平均 (0,1) と単位分散による 2 次元の独立した正規として分布される "red" クラスにも、10 個の基底点が生成されます。クラス (green と red) ごとに、次のように 100 個の無作為な点を生成します。

  1. 適切な色の基底点 m を一様にランダムに選択します。

  2. 平均 m と分散 I/5 (I は 2 行 2 列の単位行列) をもつ 2 次元正規分布を使用して、独立した無作為な点を生成します。最適化のアドバンテージをより明確に示すため、この例では I/50 という分散を使用します。

100 個の緑の点と 100 個の赤の点を生成した後で、fitcsvm を使用してこれらを分類します。次に、bayesopt を使用して、生成される SVM モデルのパラメーターを交差検証に関して最適化します。

点と分類器の生成

クラスごとに 10 個の基底点を生成します。

rng default
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

基底点を表示します。

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

赤の基底点の一部が緑の基底点の近くにあるため、位置のみによるデータ点の分類は難しいかもしれません。

各クラスについて 100 個ずつのデータ点を生成します。

redpts = zeros(100,2);grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

データ点を表示します。

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

分類用のデータの準備

データを 1 つの行列に格納し、各点のクラスにラベルを付けるベクトル grp を作成します。

cdata = [grnpts;redpts];
grp = ones(200,1);
% Green label 1, red label -1
grp(101:200) = -1;

交差検証の準備

交差検証用の分割を設定します。これにより、各ステップで最適化に使用される学習セットとテスト セットとが決まります。

c = cvpartition(200,'KFold',10);

ベイズ最適化用の変数の準備

入力として z = [rbf_sigma,boxconstraint] を受け入れ z の交差検証損失値を返す関数を設定します。z の成分を 1e-51e5 の範囲で正の変数に対数変換します。どの値が適切であるかがわからないので、広い範囲を選択します。

sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log');
box = optimizableVariable('box',[1e-5,1e5],'Transform','log');

目的関数

この関数ハンドルは、[sigma,box] というパラメーターで交差検証損失を計算します。詳細については、kfoldLossを参照してください。

bayesopt は変数 z を 1 行のテーブルとして目的関数に渡します。

minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,...
    'KernelFunction','rbf','BoxConstraint',z.box,...
    'KernelScale',z.sigma));

分類器の最適化

bayesopt を使用して最良のパラメーター [sigma,box] を求めます。再現性を得るために、'expected-improvement-plus' の獲得関数を選択します。既定の獲得関数は実行時に決定されるので、結果が異なる場合があります。

results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,...
    'AcquisitionFunctionName','expected-improvement-plus')

|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |        0.61 |     0.37034 |        0.61 |        0.61 |   0.00013375 |        13929 |
|    2 | Best   |       0.345 |     0.37734 |       0.345 |       0.345 |        24526 |        1.936 |
|    3 | Accept |        0.61 |     0.38481 |       0.345 |       0.345 |    0.0026459 |   0.00084929 |
|    4 | Accept |       0.345 |     0.43188 |       0.345 |       0.345 |       3506.3 |   6.7427e-05 |
|    5 | Accept |       0.345 |     0.26215 |       0.345 |       0.345 |       9135.2 |       571.87 |
|    6 | Accept |       0.345 |      0.2454 |       0.345 |       0.345 |        99701 |        10223 |
|    7 | Best   |       0.295 |     0.30424 |       0.295 |       0.295 |       455.88 |       9957.4 |
|    8 | Best   |        0.24 |      4.2242 |        0.24 |        0.24 |        31.56 |        99389 |
|    9 | Accept |        0.24 |      5.2582 |        0.24 |        0.24 |       10.451 |        64429 |
|   10 | Accept |        0.35 |     0.22117 |        0.24 |        0.24 |       17.331 |   1.0264e-05 |
|   11 | Best   |        0.23 |      3.3915 |        0.23 |        0.23 |       16.005 |        90155 |
|   12 | Best   |         0.1 |     0.44594 |         0.1 |         0.1 |      0.36562 |        80878 |
|   13 | Accept |       0.115 |     0.37155 |         0.1 |         0.1 |       0.1793 |        68459 |
|   14 | Accept |       0.105 |     0.27767 |         0.1 |         0.1 |       0.2267 |        95421 |
|   15 | Best   |       0.095 |     0.24329 |       0.095 |       0.095 |      0.28999 |    0.0058227 |
|   16 | Best   |       0.075 |     0.38443 |       0.075 |       0.075 |      0.30554 |       8.9017 |
|   17 | Accept |       0.085 |      0.3546 |       0.075 |       0.075 |      0.41122 |       4.4476 |
|   18 | Accept |       0.085 |     0.21907 |       0.075 |       0.075 |      0.25565 |       7.8038 |
|   19 | Accept |       0.075 |     0.21965 |       0.075 |       0.075 |      0.32869 |       18.076 |
|   20 | Accept |       0.085 |     0.27585 |       0.075 |       0.075 |      0.32442 |       5.2118 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |         0.3 |     0.24937 |       0.075 |       0.075 |       1.3592 |    0.0098067 |
|   22 | Accept |        0.12 |     0.25383 |       0.075 |       0.075 |      0.17515 |   0.00070913 |
|   23 | Accept |       0.175 |     0.21193 |       0.075 |       0.075 |       0.1252 |     0.010749 |
|   24 | Accept |       0.105 |     0.21899 |       0.075 |       0.075 |       1.1664 |        31.13 |
|   25 | Accept |         0.1 |     0.24818 |       0.075 |       0.075 |      0.57465 |       2013.8 |
|   26 | Accept |        0.12 |     0.17319 |       0.075 |       0.075 |      0.42922 |   1.1602e-05 |
|   27 | Accept |        0.12 |     0.23357 |       0.075 |       0.075 |      0.42956 |   0.00027218 |
|   28 | Accept |       0.095 |     0.14202 |       0.075 |       0.075 |       0.4806 |       13.452 |
|   29 | Accept |       0.105 |     0.25958 |       0.075 |       0.075 |      0.19755 |       943.87 |
|   30 | Accept |       0.205 |     0.26937 |       0.075 |       0.075 |       3.5051 |       93.492 |

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 87.0198 seconds.
Total objective function evaluation time: 20.5233

Best observed feasible point:
     sigma      box  
    _______    ______

    0.30554    8.9017

Observed objective function value = 0.075
Estimated objective function value = 0.075
Function evaluation time = 0.38443

Best estimated feasible point (according to models):
     sigma      box  
    _______    ______

    0.32869    18.076

Estimated objective function value = 0.075
Estimated function evaluation time = 0.2349
results = 
  BayesianOptimization with properties:

                      ObjectiveFcn: [function_handle]
              VariableDescriptions: [1x2 optimizableVariable]
                           Options: [1x1 struct]
                      MinObjective: 0.0750
                   XAtMinObjective: [1x2 table]
             MinEstimatedObjective: 0.0750
          XAtMinEstimatedObjective: [1x2 table]
           NumObjectiveEvaluations: 30
                  TotalElapsedTime: 87.0198
                         NextPoint: [1x2 table]
                            XTrace: [30x2 table]
                    ObjectiveTrace: [30x1 double]
                  ConstraintsTrace: []
                     UserDataTrace: {30x1 cell}
      ObjectiveEvaluationTimeTrace: [30x1 double]
                IterationTimeTrace: [30x1 double]
                        ErrorTrace: [30x1 double]
                  FeasibilityTrace: [30x1 logical]
       FeasibilityProbabilityTrace: [30x1 double]
               IndexOfMinimumTrace: [30x1 double]
             ObjectiveMinimumTrace: [30x1 double]
    EstimatedObjectiveMinimumTrace: [30x1 double]

この結果を使用して、最適化された新しい SVM 分類器に学習をさせます。

z(1) = results.XAtMinObjective.sigma;
z(2) = results.XAtMinObjective.box;
SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf',...
    'KernelScale',z(1),'BoxConstraint',z(2));

分類境界線をプロットします。サポート ベクター分類器を可視化するため、グリッド全体でスコアを予測します。

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(SVMModel,xGrid);

h = nan(3,1); % Preallocation
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
axis equal
hold off

新しいデータにおける精度の評価

新しいデータ点を生成して分類します。

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1);
grpData(11:20) = -1; % red = -1

v = predict(SVMModel,newData);

g = nan(7,1);
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h(1:5),{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors'},'Location','Southeast');
axis equal
hold off

正しく分類された新しいデータ点を確認します。正しく分類された点は赤い丸で囲まれ、正しく分類されていない点は黒い丸で囲まれています。

mydiff = (v == grpData); % Classified correctly
figure;
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**');
h(5) = plot(cdata(SVMModel.IsSupportVector,1),...
    cdata(SVMModel.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','-1 (classified)',...
    '+1 (classified)','Support Vectors','Correctly Classified',...
    'Misclassified'},'Location','Southeast');
hold off

参考

|

関連するトピック