Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

カスタム学習ループを使用したネットワークの学習

この例では、カスタム学習率スケジュールで手書きの数字を分類するネットワークに学習させる方法を示します。

必要なオプション (たとえば、カスタム学習率スケジュール) が関数 trainingOptions に用意されていない場合、自動微分を使用して独自のカスタム学習ループを定義できます。

この例では、"時間ベースの減衰" 学習率スケジュールで手書きの数字を分類するようにネットワークに学習させます。各反復で、ソルバーは ρt=ρ01+k t によって与えられる学習率を使用します。ここで、t は反復回数、ρ0 は初期学習率、k は減衰です。

学習データの読み込み

関数 imageDatastore を使用して数字データをイメージ データストアとして読み込み、イメージ データが格納されているフォルダーを指定します。

dataFolder = fullfile(toolboxdir('nnet'),'nndemos','nndatasets','DigitDataset');
imds = imageDatastore(dataFolder, ...
    'IncludeSubfolders',true, ....
    'LabelSource','foldernames');

データを学習セットと検証セットに分割します。関数 splitEachLabel を使用して、データの 10% を検証用に残しておきます。

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,'randomize');

この例で使用されるネットワークには、サイズが 28 x 28 x 1 の入力イメージが必要です。学習イメージのサイズを自動的に変更するには、拡張イメージ データストアを使用します。学習イメージに対して実行する追加の拡張演算として、イメージを水平軸方向および垂直軸方向に最大 5 ピクセルだけランダムに平行移動させる演算を指定します。データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。

inputSize = [28 28 1];
pixelRange = [-5 5];
imageAugmenter = imageDataAugmenter( ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,'DataAugmentation',imageAugmenter);

他のデータ拡張を実行せずに検証イメージのサイズを自動的に変更するには、追加の前処理演算を指定せずに拡張イメージ データストアを使用します。

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

学習データ内のクラスの数を決定します。

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

ネットワークの定義

イメージ分類のネットワークを定義します。

layers = [
    imageInputLayer(inputSize,'Normalization','none','Name','input')
    convolution2dLayer(5,20,'Name','conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,20,'Padding','same','Name','conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding','same','Name','conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')
    softmaxLayer('Name','softmax')];
lgraph = layerGraph(layers);

層グラフから dlnetwork オブジェクトを作成します。

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}

モデル勾配関数の定義

例の最後にリストされている関数 modelGradients を作成します。この関数は、dlnetwork オブジェクト、入力データのミニバッチとそれに対応するラベルを受け取り、ネットワークにおける学習可能なパラメーターについての損失の勾配、および対応する損失を返します。

学習オプションの指定

ミニバッチ サイズを 128 として 10 エポック学習させます。

numEpochs = 10;
miniBatchSize = 128;

SGDM 最適化のオプションを指定します。減衰 0.01 の初期学習率 0.01、およびモーメンタム 0.9 を指定します。

initialLearnRate = 0.01;
decay = 0.01;
momentum = 0.9;

モデルの学習

学習中にイメージのミニバッチを処理および管理するminibatchqueueオブジェクトを作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、ラベルを one-hot 符号化変数に変換します。

  • イメージ データを次元ラベル 'SSCB' (spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。書式をクラス ラベルに追加しないでください。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

mbq = minibatchqueue(augimdsTrain,...
    'MiniBatchSize',miniBatchSize,...
    'MiniBatchFcn',@preprocessMiniBatch,...
    'MiniBatchFormat',{'SSCB',''});

学習の進行状況プロットを初期化します。

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

SGDM ソルバーの速度パラメーターを初期化します。

velocity = [];

カスタム学習ループを使用してネットワークに学習させます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。各ミニバッチで次を行います。

  • 関数 dlfeval および modelGradients を使用してモデルの勾配、状態、および損失を評価し、ネットワークの状態を更新。

  • 時間ベースの減衰学習率スケジュールの学習率を決定。

  • 関数 sgdmupdate を使用してネットワーク パラメーターを更新。

  • 学習の進行状況を表示。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    % Shuffle data.
    shuffle(mbq);
    
    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        [dlX, dlY] = next(mbq);
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function and update the network state.
        [gradients,state,loss] = dlfeval(@modelGradients,dlnet,dlX,dlY);
        dlnet.State = state;
        
        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);
        
        % Update the network parameters using the SGDM optimizer.
        [dlnet,velocity] = sgdmupdate(dlnet,gradients,velocity,learnRate,momentum);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

