このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
dlnetwork
オブジェクトを使用した予測の実行
この例では、データをミニバッチに分割することにより、dlnetwork
オブジェクトを使用して予測を行う方法を示します。
データセットが大きい場合、またはメモリが限られたハードウェアで予測を行う場合、データをミニバッチに分割して予測を行います。SeriesNetwork
または DAGNetwork
オブジェクトで予測を行う場合、関数 predict
は入力データをミニバッチに自動的に分割します。dlnetwork
オブジェクトでは、データを手動でミニバッチに分割しなければなりません。
dlnetwork
オブジェクトの読み込み
学習済み dlnetwork
オブジェクトと対応するクラスを読み込みます。
s = load("digitsCustom.mat");
dlnet = s.dlnet;
classes = s.classes;
予測用データの読み込み
予測用の数字データを読み込みます。
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ... 'nndatasets','DigitDataset'); imds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders',true);
予測の実行
テスト データのミニバッチをループ処理して、カスタム予測ループを使って予測を行います。
minibatchqueue
を使用して、イメージのミニバッチを処理および管理します。ミニバッチ サイズとして 128 を指定します。イメージ データストアの読み取りサイズ プロパティをミニバッチ サイズに設定します。
各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
(この例の最後に定義) を使用して、データをバッチに連結し、イメージを正規化。イメージを次元
'SSCB'
(spatial、spatial、channel、batch) で書式設定。既定では、minibatchqueue
オブジェクトは、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。GPU が利用できる場合、GPU で予測を実行。既定では、
minibatchqueue
オブジェクトは、GPU が利用可能な場合、出力をgpuArray
に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
miniBatchSize = 128; imds.ReadSize = miniBatchSize; mbq = minibatchqueue(imds,... "MiniBatchSize",miniBatchSize,... "MiniBatchFcn", @preprocessMiniBatch,... "MiniBatchFormat","SSCB");
データのミニバッチをループ処理し、関数 predict
を使用して予測を行います。関数 onehotdecode
を使用して、クラス ラベルを決定します。予測クラス ラベルを保存します。
numObservations = numel(imds.Files); YPred = strings(1,numObservations); predictions = []; % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. dlX = next(mbq); % Make predictions using the predict function. dlYPred = predict(dlnet,dlX); % Determine corresponding classes. predBatch = onehotdecode(dlYPred,classes,1); predictions = [predictions predBatch]; end
予測の一部を可視化します。
idx = randperm(numObservations,9); figure for i = 1:9 subplot(3,3,i) I = imread(imds.Files{idx(i)}); label = predictions(idx(i)); imshow(I) title("Label: " + string(label)) end
ミニバッチ前処理関数
関数 preprocessMiniBatch
は、次の手順でデータを前処理します。
入力 cell 配列からデータを抽出し、数値配列に連結します。4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されることになります。
0
と1
の間のピクセル値を正規化します。
function X = preprocessMiniBatch(data) % Extract image data from cell and concatenate X = cat(4,data{:}); % Normalize the images. X = X/255; end
参考
dlarray
| dlnetwork
| predict
| minibatchqueue
| onehotdecode