メインコンテンツ

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

dlnetwork オブジェクトを使用した予測の実行

この例では、ミニバッチをループ処理することにより、dlnetwork オブジェクトを使用して予測を行う方法を示します。

データ セットが大きい場合、またはメモリが限られたハードウェアで予測を行う場合、関数 minibatchpredict を使用し、データのミニバッチをループ処理して予測を行います。

dlnetwork オブジェクトの読み込み

学習済みの dlnetwork オブジェクトおよび対応するクラス名をワークスペースに読み込みます。このニューラル ネットワークには 1 つの入力と 2 つの出力があります。手書きの数字のイメージを入力として受け取り、数字のラベルと回転角度を予測します。

load dlnetDigits

予測用データの読み込み

予測用の数字テスト データを読み込みます。

load DigitsDataTest

クラス名を表示します。

classNames
classNames = 10×1 cell
    {'0'}
    {'1'}
    {'2'}
    {'3'}
    {'4'}
    {'5'}
    {'6'}
    {'7'}
    {'8'}
    {'9'}

いくつかのイメージと、それに対応するラベルと回転角度を表示します。

numObservations = size(XTest,4);
numPlots = 9;
idx = randperm(numObservations,numPlots);

figure
for i = 1:numPlots
    nexttile(i)
    I = XTest(:,:,:,idx(i));
    label = labelsTest(idx(i));
    imshow(I)
    title("Label: " + string(label) + newline + "Angle: " + anglesTest(idx(i)))
end

Figure contains 9 axes objects. Hidden axes object 1 with title Label: 8 Angle: 5 contains an object of type image. Hidden axes object 2 with title Label: 9 Angle: -45 contains an object of type image. Hidden axes object 3 with title Label: 1 Angle: -11 contains an object of type image. Hidden axes object 4 with title Label: 9 Angle: -40 contains an object of type image. Hidden axes object 5 with title Label: 6 Angle: -42 contains an object of type image. Hidden axes object 6 with title Label: 0 Angle: -18 contains an object of type image. Hidden axes object 7 with title Label: 2 Angle: -9 contains an object of type image. Hidden axes object 8 with title Label: 5 Angle: -17 contains an object of type image. Hidden axes object 9 with title Label: 9 Angle: -27 contains an object of type image.

予測の実行

minibatchpredict関数を使用して予測を行い、scores2label関数を使用して分類スコアをラベルに変換します。既定では、関数 minibatchpredict は利用可能な GPU がある場合にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。実行環境を手動で選択するには、minibatchpredict 関数の ExecutionEnvironment 引数を使用します。

[scoresTest,Y2Test] = minibatchpredict(net,XTest);
Y1Test = scores2label(scoresTest,classNames);

予測の一部を可視化します。

idx = randperm(numObservations,numPlots);

figure
for i = 1:numPlots
    nexttile(i)
    I = XTest(:,:,:,idx(i));
    label = Y1Test(idx(i));
    imshow(I)
    title("Label: " + string(label) + newline + "Angle: " + Y2Test(idx(i)))
end

Figure contains 9 axes objects. Hidden axes object 1 with title Label: 9 Angle: 20.3954 contains an object of type image. Hidden axes object 2 with title Label: 1 Angle: 3.7015 contains an object of type image. Hidden axes object 3 with title Label: 9 Angle: 23.5494 contains an object of type image. Hidden axes object 4 with title Label: 9 Angle: -36.4954 contains an object of type image. Hidden axes object 5 with title Label: 4 Angle: 16.428 contains an object of type image. Hidden axes object 6 with title Label: 7 Angle: 3.0644 contains an object of type image. Hidden axes object 7 with title Label: 1 Angle: 33.1356 contains an object of type image. Hidden axes object 8 with title Label: 4 Angle: 30.7531 contains an object of type image. Hidden axes object 9 with title Label: 9 Angle: 0.55887 contains an object of type image.

参考

| | | |

トピック