モデル関数を使用した予測の実行
この例では、データをミニバッチに分割することにより、モデル関数を使用して予測を行う方法を示します。
データセットが大きい場合、またはメモリが限られたハードウェアで予測を行う場合、データをミニバッチに分割して予測を行います。SeriesNetwork
または DAGNetwork
オブジェクトで予測を行う場合、関数 predict
は入力データをミニバッチに自動的に分割します。モデル関数では、データをミニバッチに手動で分割しなければなりません。
モデル関数の作成とパラメーターの読み込み
MAT ファイル digitsMIMO.mat
からモデル パラメーターを読み込みます。この MAT ファイルは、parameters
という名前の構造体にモデル パラメーター、state
という名前の構造体にモデルの状態、classNames
にクラス名を格納しています。
s = load("digitsMIMO.mat");
parameters = s.parameters;
state = s.state;
classNames = s.classNames;
例の最後にリストされているモデル関数 model
は、与えられたモデルのパラメーターと状態に基づいてモデルを定義します。
予測用データの読み込み
予測用の数字データを読み込みます。
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); numObservations = numel(imds.Files);
予測の実行
テスト データのミニバッチをループ処理して、カスタム予測ループを使って予測を行います。
minibatchqueue
を使用して、イメージのミニバッチを処理および管理します。ミニバッチ サイズとして 128 を指定します。イメージ データストアの読み取りサイズ プロパティをミニバッチ サイズに設定します。
各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
(この例の最後に定義) を使用して、データをバッチに連結し、イメージを正規化。イメージを次元
'SSCB'
(spatial、spatial、channel、batch) で書式設定。既定では、minibatchqueue
オブジェクトは、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。GPU が利用できる場合、GPU で予測を実行。既定では、
minibatchqueue
オブジェクトは、GPU が利用可能な場合、出力をgpuArray
に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(imds,... "MiniBatchSize",miniBatchSize,... "MiniBatchFcn", @preprocessMiniBatch,... "MiniBatchFormat","SSCB");
データのミニバッチをループ処理し、関数 predict
を使用して予測を行います。関数 onehotdecode
を使用して、クラス ラベルを決定します。予測クラス ラベルを保存します。
doTraining = false; Y1Predictions = []; Y2Predictions = []; % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. dlX = next(mbq); % Make predictions using the predict function. [dlY1Pred,dlY2Pred] = model(parameters,dlX,doTraining,state); % Determine corresponding classes. Y1PredBatch = onehotdecode(dlY1Pred,classNames,1); Y1Predictions = [Y1Predictions Y1PredBatch]; Y2PredBatch = extractdata(dlY2Pred); Y2Predictions = [Y2Predictions Y2PredBatch]; end
一部のイメージと、その予測を表示します。
idx = randperm(numObservations,9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); imshow(I) hold on sz = size(I,1); offset = sz/2; thetaPred = Y2Predictions(idx(i)); plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--') hold off label = string(Y1Predictions(idx(i))); title("Label: " + label) end
モデル関数
関数 model
は、モデル パラメーター parameters
、入力データ dlX
、モデルが学習と予測のどちらの出力を返すべきかを指定するフラグ doTraining
、およびネットワークの状態 state
を受け取ります。ネットワークはラベルの予測、角度の予測、および更新されたネットワークの状態を出力します。
function [dlY1,dlY2,state] = model(parameters,dlX,doTraining,state) % Convolution weights = parameters.conv1.Weights; bias = parameters.conv1.Bias; dlY = dlconv(dlX,weights,bias,'Padding','same'); % Batch normalization, ReLU offset = parameters.batchnorm1.Offset; scale = parameters.batchnorm1.Scale; trainedMean = state.batchnorm1.TrainedMean; trainedVariance = state.batchnorm1.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm1.TrainedMean = trainedMean; state.batchnorm1.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution, batch normalization (Skip connection) weights = parameters.convSkip.Weights; bias = parameters.convSkip.Bias; dlYSkip = dlconv(dlY,weights,bias,'Stride',2); offset = parameters.batchnormSkip.Offset; scale = parameters.batchnormSkip.Scale; trainedMean = state.batchnormSkip.TrainedMean; trainedVariance = state.batchnormSkip.TrainedVariance; if doTraining [dlYSkip,trainedMean,trainedVariance] = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance); % Update state state.batchnormSkip.TrainedMean = trainedMean; state.batchnormSkip.TrainedVariance = trainedVariance; else dlYSkip = batchnorm(dlYSkip,offset,scale,trainedMean,trainedVariance); end % Convolution weights = parameters.conv2.Weights; bias = parameters.conv2.Bias; dlY = dlconv(dlY,weights,bias,'Padding','same','Stride',2); % Batch normalization, ReLU offset = parameters.batchnorm2.Offset; scale = parameters.batchnorm2.Scale; trainedMean = state.batchnorm2.TrainedMean; trainedVariance = state.batchnorm2.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm2.TrainedMean = trainedMean; state.batchnorm2.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end dlY = relu(dlY); % Convolution weights = parameters.conv3.Weights; bias = parameters.conv3.Bias; dlY = dlconv(dlY,weights,bias,'Padding','same'); % Batch normalization offset = parameters.batchnorm3.Offset; scale = parameters.batchnorm3.Scale; trainedMean = state.batchnorm3.TrainedMean; trainedVariance = state.batchnorm3.TrainedVariance; if doTraining [dlY,trainedMean,trainedVariance] = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); % Update state state.batchnorm3.TrainedMean = trainedMean; state.batchnorm3.TrainedVariance = trainedVariance; else dlY = batchnorm(dlY,offset,scale,trainedMean,trainedVariance); end % Addition, ReLU dlY = dlYSkip + dlY; dlY = relu(dlY); % Fully connect, softmax (labels) weights = parameters.fc1.Weights; bias = parameters.fc1.Bias; dlY1 = fullyconnect(dlY,weights,bias); dlY1 = softmax(dlY1); % Fully connect (angles) weights = parameters.fc2.Weights; bias = parameters.fc2.Bias; dlY2 = fullyconnect(dlY,weights,bias); end
ミニバッチ前処理関数
関数 preprocessMiniBatch
は、次の手順でデータを前処理します。
入力 cell 配列からデータを抽出し、数値配列に連結します。4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されることになります。
0
と1
の間のピクセル値を正規化します。
function X = preprocessMiniBatch(data) % Extract image data from cell and concatenate X = cat(4,data{:}); % Normalize the images. X = X/255; end
参考
dlarray
| dlgradient
| dlfeval
| sgdmupdate
| dlconv
| batchnorm
| relu
| fullyconnect
| softmax
| minibatchqueue
| onehotdecode