Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

予測用の多出力ネットワークの組み立て

この例では、予測用の複数の出力ネットワークを組み立てる方法を説明します。

予測用の dlnetwork オブジェクトを使用する代わりに、関数 assembleNetwork を使用して、予測の準備が整っている DAGNetwork 内にネットワークを組み立てることができます。これにより、データストアなどの他のデータ型で関数 predict を使用できます。

モデル関数とパラメーターの読み込み

MAT ファイル dlnetDigits.mat からモデル パラメーターを読み込みます。MAT ファイルには、カテゴリカル ラベルのスコアと数字イメージの回転角度の両方を予測する dlnetwork オブジェクト、ならびにクラス名が含まれます。

s = load("dlnetDigits.mat");
dlnet = s.dlnet;
classNames = s.classNames;

予測用のネットワークの組み立て

関数 layerGraph を使用して、dlnetwork オブジェクトから層グラフを抽出します。

lgraph = layerGraph(dlnet);

層グラフに出力層は含まれません。関数 addLayers および connectLayers を使用し、層グラフに分類層と回帰層を追加します。

layers = classificationLayer('Classes',classNames,'Name','coutput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'softmax','coutput');

layers = regressionLayer('Name','routput');
lgraph = addLayers(lgraph,layers);
lgraph = connectLayers(lgraph,'fc2','routput');

ネットワークのプロットを表示します。

figure
plot(lgraph)

Figure contains an axes. The axes contains an object of type graphplot.

関数 assembleNetwork を使用してネットワークを組み立てます。

net = assembleNetwork(lgraph)
net = 
  DAGNetwork with properties:

         Layers: [19x1 nnet.cnn.layer.Layer]
    Connections: [19x2 table]
     InputNames: {'in'}
    OutputNames: {'coutput'  'routput'}

新しいデータの予測の実行

テスト データを読み込みます。

[XTest,Y1Test,Y2Test] = digitTest4DArrayData;

組み立てたネットワークを使用して予測を行うには、関数 predict を使用します。分類出力のカテゴリカル ラベルを返すには、'ReturnCategorical' オプションを true に設定します。

[Y1Pred,Y2Pred] = predict(net,XTest,'ReturnCategorical',true);

分類精度を評価します。

accuracy = mean(Y1Pred==Y1Test)
accuracy = 0.9870

回帰精度を評価します。

angleRMSE = sqrt(mean((Y2Pred - Y2Test).^2))
angleRMSE = single
    6.0091

一部のイメージと、その予測を表示します。予測角度を赤、正解ラベルを緑で表示します。

idx = randperm(size(XTest,4),9);
figure
for i = 1:9
    subplot(3,3,i)
    I = XTest(:,:,:,idx(i));
    imshow(I)
    hold on
    
    sz = size(I,1);
    offset = sz/2;
    
    thetaPred = Y2Pred(idx(i));
    plot(offset*[1-tand(thetaPred) 1+tand(thetaPred)],[sz 0],'r--')
    
    thetaValidation = Y2Test(idx(i));
    plot(offset*[1-tand(thetaValidation) 1+tand(thetaValidation)],[sz 0],'g--')
    
    hold off
    label = string(Y1Pred(idx(i)));
    title("Label: " + label)
end

Figure contains 9 axes. Axes 1 with title Label: 8 contains 3 objects of type image, line. Axes 2 with title Label: 9 contains 3 objects of type image, line. Axes 3 with title Label: 1 contains 3 objects of type image, line. Axes 4 with title Label: 9 contains 3 objects of type image, line. Axes 5 with title Label: 6 contains 3 objects of type image, line. Axes 6 with title Label: 0 contains 3 objects of type image, line. Axes 7 with title Label: 2 contains 3 objects of type image, line. Axes 8 with title Label: 5 contains 3 objects of type image, line. Axes 9 with title Label: 9 contains 3 objects of type image, line.

参考

| | | | | |

関連するトピック