Main Content

ベイズ最適化の使用による分類器の当てはめの最適化

この例では、関数 fitcsvm および名前と値の引数 OptimizeHyperparameters を使用して SVM 分類を最適化する方法を示します。

データの生成

この分類は、混合ガウス モデルによる点の位置に作用します。モデルの詳細については、The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009) の 17 ページを参照してください。このモデルでは、平均 (1,0) および単位分散をもつ 2 次元の独立した正規分布になっている 10 個の基底点をはじめに "green" クラスについて生成します。また、平均 (0,1) と単位分散による 2 次元の独立した正規として分布される "red" クラスにも、10 個の基底点が生成されます。クラス (green と red) ごとに、次のように 100 個の無作為な点を生成します。

  1. 適切な色の基底点 m を一様にランダムに選択します。

  2. 平均 m と分散 I/5 (I は 2 行 2 列の単位行列) をもつ 2 次元正規分布を使用して、独立した無作為な点を生成します。最適化のアドバンテージをより明確に示すため、この例では I/50 という分散を使用します。

クラスごとに 10 個の基底点を生成します。

rng('default') % For reproducibility
grnpop = mvnrnd([1,0],eye(2),10);
redpop = mvnrnd([0,1],eye(2),10);

基底点を表示します。

plot(grnpop(:,1),grnpop(:,2),'go')
hold on
plot(redpop(:,1),redpop(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line.

赤の基底点の一部が緑の基底点の近くにあるため、位置のみによるデータ点の分類は難しいかもしれません。

各クラスについて 100 個ずつのデータ点を生成します。

redpts = zeros(100,2);
grnpts = redpts;
for i = 1:100
    grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);
    redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02);
end

データ点を表示します。

figure
plot(grnpts(:,1),grnpts(:,2),'go')
hold on
plot(redpts(:,1),redpts(:,2),'ro')
hold off

Figure contains an axes object. The axes object contains 2 objects of type line.

分類用のデータの準備

データを 1 つの行列に格納し、各点のクラスにラベルを付けるベクトル grp を作成します。1 は green クラスを示し、–1 は red クラスを示します。

cdata = [grnpts;redpts];
grp = ones(200,1);
grp(101:200) = -1;

交差検証の準備

交差検証用の分割を設定します。

c = cvpartition(200,'KFold',10);

この手順はオプションです。最適化に分割を指定する場合は、返されたモデルの実際の交差検証損失を計算できます。

当てはめの最適化

適切な当てはめ、つまり交差検証損失を最小化する最適なハイパーパラメーターをもつ当てはめを求めるには、ベイズ最適化を使用します。名前と値の引数 OptimizeHyperparameters を使用して最適化対象ハイパーパラメーターのリストを指定し、名前と値の引数 HyperparameterOptimizationOptions を使用して最適化オプションを指定します。

'OptimizeHyperparameters' として 'auto' を指定します。'auto' オプションを指定すると、一般的な最適化対象ハイパーパラメーターのセットが含まれます。fitcsvm は、BoxConstraintKernelScale の最適な値を求めます。再現性を得るために、ハイパーパラメーター最適化オプションを設定して交差検証分割 c を使用し、獲得関数 'expected-improvement-plus' を選択します。既定の獲得関数は実行時に決定されるので、結果が異なる場合があります。

