dlnetwork
オブジェクトを使用した予測の実行
この例では、ミニバッチをループ処理することにより、dlnetwork
オブジェクトを使用して予測を行う方法を示します。
データ セットが大きい場合、またはメモリが限られたハードウェアで予測を行う場合、関数 minibatchpredict
を使用し、データのミニバッチをループ処理して予測を行います。
dlnetwork
オブジェクトの読み込み
学習済みの dlnetwork
オブジェクトと対応するクラス名をロードします。このニューラル ネットワークには 1 つの入力と 2 つの出力があります。手書きの数字のイメージを入力として受け取り、数字のラベルと回転角度を予測します。
load dlnetDigits
予測用データの読み込み
予測用の数字テスト データを読み込みます。
load DigitsDataTest
クラス名を表示します。
classNames
classNames = 10x1 cell
{'0'}
{'1'}
{'2'}
{'3'}
{'4'}
{'5'}
{'6'}
{'7'}
{'8'}
{'9'}
いくつかのイメージと、それに対応するラベルと回転角度を表示します。
numObservations = size(XTest,4); numPlots = 9; idx = randperm(numObservations,numPlots); figure for i = 1:numPlots nexttile(i) I = XTest(:,:,:,idx(i)); label = labelsTest(idx(i)); imshow(I) title("Label: " + string(label) + newline + "Angle: " + anglesTest(idx(i))) end
予測の実行
関数minibatchpredict
を使用して予測を行い、関数scores2label
を使用して分類スコアをラベルに変換します。既定では、関数 minibatchpredict
は利用可能な GPU がある場合にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。実行環境を指定するには、ExecutionEnvironment
オプションを使用します。
[scoresTest,Y2Test] = minibatchpredict(net,XTest); Y1Test = scores2label(scoresTest,classNames);
予測の一部を可視化します。
idx = randperm(numObservations,numPlots); figure for i = 1:numPlots nexttile(i) I = XTest(:,:,:,idx(i)); label = Y1Test(idx(i)); imshow(I) title("Label: " + string(label) + newline + "Angle: " + Y2Test(idx(i))) end
参考
dlarray
| dlnetwork
| predict
| minibatchqueue
| onehotdecode