Main Content

dlnetwork オブジェクトを使用した予測の実行

この例では、ミニバッチをループ処理することにより、dlnetwork オブジェクトを使用して予測を行う方法を示します。

データ セットが大きい場合、またはメモリが限られたハードウェアで予測を行う場合、関数 minibatchpredict を使用し、データのミニバッチをループ処理して予測を行います。

dlnetwork オブジェクトの読み込み

学習済みの dlnetwork オブジェクトと対応するクラス名をロードします。このニューラル ネットワークには 1 つの入力と 2 つの出力があります。手書きの数字のイメージを入力として受け取り、数字のラベルと回転角度を予測します。

load dlnetDigits

予測用データの読み込み

予測用の数字テスト データを読み込みます。

load DigitsDataTest

クラス名を表示します。

classNames
classNames = 10x1 cell
    {'0'}
    {'1'}
    {'2'}
    {'3'}
    {'4'}
    {'5'}
    {'6'}
    {'7'}
    {'8'}
    {'9'}

いくつかのイメージと、それに対応するラベルと回転角度を表示します。

numObservations = size(XTest,4);
numPlots = 9;
idx = randperm(numObservations,numPlots);

figure
for i = 1:numPlots
    nexttile(i)
    I = XTest(:,:,:,idx(i));
    label = labelsTest(idx(i));
    imshow(I)
    title("Label: " + string(label) + newline + "Angle: " + anglesTest(idx(i)))
end

予測の実行

関数minibatchpredictを使用して予測を行い、関数scores2labelを使用して分類スコアをラベルに変換します。既定では、関数 minibatchpredict は利用可能な GPU がある場合にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。実行環境を指定するには、ExecutionEnvironment オプションを使用します。

[scoresTest,Y2Test] = minibatchpredict(net,XTest);
Y1Test = scores2label(scoresTest,classNames);

予測の一部を可視化します。

idx = randperm(numObservations,numPlots);

figure
for i = 1:numPlots
    nexttile(i)
    I = XTest(:,:,:,idx(i));
    label = Y1Test(idx(i));
    imshow(I)
    title("Label: " + string(label) + newline + "Angle: " + Y2Test(idx(i)))
end

参考

| | | |

関連するトピック