Main Content

インクリメンタル学習中の条件付き学習の実行

この例では、モデルのパフォーマンスが不十分である場合にのみインクリメンタル学習用の単純ベイズ マルチクラス分類モデルに学習をさせる方法を示します。

柔軟なインクリメンタル学習ワークフローにより、学習が必要な場合にのみデータの入力バッチでインクリメンタル モデルに学習をさせることができます (インクリメンタル学習とはを参照)。たとえば、モデルのパフォーマンス メトリクスが十分なものである場合に効率を向上させるには、メトリクスが不十分になるまで入力バッチでの学習をスキップします。

データの読み込み

人の行動のデータ セットを読み込みます。データをランダムにシャッフルします。

load humanactivity
n = numel(actid);
rng(1) % For reproducibility
idx = randsample(n,n);
X = feat(idx,:);
Y = actid(idx);

データ セットの詳細については、コマンド ラインで Description を入力してください。

単純ベイズ分類モデルの学習

次を設定して、インクリメンタル学習用の単純ベイズ分類モデルを構成します。

  • 予測されるクラスの最大数: 5

  • 追跡されるパフォーマンス メトリクス: 最小コストも含む誤分類誤差率

  • メトリクス ウィンドウ サイズ: 1000

  • メトリクスのウォームアップ期間: 50

initobs = 50;
Mdl = incrementalClassificationNaiveBayes('MaxNumClasses',5,'MetricsWindowSize',1000,...
    'Metrics','classiferror','MetricsWarmupPeriod',initobs);

構成したモデルを最初の 50 個の観測値に当てはめます。

Mdl = fit(Mdl,X(1:initobs,:),Y(1:initobs))
Mdl = 
  incrementalClassificationNaiveBayes

                    IsWarm: 1
                   Metrics: [2x2 table]
                ClassNames: [1 2 3 4 5]
            ScoreTransform: 'none'
         DistributionNames: {1x60 cell}
    DistributionParameters: {5x60 cell}


haveTrainedAllClasses = numel(unique(Y(1:initobs))) == 5
haveTrainedAllClasses = logical
   1

MdlincrementalClassificationNaiveBayes モデル オブジェクトです。次のすべての条件が適用されるため、モデルはウォームです (IsWarm1)。

  • 初期学習データに予測されるすべてのクラスが含まれる (haveTrainedAllClassestrue)。

  • MdlMdl.MetricsWarmupPeriod 観測値に当てはめられている。

そのため、モデルは予測を生成するために準備されており、インクリメンタル学習関数はモデル内のパフォーマンス メトリクスを測定します。

条件付き学習を伴うインクリメンタル学習の実行

最新の 1000 個の観測値における誤分類誤差が 5% を超える場合にのみモデルに学習をさせるとします。

条件付き学習を行い、インクリメンタル学習を実行します。各反復で次の手順に従います。

  1. 100 個の観測値のチャンクを一度に処理することで、データ ストリームをシミュレートします。

  2. モデルと現在のデータ チャンクを updateMetrics に渡してモデルのパフォーマンスを更新します。入力モデルを出力モデルで上書きします。

  3. 誤分類誤差率と 2 番目のクラス μ21 における 1 番目の予測子の平均を保存し、学習中にそれらがどのように進化するかを確認します。

  4. 誤分類誤差率が 0.05 を超える場合にのみ、モデルをデータ チャンクに当てはめます。学習の実行時に入力モデルを出力モデルで上書きします。

  5. fit がモデルに学習させるタイミングを追跡します。

% Preallocation
numObsPerChunk = 100;
nchunk = floor((n - initobs)/numObsPerChunk);
mu21 = zeros(nchunk,1);
ce = array2table(nan(nchunk,2),'VariableNames',["Cumulative" "Window"]);
trained = false(nchunk,1);

% Incremental fitting
for j = 1:nchunk
    ibegin = min(n,numObsPerChunk*(j-1) + 1 + initobs);
    iend   = min(n,numObsPerChunk*j + initobs);
    idx = ibegin:iend;
    Mdl = updateMetrics(Mdl,X(idx,:),Y(idx));
    ce{j,:} = Mdl.Metrics{"ClassificationError",:};
    if ce{j,"Window"} > 0.05
        Mdl = fit(Mdl,X(idx,:),Y(idx));
        trained(j) = true;
    end    
    mu21(j) = Mdl.DistributionParameters{2,1}(1);
end

Mdl は、ストリーム内のすべてのデータで学習させた incrementalClassificationNaiveBayes モデル オブジェクトです。

モデルのパフォーマンスと μ21 が学習中にどのように進化するかを確認するには、それらを別々のタイルにプロットします。モデルの学習期間を特定します。

t = tiledlayout(2,1);
nexttile
plot(mu21)
hold on
plot(find(trained),mu21(trained),'r.')
ylabel('\mu_{21}')
legend('\mu_{21}','Training occurs','Location','best')
hold off
nexttile
plot(ce.Variables)
ylabel('Misclassification Error Rate')
legend(ce.Properties.VariableNames,'Location','best')
xlabel(t,'Iteration')

Figure contains 2 axes objects. Axes object 1 with ylabel \mu_{21} contains 2 objects of type line. One or more of the lines displays its values using only markers These objects represent \mu_{21}, Training occurs. Axes object 2 with ylabel Misclassification Error Rate contains 2 objects of type line. These objects represent Cumulative, Window.

μ21 のトレース プロットは、前の観測ウィンドウ内におけるモデルのパフォーマンスが最大で 0.05 である定数値の期間を示します。

参考

オブジェクト

関数

関連するトピック