このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。
カスタム学習ループを使用したネットワークの学習
この例では、カスタム学習率スケジュールで手書きの数字を分類するネットワークに学習させる方法を示します。
ほとんどのタイプのニューラル ネットワークは、関数trainNetwork
と関数trainingOptions
を使用して学習させることができます。必要なオプション (たとえば、カスタム学習率スケジュール) が関数 trainingOptions
に用意されていない場合、自動微分のために dlarray
オブジェクトと dlnetwork
オブジェクトを使用して独自のカスタム学習ループを定義できます。関数 trainNetwork
を使用して事前学習済みの深層学習ネットワークに再学習させる方法を示す例については、事前学習済みのネットワークを使用した転移学習を参照してください。
深層ニューラル ネットワークの学習は最適化タスクです。ニューラル ネットワークを関数 と見なすことにより (ここで、 はネットワーク入力、 は学習可能なパラメーターのセット)、 を最適化して学習データに基づく損失値を最小化できます。たとえば、与えらえた入力 と対応するターゲット に対して、予測 と の間の誤差が最小になるように、学習可能なパラメーター を最適化します。
使用される損失関数は、タスクのタイプによって異なります。次に例を示します。
分類タスクの場合、予測とターゲットの間のクロス エントロピー誤差を最小化できます。
回帰タスクの場合、予測とターゲットの間の平均二乗誤差を最小化できます。
勾配降下法を使用して目的を最適化できます。勾配降下法では、学習可能なパラメーターについての損失の勾配を使用して最小化に向けてステップを実行することで、学習可能なパラメーター を反復的に更新することにより、損失 を最小化します。勾配降下法アルゴリズムは通常、 の形式の更新ステップのバリアントを使用して、学習可能なパラメーターを更新します。ここで、 は反復回数、 は学習率、 は勾配 (学習可能なパラメーターについての損失の微分) を表します。
この例では、"時間ベースの減衰" 学習率スケジュールで手書きの数字を分類するようにネットワークに学習させます。各反復で、ソルバーは によって与えられる学習率を使用します。ここで、t は反復回数、 は初期学習率、k は減衰です。
学習データの読み込み
関数 imageDatastore
を使用して数字データをイメージ データストアとして読み込み、イメージ データが格納されているフォルダーを指定します。
dataFolder = fullfile(toolboxdir("nnet"),"nndemos","nndatasets","DigitDataset"); imds = imageDatastore(dataFolder, ... IncludeSubfolders=true, .... LabelSource="foldernames");
データを学習セットと検証セットに分割します。関数 splitEachLabel
を使用して、データの 10% を検証用に残しておきます。
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9,"randomize");
この例で使用されるネットワークには、サイズが 28 x 28 x 1 の入力イメージが必要です。学習イメージのサイズを自動的に変更するには、拡張イメージ データストアを使用します。学習イメージに対して実行する追加の拡張演算として、イメージを水平軸方向および垂直軸方向に最大 5 ピクセルだけランダムに平行移動させる演算を指定します。データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。
inputSize = [28 28 1]; pixelRange = [-5 5]; imageAugmenter = imageDataAugmenter( ... RandXTranslation=pixelRange, ... RandYTranslation=pixelRange); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,DataAugmentation=imageAugmenter);
他のデータ拡張を実行せずに検証イメージのサイズを自動的に変更するには、追加の前処理演算を指定せずに拡張イメージ データストアを使用します。
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
学習データ内のクラスの数を決定します。
classes = categories(imdsTrain.Labels); numClasses = numel(classes);
ネットワークの定義
イメージ分類のネットワークを定義します。
イメージ入力用に、学習データと一致する入力サイズのイメージ入力層を指定します。
イメージ入力は正規化せず、入力層の
Normalization
オプションを"none"
に設定します。convolution-batchnorm-ReLU ブロックを 3 つ指定します。
Padding
オプションを"same"
に設定して、出力が同じサイズになるように畳み込み層に入力をパディングします。最初の畳み込み層には、サイズ 5 のフィルターを 20 個指定します。残りの畳み込み層には、サイズ 3 のフィルターを 20 個指定します。
分類用に、クラスの数と一致するサイズの全結合層を指定します。
出力を確率にマッピングするために、ソフトマックス層を含めます。
カスタム学習ループを使用してネットワークに学習させる場合、出力層は含めません。
layers = [ imageInputLayer(inputSize,Normalization="none") convolution2dLayer(5,20,Padding="same") batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding="same") batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding="same") batchNormalizationLayer reluLayer fullyConnectedLayer(numClasses) softmaxLayer];
層配列から dlnetwork
オブジェクトを作成します。
net = dlnetwork(layers)
net = dlnetwork with properties: Layers: [12×1 nnet.cnn.layer.Layer] Connections: [11×2 table] Learnables: [14×3 table] State: [6×3 table] InputNames: {'imageinput'} OutputNames: {'softmax'} Initialized: 1 View summary with summary.
モデル損失関数の定義
深層ニューラル ネットワークの学習は最適化タスクです。ニューラル ネットワークを関数 と見なすことにより (ここで、 はネットワーク入力、 は学習可能なパラメーターのセット)、 を最適化して学習データに基づく損失値を最小化できます。たとえば、与えらえた入力 と対応するターゲット に対して、予測 と の間の誤差が最小になるように、学習可能なパラメーター を最適化します。
例のモデル損失関数の節にリストされている関数 modelLoss
を作成します。この関数は、dlnetwork
オブジェクト、入力データのミニバッチと対応するターゲットを入力として受け取り、学習可能なパラメーターについての損失と損失の勾配、およびネットワークの状態を返します。
学習オプションの指定
ミニバッチ サイズを 128 として 10 エポック学習させます。
numEpochs = 10; miniBatchSize = 128;
SGDM 最適化のオプションを指定します。減衰 0.01 の初期学習率 0.01、およびモーメンタム 0.9 を指定します。
initialLearnRate = 0.01; decay = 0.01; momentum = 0.9;
モデルの学習
学習中にイメージのミニバッチを処理および管理するminibatchqueue
オブジェクトを作成します。各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
(この例の最後に定義) を使用して、ラベルを one-hot 符号化変数に変換します。イメージ データを次元ラベル
"SSCB"
(spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue
オブジェクトは、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。クラス ラベルは書式設定しません。GPU が利用できる場合、GPU で学習を行います。既定では、
minibatchqueue
オブジェクトは、GPU が利用可能な場合、各出力をgpuArray
に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
mbq = minibatchqueue(augimdsTrain,... MiniBatchSize=miniBatchSize,... MiniBatchFcn=@preprocessMiniBatch,... MiniBatchFormat=["SSCB" ""]);
SGDM ソルバーの速度パラメーターを初期化します。
velocity = [];
学習の進行状況モニター用に合計反復回数を計算します。
numObservationsTrain = numel(imdsTrain.Files); numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize); numIterations = numEpochs * numIterationsPerEpoch;
TrainingProgressMonitor
オブジェクトを初期化します。monitor オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。
monitor = trainingProgressMonitor(Metrics="Loss",Info=["Epoch","LearnRate"],XLabel="Iteration");
カスタム学習ループを使用してネットワークに学習させます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。各ミニバッチで次を行います。
関数
dlfeval
と関数modelLoss
を使用してモデルの損失、勾配、および状態を評価し、ネットワークの状態を更新。時間ベースの減衰学習率スケジュールの学習率を決定。
関数
sgdmupdate
を使用してネットワーク パラメーターを更新。学習の進行状況モニターで、損失、学習率、およびエポックの値を更新。
Stop プロパティが true の場合は停止。[停止] ボタンをクリックすると、
TrainingProgressMonitor
オブジェクトの Stop プロパティ値が true に変わります。
epoch = 0; iteration = 0; % Loop over epochs. while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) && ~monitor.Stop iteration = iteration + 1; % Read mini-batch of data. [X,T] = next(mbq); % Evaluate the model gradients, state, and loss using dlfeval and the % modelLoss function and update the network state. [loss,gradients,state] = dlfeval(@modelLoss,net,X,T); net.State = state; % Determine learning rate for time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the SGDM optimizer. [net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum); % Update the training progress monitor. recordMetrics(monitor,iteration,Loss=loss); updateInfo(monitor,Epoch=epoch,LearnRate=learnRate); monitor.Progress = 100 * iteration/numIterations; end end
モデルのテスト
真のラベルをもつ検証セットで予測を比較し、モデルの分類精度をテストします。
学習後に新しいデータについて予測を行う際、ラベルは必要ありません。テスト データの予測子のみを含む minibatchqueue
オブジェクトを作成します。
テスト用のラベルを無視するには、ミニバッチ キューの出力数を 1 に設定します。
学習に使用されるサイズと同じミニバッチ サイズを指定します。
例の最後にリストされている関数
preprocessMiniBatchPredictors
を使用して予測子を前処理します。データストアの単一の出力では、ミニバッチの形式
"SSCB"
(spatial、spatial、channel、batch) を指定します。
numOutputs = 1; mbqTest = minibatchqueue(augimdsValidation,numOutputs, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=@preprocessMiniBatchPredictors, ... MiniBatchFormat="SSCB");
例の最後にリストされている関数 modelPredictions
を使用して、ミニバッチをループ処理し、イメージを分類します。
YTest = modelPredictions(net,mbqTest,classes);
分類精度を評価します。
TTest = imdsValidation.Labels; accuracy = mean(TTest == YTest)
accuracy = 0.9750
混同チャートで予測を可視化します。
figure confusionchart(TTest,YTest)
対角線上の大きな値は、対応するクラスに対する正確な予測を示しています。対角線外の大きな値は、対応するクラス間での強い混同を示しています。
サポート関数
モデル損失関数
関数 modelLoss
は、dlnetwork
オブジェクト net
、入力データのミニバッチ X
とそれに対応するターゲット T
を受け取り、net
内の学習可能パラメーターについての損失と損失の勾配、およびネットワークの状態を返します。勾配を自動的に計算するには、関数 dlgradient
を使用します。
function [loss,gradients,state] = modelLoss(net,X,T) % Forward data through network. [Y,state] = forward(net,X); % Calculate cross-entropy loss. loss = crossentropy(Y,T); % Calculate gradients of loss with respect to learnable parameters. gradients = dlgradient(loss,net.Learnables); end
モデル予測関数
関数 modelPredictions
は、dlnetwork
オブジェクト net
、入力データの minibatchqueue
オブジェクト mbq
、およびネットワーク クラスを受け取り、minibatchqueue
オブジェクトに含まれるすべてのデータを反復処理することによってモデル予測を計算します。この関数は、関数 onehotdecode
を使用して、スコアが最も高い予測されたクラスを見つけます。
function Y = modelPredictions(net,mbq,classes) Y = []; % Loop over mini-batches. while hasdata(mbq) X = next(mbq); % Make prediction. scores = predict(net,X); % Decode labels and append to output. labels = onehotdecode(scores,classes,1)'; Y = [Y; labels]; end end
ミニ バッチ前処理関数
関数 preprocessMiniBatch
は、次の手順を使用して予測子とラベルのミニバッチを前処理します。
関数
preprocessMiniBatchPredictors
を使用してイメージを前処理します。入力 cell 配列からラベル データを抽出し、2 番目の次元に沿って categorical 配列に連結します。
カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。
function [X,T] = preprocessMiniBatch(dataX,dataT) % Preprocess predictors. X = preprocessMiniBatchPredictors(dataX); % Extract label data from cell and concatenate. T = cat(2,dataT{1:end}); % One-hot encode labels. T = onehotencode(T,1); end
ミニバッチ予測子前処理関数
関数 preprocessMiniBatchPredictors
は、入力 cell 配列からイメージ データを抽出することで予測子のミニバッチを前処理し、数値配列に連結します。グレースケール入力では、4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されることになります。
function X = preprocessMiniBatchPredictors(dataX) % Concatenate. X = cat(4,dataX{1:end}); end
参考
trainingProgressMonitor
| dlarray
| dlgradient
| dlfeval
| dlnetwork
| forward
| adamupdate
| predict
| minibatchqueue
| onehotencode
| onehotdecode