ベイズ最適化の使用による分類器の当てはめの最適化
この例では、関数 fitcsvm
および名前と値の引数 OptimizeHyperparameters
を使用して SVM 分類を最適化する方法を示します。
データの生成
この分類は、混合ガウス モデルによる点の位置に作用します。モデルの詳細については、The Elements of Statistical Learning, Hastie, Tibshirani, and Friedman (2009) の 17 ページを参照してください。このモデルでは、平均 (1,0) および単位分散をもつ 2 次元の独立した正規分布になっている 10 個の基底点をはじめに "green" クラスについて生成します。また、平均 (0,1) と単位分散による 2 次元の独立した正規として分布される "red" クラスにも、10 個の基底点が生成されます。クラス (green と red) ごとに、次のように 100 個の無作為な点を生成します。
適切な色の基底点 m を一様にランダムに選択します。
平均 m と分散 I/5 (I は 2 行 2 列の単位行列) をもつ 2 次元正規分布を使用して、独立した無作為な点を生成します。最適化のアドバンテージをより明確に示すため、この例では I/50 という分散を使用します。
クラスごとに 10 個の基底点を生成します。
rng('default') % For reproducibility grnpop = mvnrnd([1,0],eye(2),10); redpop = mvnrnd([0,1],eye(2),10);
基底点を表示します。
plot(grnpop(:,1),grnpop(:,2),'go') hold on plot(redpop(:,1),redpop(:,2),'ro') hold off
赤の基底点の一部が緑の基底点の近くにあるため、位置のみによるデータ点の分類は難しいかもしれません。
各クラスについて 100 個ずつのデータ点を生成します。
redpts = zeros(100,2); grnpts = redpts; for i = 1:100 grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02); redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02); end
データ点を表示します。
figure plot(grnpts(:,1),grnpts(:,2),'go') hold on plot(redpts(:,1),redpts(:,2),'ro') hold off
分類用のデータの準備
データを 1 つの行列に格納し、各点のクラスにラベルを付けるベクトル grp
を作成します。1 は green クラスを示し、–1 は red クラスを示します。
cdata = [grnpts;redpts]; grp = ones(200,1); grp(101:200) = -1;
交差検証の準備
交差検証用の分割を設定します。
c = cvpartition(200,'KFold',10);
この手順はオプションです。最適化に分割を指定する場合は、返されたモデルの実際の交差検証損失を計算できます。
当てはめの最適化
適切な当てはめ、つまり交差検証損失を最小化する最適なハイパーパラメーターをもつ当てはめを求めるには、ベイズ最適化を使用します。名前と値の引数 OptimizeHyperparameters
を使用して最適化対象ハイパーパラメーターのリストを指定し、名前と値の引数 HyperparameterOptimizationOptions
を使用して最適化オプションを指定します。
'OptimizeHyperparameters'
として 'auto'
を指定します。'auto'
オプションを指定すると、一般的な最適化対象ハイパーパラメーターのセットが含まれます。fitcsvm
は、BoxConstraint
、KernelScale
、および Standardize
の最適な値を求めます。再現性を得るために、ハイパーパラメーター最適化オプションを設定して交差検証分割 c
を使用し、獲得関数 'expected-improvement-plus'
を選択します。既定の獲得関数は実行時に決定されるので、結果が異なる場合があります。
opts = struct('CVPartition',c,'AcquisitionFunctionName', ... 'expected-improvement-plus'); Mdl = fitcsvm(cdata,grp,'KernelFunction','rbf', ... 'OptimizeHyperparameters','auto','HyperparameterOptimizationOptions',opts)
|====================================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | BoxConstraint| KernelScale | Standardize | | | result | | runtime | (observed) | (estim.) | | | | |====================================================================================================================| | 1 | Best | 0.195 | 0.21831 | 0.195 | 0.195 | 193.54 | 0.069073 | false | | 2 | Accept | 0.345 | 0.10288 | 0.195 | 0.20398 | 43.991 | 277.86 | false | | 3 | Accept | 0.365 | 0.085025 | 0.195 | 0.20784 | 0.0056595 | 0.042141 | false | | 4 | Accept | 0.61 | 0.17863 | 0.195 | 0.31714 | 49.333 | 0.0010514 | true | | 5 | Best | 0.1 | 0.30419 | 0.1 | 0.10005 | 996.27 | 1.3081 | false | | 6 | Accept | 0.13 | 0.069174 | 0.1 | 0.10003 | 25.398 | 1.7076 | false | | 7 | Best | 0.085 | 0.1168 | 0.085 | 0.08521 | 930.3 | 0.66262 | false | | 8 | Accept | 0.35 | 0.066595 | 0.085 | 0.085172 | 0.012972 | 983.4 | true | | 9 | Best | 0.075 | 0.091629 | 0.075 | 0.077959 | 871.26 | 0.40617 | false | | 10 | Accept | 0.08 | 0.12545 | 0.075 | 0.077975 | 974.28 | 0.45314 | false | | 11 | Accept | 0.235 | 0.30216 | 0.075 | 0.077907 | 920.57 | 6.482 | true | | 12 | Accept | 0.305 | 0.070665 | 0.075 | 0.077922 | 0.0010077 | 1.0212 | true | | 13 | Best | 0.07 | 0.080775 | 0.07 | 0.073603 | 991.16 | 0.37801 | false | | 14 | Accept | 0.075 | 0.078256 | 0.07 | 0.073191 | 989.88 | 0.24951 | false | | 15 | Accept | 0.245 | 0.09407 | 0.07 | 0.073276 | 988.76 | 9.1309 | false | | 16 | Accept | 0.07 | 0.0795 | 0.07 | 0.071416 | 957.65 | 0.31271 | false | | 17 | Accept | 0.35 | 0.11798 | 0.07 | 0.071421 | 0.0010579 | 33.692 | true | | 18 | Accept | 0.085 | 0.05857 | 0.07 | 0.071274 | 48.536 | 0.32107 | false | | 19 | Accept | 0.07 | 0.082979 | 0.07 | 0.070587 | 742.56 | 0.30798 | false | | 20 | Accept | 0.61 | 0.19356 | 0.07 | 0.070796 | 865.48 | 0.0010165 | false | |====================================================================================================================| | Iter | Eval | Objective | Objective | BestSoFar | BestSoFar | BoxConstraint| KernelScale | Standardize | | | result | | runtime | (observed) | (estim.) | | | | |====================================================================================================================| | 21 | Accept | 0.1 | 0.085428 | 0.07 | 0.070715 | 970.87 | 0.14635 | true | | 22 | Accept | 0.095 | 0.12121 | 0.07 | 0.07087 | 914.88 | 0.46353 | true | | 23 | Accept | 0.07 | 0.14119 | 0.07 | 0.070473 | 982.01 | 0.2792 | false | | 24 | Accept | 0.51 | 0.51006 | 0.07 | 0.070515 | 0.0010005 | 0.014749 | true | | 25 | Accept | 0.345 | 0.16526 | 0.07 | 0.070533 | 0.0010063 | 972.18 | false | | 26 | Accept | 0.315 | 0.17117 | 0.07 | 0.07057 | 947.71 | 152.95 | true | | 27 | Accept | 0.35 | 0.36783 | 0.07 | 0.070605 | 0.0010028 | 43.62 | false | | 28 | Accept | 0.61 | 0.10346 | 0.07 | 0.070598 | 0.0010405 | 0.0010258 | false | | 29 | Accept | 0.555 | 0.07333 | 0.07 | 0.070173 | 993.56 | 0.010502 | true | | 30 | Accept | 0.07 | 0.099019 | 0.07 | 0.070158 | 965.73 | 0.25363 | true | __________________________________________________________ Optimization completed. MaxObjectiveEvaluations of 30 reached. Total function evaluations: 30 Total elapsed time: 16.8267 seconds Total objective function evaluation time: 4.3552 Best observed feasible point: BoxConstraint KernelScale Standardize _____________ ___________ ___________ 991.16 0.37801 false Observed objective function value = 0.07 Estimated objective function value = 0.072292 Function evaluation time = 0.080775 Best estimated feasible point (according to models): BoxConstraint KernelScale Standardize _____________ ___________ ___________ 957.65 0.31271 false Estimated objective function value = 0.070158 Estimated function evaluation time = 0.092681
Mdl = ClassificationSVM ResponseName: 'Y' CategoricalPredictors: [] ClassNames: [-1 1] ScoreTransform: 'none' NumObservations: 200 HyperparameterOptimizationResults: [1x1 BayesianOptimization] Alpha: [66x1 double] Bias: -0.0910 KernelParameters: [1x1 struct] BoxConstraints: [200x1 double] ConvergenceInfo: [1x1 struct] IsSupportVector: [200x1 logical] Solver: 'SMO'
fitcsvm
は、最適な推定実行可能点を使用する ClassificationSVM
モデル オブジェクトを返します。最適な推定実行可能点は、ベイズ最適化プロセスの基となるガウス過程モデルに基づいて交差検証損失の信頼限界の上限を最小化するハイパーパラメーターのセットです。
ベイズ最適化プロセスは、目的関数のガウス過程モデルを内部に保持します。目的関数は、分類の場合は交差検証済み誤分類率です。各反復において、最適化プロセスによってガウス過程モデルが更新され、そのモデルを使用して新しいハイパーパラメーターのセットが求められます。反復表示の各行には、新しいハイパーパラメーターのセットと次の列の値が表示されます。
Objective
— 新しいハイパーパラメーターのセットにおいて計算された目的関数値。Objective runtime
— 目的関数の評価時間。Eval result
—Accept
、Best
またはError
として指定される結果レポート。Accept
は目的関数が有限値を返すことを示し、Error
は目的関数が有限の実数スカラーではない値を返すことを示します。Best
は、目的関数が以前に計算された目的関数値より小さい有限値を返すことを示します。BestSoFar(observed)
— それまでに計算された最小の目的関数値。この値は、現在の反復の目的関数値 (現在の反復におけるEval result
の値がBest
である場合)、または前回のBest
反復の値です。BestSoFar(estim.)
— 各反復で、更新されたガウス過程モデルを使用して、それまでに試行されたすべてのハイパーパラメーターのセットにおける目的関数値の信頼限界の上限が推定されます。次に、信頼限界の上限が最小になる点が選択されます。BestSoFar(estim.)
の値は、最小点において関数predictObjective
によって返される目的関数値です。
反復表示の下のプロットは、BestSoFar(observed)
と BestSoFar(estim.)
の値をそれぞれ青と緑で示しています。
返されるオブジェクト Mdl
は、最適な推定実行可能点、つまり、最終的なガウス過程モデルに基づく最後の反復で BestSoFar(estim.)
の値を生成するハイパーパラメーターのセットを使用します。
HyperparameterOptimizationResults
プロパティから、または関数 bestPoint
を使用して、最適な点を取得できます。
Mdl.HyperparameterOptimizationResults.XAtMinEstimatedObjective
ans=1×3 table
BoxConstraint KernelScale Standardize
_____________ ___________ ___________
957.65 0.31271 false
[x,CriterionValue,iteration] = bestPoint(Mdl.HyperparameterOptimizationResults)
x=1×3 table
BoxConstraint KernelScale Standardize
_____________ ___________ ___________
957.65 0.31271 false
CriterionValue = 0.0724
iteration = 16
既定では、関数 bestPoint
は基準 'min-visited-upper-confidence-interval'
を使用します。この基準では、16 番目の反復から取得されたハイパーパラメーターが最適な点として選択されます。CriterionValue
は、最終的なガウス過程モデルによって計算された交差検証損失の上限です。分割 c
を使用して実際の交差検証損失を計算します。
L_MinEstimated = kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c, ... 'KernelFunction','rbf','BoxConstraint',x.BoxConstraint, ... 'KernelScale',x.KernelScale,'Standardize',x.Standardize=='true'))
L_MinEstimated = 0.0700
実際の交差検証損失は、推定値に近くなっています。最適化の結果を示すプロットの下に Estimated objective function value
が表示されます。
また、HyperparameterOptimizationResults
プロパティから、または Criterion
として 'min-observed'
を指定して、最適な観測実行可能点 (つまり、反復表示内の最後の Best
点) を抽出できます。
Mdl.HyperparameterOptimizationResults.XAtMinObjective
ans=1×3 table
BoxConstraint KernelScale Standardize
_____________ ___________ ___________
991.16 0.37801 false
[x_observed,CriterionValue_observed,iteration_observed] = ... bestPoint(Mdl.HyperparameterOptimizationResults,'Criterion','min-observed')
x_observed=1×3 table
BoxConstraint KernelScale Standardize
_____________ ___________ ___________
991.16 0.37801 false
CriterionValue_observed = 0.0700
iteration_observed = 13
基準 'min-observed'
では、13 番目の反復から取得されたハイパーパラメーターが最適な点として選択されます。CriterionValue_observed
は、選択されたハイパーパラメーターを使用して計算された実際の交差検証損失です。詳細については、bestPoint
の名前と値の引数Criterionを参照してください。
最適化された分類器を可視化します。
d = 0.02; [x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)), ... min(cdata(:,2)):d:max(cdata(:,2))); xGrid = [x1Grid(:),x2Grid(:)]; [~,scores] = predict(Mdl,xGrid); figure h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3) = plot(cdata(Mdl.IsSupportVector,1), ... cdata(Mdl.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); legend(h,{'-1','+1','Support Vectors'},'Location','Southeast');
新しいデータにおける精度の評価
新しいテスト データ点を生成して分類します。
grnobj = gmdistribution(grnpop,.2*eye(2)); redobj = gmdistribution(redpop,.2*eye(2)); newData = random(grnobj,10); newData = [newData;random(redobj,10)]; grpData = ones(20,1); % green = 1 grpData(11:20) = -1; % red = -1 v = predict(Mdl,newData);
テスト データ セットで誤分類率を計算します。
L_Test = loss(Mdl,newData,grpData)
L_Test = 0.2000
正しく分類された新しいデータ点を判別します。正しく分類された点は赤い四角形で囲まれ、正しく分類されていない点は黒い四角形で囲まれています。
h(4:5) = gscatter(newData(:,1),newData(:,2),v,'mc','**'); mydiff = (v == grpData); % Classified correctly for ii = mydiff % Plot red squares around correct pts h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12); end for ii = not(mydiff) % Plot black squares around incorrect pts h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12); end legend(h,{'-1 (training)','+1 (training)','Support Vectors', ... '-1 (classified)','+1 (classified)', ... 'Correctly Classified','Misclassified'}, ... 'Location','Southeast'); hold off