Main Content

セマンティック セグメンテーション ネットワークの学習

学習データを読み込みます。

dataSetDir = fullfile(toolboxdir("vision"),"visiondata","triangleImages");
imageDir = fullfile(dataSetDir,"trainingImages");
labelDir = fullfile(dataSetDir,"trainingLabels");

イメージのイメージ データストアを作成します。

imds = imageDatastore(imageDir);

グラウンド トゥルース ピクセル ラベル用の pixelLabelDatastore を作成します。

classNames = ["triangle" "background"];
labelIDs   = [255 0];
pxds = pixelLabelDatastore(labelDir,classNames,labelIDs);

学習イメージとグラウンド トゥルース ピクセル ラベルを可視化します。

I = read(imds);
C = read(pxds);

I = imresize(I,5,"nearest");
L = imresize(uint8(C{1}),5,"nearest");
imshowpair(I,L,"montage")

Figure contains an axes object. The axes object contains an object of type image.

イメージ データストアとピクセル ラベル データストアを学習用に統合します。

trainingData = pixelLabelImageDatastore(imds,pxds);

セマンティック セグメンテーション ネットワークを作成します。このネットワークでは、ダウンサンプリングおよびアップサンプリングの設計に基づいてシンプルなセマンティック セグメンテーション ネットワークを使用します。

numFilters = 64;
filterSize = 3;
numClasses = 2;
layers = [
    imageInputLayer([32 32 1])
    convolution2dLayer(filterSize,numFilters,Padding=1)
    reluLayer()
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(filterSize,numFilters,Padding=1)
    reluLayer()
    transposedConv2dLayer(4,numFilters,Stride=2,Cropping=1);
    convolution2dLayer(1,numClasses);
    softmaxLayer()
    ];

学習オプションを設定します。

opts = trainingOptions("sgdm", ...
    InitialLearnRate=1e-3, ...
    MaxEpochs=100, ...
    MiniBatchSize=64);

ピクセル分類に適した損失関数を定義します。

function loss = modelLoss(Y,T)
    mask = ~isnan(T);
    T(isnan(T)) = 0;
    loss = crossentropy(Y,T,Mask=mask,NormalizationFactor="mask-included");
end

ネットワークに学習をさせます。

net = trainnet(trainingData,layers,@modelLoss,opts);
    Iteration    Epoch    TimeElapsed    LearnRate    TrainingLoss
    _________    _____    ___________    _________    ____________
            1        1       00:00:06        0.001          41.892
           50       17       00:00:21        0.001         0.93931
          100       34       00:00:35        0.001          0.7432
          150       50       00:00:52        0.001          0.4558
          200       67       00:01:10        0.001         0.48874
          250       84       00:01:32        0.001         0.43741
          300      100       00:01:44        0.001         0.32055
Training stopped: Max epochs completed

テスト イメージを読み取って表示します。

testImage = imread("triangleTest.jpg");
imshow(testImage)

Figure contains an axes object. The axes object contains an object of type image.

テスト イメージをセグメント化し、結果を表示します。

C = semanticseg(testImage,net,Classes=classNames);
B = labeloverlay(testImage,C);
imshow(B)

Figure contains an axes object. The axes object contains an object of type image.

参考

(Deep Learning Toolbox)

関連するトピック