Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

dlnetwork オブジェクトを使用した予測の実行

この例では、データをミニバッチに分割することにより、dlnetwork オブジェクトを使用して予測を行う方法を示します。

データセットが大きい場合、またはメモリが限られたハードウェアで予測を行う場合、データをミニバッチに分割して予測を行います。SeriesNetwork または DAGNetwork オブジェクトで予測を行う場合、関数 predict は入力データをミニバッチに自動的に分割します。dlnetwork オブジェクトでは、データを手動でミニバッチに分割しなければなりません。

dlnetwork オブジェクトの読み込み

学習済み dlnetwork オブジェクトと対応するクラスを読み込みます。

s = load("digitsCustom.mat");
dlnet = s.dlnet;
classes = s.classes;

予測用データの読み込み

予測用の数字データを読み込みます。

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true);

予測の実行

テスト データのミニバッチをループ処理して、カスタム予測ループを使って予測を行います。

minibatchqueueを使用して、イメージのミニバッチを処理および管理します。ミニバッチ サイズとして 128 を指定します。イメージ データストアの読み取りサイズ プロパティをミニバッチ サイズに設定します。

各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、データをバッチに連結し、イメージを正規化。

  • イメージを次元 'SSCB' (spatial、spatial、channel、batch) で書式設定。既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。

  • GPU が利用できる場合、GPU で予測を実行。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

miniBatchSize = 128;
imds.ReadSize = miniBatchSize;

mbq = minibatchqueue(imds,...
    "MiniBatchSize",miniBatchSize,...
    "MiniBatchFcn", @preprocessMiniBatch,...
    "MiniBatchFormat","SSCB");

データのミニバッチをループ処理し、関数 predict を使用して予測を行います。関数 onehotdecode を使用して、クラス ラベルを決定します。予測クラス ラベルを保存します。

numObservations = numel(imds.Files);
YPred = strings(1,numObservations);

predictions = [];

% Loop over mini-batches.
while hasdata(mbq)
    
    % Read mini-batch of data.
    dlX = next(mbq);
       
    % Make predictions using the predict function.
    dlYPred = predict(dlnet,dlX);
   
    % Determine corresponding classes.
    predBatch = onehotdecode(dlYPred,classes,1);
    predictions = [predictions predBatch];
  
end

予測の一部を可視化します。

idx = randperm(numObservations,9);

figure
for i = 1:9
    subplot(3,3,i)
    I = imread(imds.Files{idx(i)});    
    label = predictions(idx(i));
    imshow(I)
    title("Label: " + string(label))
  
end

ミニバッチ前処理関数

関数 preprocessMiniBatch は、次の手順でデータを前処理します。

  1. 入力 cell 配列からデータを抽出し、数値配列に連結します。4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されることになります。

  2. 01 の間のピクセル値を正規化します。

function X = preprocessMiniBatch(data)    
    % Extract image data from cell and concatenate
    X = cat(4,data{:});
    
    % Normalize the images.
    X = X/255;
end

参考

| | | |

関連するトピック