opts = struct('CVPartition',c,'AcquisitionFunctionName','expected-improvement-plus');
Mdl = fitcsvm(cdata,grp,'KernelFunction','rbf', ...
    'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',opts)
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|    1 | Best   |       0.345 |     0.25732 |       0.345 |       0.345 |      0.00474 |       306.44 |
|    2 | Best   |       0.115 |     0.19793 |       0.115 |     0.12678 |       430.31 |       1.4864 |
|    3 | Accept |        0.52 |     0.30034 |       0.115 |      0.1152 |     0.028415 |     0.014369 |
|    4 | Accept |        0.61 |     0.42574 |       0.115 |     0.11504 |       133.94 |    0.0031427 |
|    5 | Accept |        0.34 |     0.67052 |       0.115 |     0.11504 |     0.010993 |       5.7742 |
|    6 | Best   |       0.085 |     0.31819 |       0.085 |    0.085039 |       885.63 |      0.68403 |
|    7 | Accept |       0.105 |     0.28752 |       0.085 |    0.085428 |       0.3057 |      0.58118 |
|    8 | Accept |        0.21 |     0.30194 |       0.085 |     0.09566 |      0.16044 |      0.91824 |
|    9 | Accept |       0.085 |      0.2937 |       0.085 |     0.08725 |       972.19 |      0.46259 |
|   10 | Accept |         0.1 |     0.61698 |       0.085 |    0.090952 |       990.29 |        0.491 |
|   11 | Best   |        0.08 |      0.3093 |        0.08 |    0.079362 |       2.5195 |        0.291 |
|   12 | Accept |        0.09 |     0.35855 |        0.08 |     0.08402 |       14.338 |      0.44386 |
|   13 | Accept |         0.1 |     0.31417 |        0.08 |     0.08508 |    0.0022577 |      0.23803 |
|   14 | Accept |        0.11 |     0.27436 |        0.08 |    0.087378 |       0.2115 |      0.32109 |
|   15 | Best   |        0.07 |     0.29307 |        0.07 |    0.081507 |        910.2 |      0.25218 |
|   16 | Best   |       0.065 |     0.37195 |       0.065 |    0.072457 |       953.22 |      0.26253 |
|   17 | Accept |       0.075 |     0.51771 |       0.065 |    0.072554 |       998.74 |      0.23087 |
|   18 | Accept |       0.295 |     0.29709 |       0.065 |    0.072647 |       996.18 |       44.626 |
|   19 | Accept |        0.07 |     0.25634 |       0.065 |     0.06946 |       985.37 |      0.27389 |
|   20 | Accept |       0.165 |     0.25769 |       0.065 |    0.071622 |     0.065103 |      0.13679 |
|=====================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | BoxConstraint|  KernelScale |
|      | result |             | runtime     | (observed)  | (estim.)    |              |              |
|=====================================================================================================|
|   21 | Accept |       0.345 |     0.24945 |       0.065 |    0.071764 |        971.7 |       999.01 |
|   22 | Accept |        0.61 |     0.26694 |       0.065 |    0.071967 |    0.0010168 |    0.0010005 |
|   23 | Accept |       0.345 |     0.25464 |       0.065 |    0.071959 |    0.0011459 |       995.89 |
|   24 | Accept |        0.35 |     0.25431 |       0.065 |    0.071863 |    0.0010003 |       40.628 |
|   25 | Accept |        0.24 |      0.4673 |       0.065 |    0.072124 |       996.55 |       10.423 |
|   26 | Accept |        0.61 |     0.78243 |       0.065 |    0.072067 |       994.71 |    0.0010063 |
|   27 | Accept |        0.47 |     0.39932 |       0.065 |     0.07218 |       993.69 |     0.029723 |
|   28 | Accept |         0.3 |     0.24834 |       0.065 |    0.072291 |       993.15 |       170.01 |
|   29 | Accept |        0.16 |     0.58967 |       0.065 |    0.072103 |       992.81 |       3.8594 |
|   30 | Accept |       0.365 |     0.43087 |       0.065 |    0.072112 |    0.0010017 |     0.044287 |

Figure contains an axes object. The axes object with title Min objective vs. Number of function evaluations contains 2 objects of type line. These objects represent Min observed objective, Estimated min objective.

Figure contains an axes object. The axes object with title Objective function model contains 5 objects of type line, surface, contour. These objects represent Observed points, Model mean, Next point, Model minimum feasible.

__________________________________________________________
Optimization completed.
MaxObjectiveEvaluations of 30 reached.
Total function evaluations: 30
Total elapsed time: 50.761 seconds
Total objective function evaluation time: 10.8637

Best observed feasible point:
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

Observed objective function value = 0.065
Estimated objective function value = 0.073726
Function evaluation time = 0.37195

Best estimated feasible point (according to models):
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

Estimated objective function value = 0.072112
Estimated function evaluation time = 0.3405
Mdl = 
  ClassificationSVM
                         ResponseName: 'Y'
                CategoricalPredictors: []
                           ClassNames: [-1 1]
                       ScoreTransform: 'none'
                      NumObservations: 200
    HyperparameterOptimizationResults: [1x1 BayesianOptimization]
                                Alpha: [77x1 double]
                                 Bias: -0.2352
                     KernelParameters: [1x1 struct]
                       BoxConstraints: [200x1 double]
                      ConvergenceInfo: [1x1 struct]
                      IsSupportVector: [200x1 logical]
                               Solver: 'SMO'


  Properties, Methods

fitcsvm は、最適な推定実行可能点を使用する ClassificationSVM モデル オブジェクトを返します。最適な推定実行可能点は、ベイズ最適化プロセスの基となるガウス過程モデルに基づいて交差検証損失の信頼限界の上限を最小化するハイパーパラメーターのセットです。

ベイズ最適化プロセスは、目的関数のガウス過程モデルを内部に保持します。目的関数は、分類の場合は交差検証済み誤分類率です。各反復において、最適化プロセスによってガウス過程モデルが更新され、そのモデルを使用して新しいハイパーパラメーターのセットが求められます。反復表示の各行には、新しいハイパーパラメーターのセットと次の列の値が表示されます。

  • Objective — 新しいハイパーパラメーターのセットにおいて計算された目的関数値。

  • Objective runtime — 目的関数の評価時間。

  • Eval resultAcceptBest または Error として指定される結果レポート。Accept は目的関数が有限値を返すことを示し、Error は目的関数が有限の実数スカラーではない値を返すことを示します。Best は、目的関数が以前に計算された目的関数値より小さい有限値を返すことを示します。

  • BestSoFar(observed) — それまでに計算された最小の目的関数値。この値は、現在の反復の目的関数値 (現在の反復における Eval result の値が Best である場合)、または前回の Best 反復の値です。

  • BestSoFar(estim.) — 各反復で、更新されたガウス過程モデルを使用して、それまでに試行されたすべてのハイパーパラメーターのセットにおける目的関数値の信頼限界の上限が推定されます。次に、信頼限界の上限が最小になる点が選択されます。BestSoFar(estim.) の値は、最小点において関数predictObjectiveによって返される目的関数値です。

