セマンティック セグメンテーション ネットワークの学習
学習データを読み込みます。
dataSetDir = fullfile(toolboxdir("vision"),"visiondata","triangleImages"); imageDir = fullfile(dataSetDir,"trainingImages"); labelDir = fullfile(dataSetDir,"trainingLabels");
イメージのイメージ データストアを作成します。
imds = imageDatastore(imageDir);
グラウンド トゥルース ピクセル ラベル用の pixelLabelDatastore
を作成します。
classNames = ["triangle" "background"]; labelIDs = [255 0]; pxds = pixelLabelDatastore(labelDir,classNames,labelIDs);
学習イメージとグラウンド トゥルース ピクセル ラベルを可視化します。
I = read(imds); C = read(pxds); I = imresize(I,5,"nearest"); L = imresize(uint8(C{1}),5,"nearest"); imshowpair(I,L,"montage")
イメージ データストアとピクセル ラベル データストアを学習用に統合します。
trainingData = pixelLabelImageDatastore(imds,pxds);
セマンティック セグメンテーション ネットワークを作成します。このネットワークでは、ダウンサンプリングおよびアップサンプリングの設計に基づいてシンプルなセマンティック セグメンテーション ネットワークを使用します。
numFilters = 64; filterSize = 3; numClasses = 2; layers = [ imageInputLayer([32 32 1]) convolution2dLayer(filterSize,numFilters,Padding=1) reluLayer() maxPooling2dLayer(2,Stride=2) convolution2dLayer(filterSize,numFilters,Padding=1) reluLayer() transposedConv2dLayer(4,numFilters,Stride=2,Cropping=1); convolution2dLayer(1,numClasses); softmaxLayer() ];
学習オプションを設定します。
opts = trainingOptions("sgdm", ... InitialLearnRate=1e-3, ... MaxEpochs=100, ... MiniBatchSize=64);
ピクセル分類に適した損失関数を定義します。
function loss = modelLoss(Y,T) mask = ~isnan(T); T(isnan(T)) = 0; loss = crossentropy(Y,T,Mask=mask,NormalizationFactor="mask-included"); end
ネットワークに学習をさせます。
net = trainnet(trainingData,layers,@modelLoss,opts);
Iteration Epoch TimeElapsed LearnRate TrainingLoss _________ _____ ___________ _________ ____________ 1 1 00:00:06 0.001 41.892 50 17 00:00:21 0.001 0.93931 100 34 00:00:35 0.001 0.7432 150 50 00:00:52 0.001 0.4558 200 67 00:01:10 0.001 0.48874 250 84 00:01:32 0.001 0.43741 300 100 00:01:44 0.001 0.32055 Training stopped: Max epochs completed
テスト イメージを読み取って表示します。
testImage = imread("triangleTest.jpg");
imshow(testImage)
テスト イメージをセグメント化し、結果を表示します。
C = semanticseg(testImage,net,Classes=classNames); B = labeloverlay(testImage,C); imshow(B)
参考
trainnet
(Deep Learning Toolbox)