Main Content

イメージ データおよび特徴データにおけるネットワークの学習

この例では、イメージと特徴の両方の入力データを使用して、手書きの数字を分類するネットワークの学習を行う方法について説明します。

学習データの読み込み

数字のイメージ、ラベル、時計回りの回転角度を読み込みます。

[X1Train,TTrain,X2Train] = digitTrain4DArrayData;

関数 trainNetwork を使用して複数の入力をもつネットワークに学習させるには、学習予測子と応答を含む単一のデータストアを作成します。数値配列をデータストアに変換するには、arrayDatastore を使用します。次に、関数 combine を使用し、それらを組み合わせて単一のデータストアにします。

dsX1Train = arrayDatastore(X1Train,IterationDimension=4);
dsX2Train = arrayDatastore(X2Train);
dsTTrain = arrayDatastore(TTrain);
dsTrain = combine(dsX1Train,dsX2Train,dsTTrain);

ランダムに選ばれた 20 枚の学習イメージを表示します。

numObservationsTrain = numel(TTrain);
idx = randperm(numObservationsTrain,20);

figure
tiledlayout("flow");
for i = 1:numel(idx)
    nexttile
    imshow(X1Train(:,:,:,idx(i)))
    title("Angle: " + X2Train(idx(i)))
end

ネットワーク アーキテクチャの定義

次のネットワークを定義します。

  • イメージ入力用に、入力データと一致するサイズのイメージ入力層を指定します。

  • 特徴入力用に、入力特徴の数と一致するサイズの特徴入力層を指定します。

  • イメージ入力分岐用に、畳み込み層、バッチ正規化層、および ReLU 層のブロックを指定します。ここで畳み込み層は 5 行 5 列のフィルターを 16 個もちます。

  • バッチ正規化層の出力を特徴ベクトルに変換するために、サイズが 50 の全結合層を含めます。

  • 最初の全結合層の出力を特徴入力と連結するために、フラット化層を使用し、全結合層の "SSCB"(spatial、spatial、channel、batch) 出力をフラットにして "CB" の形式にします。

  • フラット化層の出力を、最初の次元 (チャネル次元) に沿って特徴入力と連結します。

  • 分類出力用に、クラスの数に一致する出力サイズの全結合層と、それに続くソフトマックス層と分類出力層を含めます。

ネットワークの主分岐を含む層配列を作成し、その配列を層グラフに変換します。

[h,w,numChannels,numObservations] = size(X1Train);
numFeatures = 1;
numClasses = numel(categories(TTrain));

imageInputSize = [h w numChannels];
filterSize = 5;
numFilters = 16;

layers = [
    imageInputLayer(imageInputSize,Normalization="none")
    convolution2dLayer(filterSize,numFilters)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(50)
    flattenLayer
    concatenationLayer(1,2,Name="cat")
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

lgraph = layerGraph(layers);

層グラフに特徴入力層を追加し、それを連結層の 2 番目の入力に結合します。

featInput = featureInputLayer(numFeatures,Name="features");
lgraph = addLayers(lgraph,featInput);
lgraph = connectLayers(lgraph,"features","cat/in2");

ネットワークをプロットで可視化します。

figure
plot(lgraph)

学習オプションの指定

学習オプションを指定します。

  • SGDM オプティマイザーを使用して学習させます。

  • 学習を 15 エポック行います。

  • 学習率を 0.01 にして学習を行います。

  • 学習の進行状況をプロットに表示。

  • 詳細出力を非表示にします。

options = trainingOptions("sgdm", ...
    MaxEpochs=15, ...
    InitialLearnRate=0.01, ...
    Plots="training-progress", ...
    Verbose=0);

ネットワークの学習

関数 trainNetwork を使用してネットワークに学習させます。

net = trainNetwork(dsTrain,lgraph,options);

ネットワークのテスト

真のラベルをもつテスト セットで予測を比較して、ネットワークの分類精度をテストします。

テスト データを読み込み、イメージと特徴を含む結合されたデータストアを作成します。

[X1Test,TTest,X2Test] = digitTest4DArrayData;
dsX1Test = arrayDatastore(X1Test,IterationDimension=4);
dsX2Test = arrayDatastore(X2Test);
dsTest = combine(dsX1Test,dsX2Test);

関数 classify を使用してテスト データを分類します。

YTest = classify(net,dsTest);

混同チャートで予測を可視化します。

figure
confusionchart(TTest,YTest)

分類精度を評価します。

accuracy = mean(YTest == TTest)
accuracy = 0.9834

一部のイメージと、その予測を表示します。

idx = randperm(size(X1Test,4),9);
figure
tiledlayout(3,3)
for i = 1:9
    nexttile
    I = X1Test(:,:,:,idx(i));
    imshow(I)

    label = string(YTest(idx(i)));
    title("Predicted Label: " + label)
end

参考

| | | | | | | |

関連する例

詳細