反復表示の下のプロットは、BestSoFar(observed)BestSoFar(estim.) の値をそれぞれ青と緑で示しています。

返されるオブジェクト Mdl は、最適な推定実行可能点、つまり、最終的なガウス過程モデルに基づく最後の反復で BestSoFar(estim.) の値を生成するハイパーパラメーターのセットを使用します。

HyperparameterOptimizationResults プロパティから、または関数 bestPoint を使用して、最適な点を取得できます。

Mdl.HyperparameterOptimizationResults.XAtMinEstimatedObjective
ans=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

[x,CriterionValue,iteration] = bestPoint(Mdl.HyperparameterOptimizationResults)
x=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       985.37          0.27389  

CriterionValue = 0.0888
iteration = 19

既定では、関数 bestPoint は基準 'min-visited-upper-confidence-interval' を使用します。この基準では、19 番目の反復から取得されたハイパーパラメーターが最適な点として選択されます。CriterionValue は、最終的なガウス過程モデルによって計算された交差検証損失の上限です。分割 c を使用して実際の交差検証損失を計算します。

L_MinEstimated = kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,'KernelFunction','rbf', ...
    'BoxConstraint',x.BoxConstraint,'KernelScale',x.KernelScale))
L_MinEstimated = 0.0700

実際の交差検証損失は、推定値に近くなっています。最適化の結果を示すプロットの下に Estimated objective function value が表示されます。

また、HyperparameterOptimizationResults プロパティから、または Criterion として 'min-observed' を指定して、最適な観測実行可能点 (つまり、反復表示内の最後の Best 点) を抽出できます。

Mdl.HyperparameterOptimizationResults.XAtMinObjective
ans=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

[x_observed,CriterionValue_observed,iteration_observed] = bestPoint(Mdl.HyperparameterOptimizationResults,'Criterion','min-observed')
x_observed=1×2 table
    BoxConstraint    KernelScale
    _____________    ___________

       953.22          0.26253  

CriterionValue_observed = 0.0650
iteration_observed = 16

基準 'min-observed' では、16 番目の反復から取得されたハイパーパラメーターが最適な点として選択されます。CriterionValue_observed は、選択されたハイパーパラメーターを使用して計算された実際の交差検証損失です。詳細については、bestPoint の名前と値の引数Criterionを参照してください。

最適化された分類器を可視化します。

d = 0.02;
[x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)), ...
    min(cdata(:,2)):d:max(cdata(:,2)));
xGrid = [x1Grid(:),x2Grid(:)];
[~,scores] = predict(Mdl,xGrid);

figure
h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*');
hold on
h(3) = plot(cdata(Mdl.IsSupportVector,1), ...
    cdata(Mdl.IsSupportVector,2),'ko');
contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');
legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');

Figure contains an axes object. The axes object contains 4 objects of type line, contour. These objects represent -1, +1, Support Vectors.

新しいデータにおける精度の評価

新しい検定データ点を生成して分類します。

grnobj = gmdistribution(grnpop,.2*eye(2));
redobj = gmdistribution(redpop,.2*eye(2));

newData = random(grnobj,10);
newData = [newData;random(redobj,10)];
grpData = ones(20,1); % green = 1
grpData(11:20) = -1; % red = -1

v = predict(Mdl,newData);

検定データ セットで誤分類率を計算します。

L_Test = loss(Mdl,newData,grpData)
L_Test = 0.3500

正しく分類された新しいデータ点を判別します。正しく分類された点は赤い四角形で囲まれ、正しく分類されていない点は黒い四角形で囲まれています。

h(4:5) = gscatter(newData(:,1),newData(:,2),v,'mc','**');

mydiff = (v == grpData); % Classified correctly

for ii = mydiff % Plot red squares around correct pts
    h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12);
end

for ii = not(mydiff) % Plot black squares around incorrect pts
    h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12);
end
legend(h,{'-1 (training)','+1 (training)','Support Vectors', ...
    '-1 (classified)','+1 (classified)', ...
    'Correctly Classified','Misclassified'}, ...
    'Location','Southeast');
hold off

Figure contains an axes object. The axes object contains 8 objects of type line, contour. These objects represent -1 (training), +1 (training), Support Vectors, -1 (classified), +1 (classified), Correctly Classified, Misclassified.

参考

|

関連するトピック