モデルのテスト

真のラベルをもつ検証セットで予測を比較し、モデルの分類精度をテストします。

学習後に新しいデータについて予測を行う際、ラベルは必要ありません。テスト データの予測子のみを含む minibatchqueue オブジェクトを作成します。

  • テスト用のラベルを無視するには、ミニバッチ キューの出力数を 1 に設定します。

  • 学習に使用されるサイズと同じミニバッチ サイズを指定します。

  • 例の最後にリストされている関数 preprocessMiniBatchPredictors を使用して予測子を前処理します。

  • データストアの単一の出力では、ミニバッチの形式 'SSCB' (spatial、spatial、channel、batch) を指定します。

numOutputs = 1;
mbqTest = minibatchqueue(augimdsValidation,numOutputs, ...
    'MiniBatchSize',miniBatchSize, ...
    'MiniBatchFcn',@preprocessMiniBatchPredictors, ...
    'MiniBatchFormat','SSCB');

例の最後にリストされている関数 modelPredictions を使用して、ミニバッチをループ処理し、イメージを分類します。

predictions = modelPredictions(dlnet,mbqTest,classes);

分類精度を評価します。

YTest = imdsValidation.Labels;
accuracy = mean(predictions == YTest)
accuracy = 0.9530

モデル勾配関数

関数 modelGradients は、dlnetwork オブジェクト dlnet、入力データ dlX のミニバッチとそれに対応するラベル Y を受け取り、dlnet における学習可能なパラメーターについての損失の勾配、ネットワークの状態、および損失を返します。勾配を自動的に計算するには、関数 dlgradient を使用します。

function [gradients,state,loss] = modelGradients(dlnet,dlX,Y)

[dlYPred,state] = forward(dlnet,dlX);

loss = crossentropy(dlYPred,Y);
gradients = dlgradient(loss,dlnet.Learnables);

loss = double(gather(extractdata(loss)));

end

モデル予測関数

関数 modelPredictions は、dlnetwork オブジェクト dlnet、入力データ mbqminibatchqueue、ネットワーク クラスを受け取り、minibatchqueue オブジェクトに含まれるすべてのデータを反復処理することによってモデル予測を計算します。この関数は、関数 onehotdecode を使用して、スコアが最も高い予測されたクラスを見つけます。

function predictions = modelPredictions(dlnet,mbq,classes)

predictions = [];

while hasdata(mbq)
    
    dlXTest = next(mbq);
    dlYPred = predict(dlnet,dlXTest);
    
    YPred = onehotdecode(dlYPred,classes,1)';
    
    predictions = [predictions; YPred];
end

end

ミニ バッチ前処理関数

関数 preprocessMiniBatch は、次の手順を使用して予測子とラベルのミニバッチを前処理します。

  1. 関数 preprocessMiniBatchPredictors を使用してイメージを前処理します。

  2. 入力 cell 配列からラベル データを抽出し、2 番目の次元に沿って categorical 配列に連結します。

  3. カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。

function [X,Y] = preprocessMiniBatch(XCell,YCell)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(XCell);

% Extract label data from cell and concatenate.
Y = cat(2,YCell{1:end});

% One-hot encode labels.
Y = onehotencode(Y,1);

end

ミニバッチ予測子前処理関数

関数 preprocessMiniBatchPredictors は、入力 cell 配列からイメージ データを抽出することで予測子のミニバッチを前処理し、数値配列に連結します。グレースケール入力では、4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されることになります。

function X = preprocessMiniBatchPredictors(XCell)

% Concatenate.
X = cat(4,XCell{1:end});

end

参考

| | | | | | | | |

関連するトピック