Main Content

膨張畳み込みを使用したセマンティック セグメンテーション

膨張畳み込みを使用してセマンティック セグメンテーション ネットワークに学習させます。

セマンティック セグメンテーション ネットワークはイメージ内のすべてのピクセルを分類して、クラスごとにセグメント化されたイメージを作成します。セマンティック セグメンテーションの応用例としては、自動運転のための道路セグメンテーションや医療診断のための癌細胞セグメンテーションなどがあります。詳細については、深層学習を使用したセマンティック セグメンテーション入門 (Computer Vision Toolbox)を参照してください。

DeepLab [1] などのセマンティック セグメンテーション ネットワークでは、膨張畳み込み (Atrous 畳み込みとも呼ばれる) が広範に使用されます。これはパラメーターの数や計算量を増やさずに、層の受容野 (層で確認できる入力の領域) を増やすことができるからです。

学習データの読み込み

例では、説明のために 32 x 32 の三角形のイメージを含む単純なデータセットを使用します。データセットには、付随するピクセル ラベル グラウンド トゥルース データが含まれます。imageDatastorepixelLabelDatastore を使用して学習データを読み込みます。

dataFolder = fullfile(toolboxdir('vision'),'visiondata','triangleImages');
imageFolderTrain = fullfile(dataFolder,'trainingImages');
labelFolderTrain = fullfile(dataFolder,'trainingLabels');

イメージの imageDatastore を作成します。

imdsTrain = imageDatastore(imageFolderTrain);

グラウンド トゥルース ピクセル ラベル用の pixelLabelDatastore を作成します。

classNames = ["triangle" "background"];
labels = [255 0];
pxdsTrain = pixelLabelDatastore(labelFolderTrain,classNames,labels)
pxdsTrain = 
  PixelLabelDatastore with properties:

                       Files: {200x1 cell}
                  ClassNames: {2x1 cell}
                    ReadSize: 1
                     ReadFcn: @readDatastoreImage
    AlternateFileSystemRoots: {}

セマンティック セグメンテーション ネットワークの作成

この例では、膨張畳み込みに基づく単純なセマンティック セグメンテーション ネットワークを使用します。

学習データのデータ ソースを作成して、各ラベルのピクセル数を取得します。

ds = combine(imdsTrain,pxdsTrain);
tbl = countEachLabel(pxdsTrain)
tbl=2×3 table
         Name         PixelCount    ImagePixelCount
    ______________    __________    _______________

    {'triangle'  }         10326       2.048e+05   
    {'background'}    1.9447e+05       2.048e+05   

ピクセル ラベルの大部分は背景用です。このクラスの不均衡によって、上位クラスを優先して学習プロセスにバイアスがかけられます。これを修正するには、クラスの重み付けを使用してクラスのバランスを調整します。クラスの重みの計算にはいくつかの方法を使用できます。一般的な方法の 1 つは、クラスの重みがクラスの頻度の逆となる逆頻度重み付けです。この方法では、過少に表現されたクラスに与えられる重みが増えます。逆頻度重み付けを使用してクラスの重みを計算します。

numberPixels = sum(tbl.PixelCount);
frequency = tbl.PixelCount / numberPixels;
classWeights = 1 ./ frequency;

入力サイズが入力イメージのサイズに対応するイメージ入力層を使用して、ピクセル分類用ネットワークを作成します。次に、畳み込み層、バッチ正規化層、および ReLU 層の 3 つのブロックを指定します。各畳み込み層で、膨張係数を増加させながら 3 行 3 列のフィルターを 32 個指定します。また、'Padding' オプションを 'same' に設定して、出力と同じサイズになるように入力をパディングします。ピクセルを分類するには K と 1 行 1 列の畳み込みを使用する畳み込み層を含め (K はクラスの数)、その後にソフトマックス層と逆のクラスの重みを持つ pixelClassificationLayer を含めます。

inputSize = [32 32 1];
filterSize = 3;
numFilters = 32;
numClasses = numel(classNames);

layers = [
    imageInputLayer(inputSize)
    
    convolution2dLayer(filterSize,numFilters,'DilationFactor',1,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(filterSize,numFilters,'DilationFactor',2,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(filterSize,numFilters,'DilationFactor',4,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(1,numClasses)
    softmaxLayer
    pixelClassificationLayer('Classes',classNames,'ClassWeights',classWeights)];

ネットワークの学習

学習オプションを指定します。

options = trainingOptions('sgdm', ...
    'MaxEpochs', 100, ...
    'MiniBatchSize', 64, ... 
    'InitialLearnRate', 1e-3);

trainNetwork を使用してネットワークに学習させます。

net = trainNetwork(ds,layers,options);
Training on single CPU.
Initializing input data normalization.
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:01 |       91.62% |       1.6825 |          0.0010 |
|      17 |          50 |       00:00:25 |       88.56% |       0.2393 |          0.0010 |
|      34 |         100 |       00:00:45 |       92.08% |       0.1672 |          0.0010 |
|      50 |         150 |       00:01:06 |       93.17% |       0.1472 |          0.0010 |
|      67 |         200 |       00:01:27 |       94.15% |       0.1313 |          0.0010 |
|      84 |         250 |       00:01:47 |       94.47% |       0.1166 |          0.0010 |
|     100 |         300 |       00:02:09 |       95.04% |       0.1100 |          0.0010 |
|========================================================================================|
Training finished: Max epochs completed.

ネットワークのテスト

テスト データを読み込みます。イメージの imageDatastore を作成します。グラウンド トゥルース ピクセル ラベル用の pixelLabelDatastore を作成します。

imageFolderTest = fullfile(dataFolder,'testImages');
imdsTest = imageDatastore(imageFolderTest);
labelFolderTest = fullfile(dataFolder,'testLabels');
pxdsTest = pixelLabelDatastore(labelFolderTest,classNames,labels);

テスト データと学習済みネットワークを使用して、予測を実行します。

pxdsPred = semanticseg(imdsTest,net,'MiniBatchSize',32,'WriteLocation',tempdir);
Running semantic segmentation network
-------------------------------------
* Processed 100 images.

evaluateSemanticSegmentation を使用して予測精度を評価します。

metrics = evaluateSemanticSegmentation(pxdsPred,pxdsTest);
Evaluating semantic segmentation results
----------------------------------------
* Selected metrics: global accuracy, class accuracy, IoU, weighted IoU, BF score.
* Processed 100 images.
* Finalizing... Done.
* Data set metrics:

    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU    MeanBFScore
    ______________    ____________    _______    ___________    ___________

       0.95237          0.97352       0.72081      0.92889        0.46416  

セマンティック セグメンテーション ネットワークの評価の詳細については、evaluateSemanticSegmentation (Computer Vision Toolbox)を参照してください。

新しいイメージのセグメンテーション

テスト イメージ triangleTest.jpg を読み取って表示します。

imgTest = imread('triangleTest.jpg');
figure
imshow(imgTest)

Figure contains an axes object. The axes object contains an object of type image.

semanticseg を使用してテスト イメージをセグメント化し、labeloverlay を使用して結果を表示します。

C = semanticseg(imgTest,net);
B = labeloverlay(imgTest,C);
figure
imshow(B)

Figure contains an axes object. The axes object contains an object of type image.

参考

(Computer Vision Toolbox) | (Computer Vision Toolbox) | (Computer Vision Toolbox) | (Image Processing Toolbox) | (Computer Vision Toolbox) | (Computer Vision Toolbox) | | | (Computer Vision Toolbox) |

関連するトピック