Main Content

膨張畳み込みを使用したセマンティック セグメンテーション

膨張畳み込みを使用してセマンティック セグメンテーション ネットワークに学習させます。

セマンティック セグメンテーション ネットワークはイメージ内のすべてのピクセルを分類して、クラスごとにセグメント化されたイメージを作成します。セマンティック セグメンテーションの応用例としては、自動運転のための道路セグメンテーションや医療診断のための癌細胞セグメンテーションなどがあります。詳細については、深層学習を使用したセマンティック セグメンテーション入門を参照してください。

Deeplab v3+ [1] などのセマンティック セグメンテーション ネットワークでは、膨張畳み込み (Atrous 畳み込みとも呼ばれる) が広範に使用されます。これはパラメーターの数や計算量を増やさずに、層の受容野 (層で確認できる入力の領域) を増やすことができるからです。

学習データの読み込み

例では、説明のために 32 x 32 の三角形のイメージを含む単純なデータセットを使用します。データセットには、付随するピクセル ラベル グラウンド トゥルース データが含まれます。imageDatastorepixelLabelDatastore を使用して学習データを読み込みます。

dataFolder = fullfile(toolboxdir("vision"),"visiondata","triangleImages");
imageFolderTrain = fullfile(dataFolder,"trainingImages");
labelFolderTrain = fullfile(dataFolder,"trainingLabels");

イメージの imageDatastore を作成します。

imdsTrain = imageDatastore(imageFolderTrain);

グラウンド トゥルース ピクセル ラベル用の pixelLabelDatastore を作成します。

classNames = ["triangle" "background"];
labels = [255 0];
pxdsTrain = pixelLabelDatastore(labelFolderTrain,classNames,labels)
pxdsTrain = 
  PixelLabelDatastore with properties:

                       Files: {200×1 cell}
                  ClassNames: {2×1 cell}
                    ReadSize: 1
                     ReadFcn: @readDatastoreImage
    AlternateFileSystemRoots: {}

セマンティック セグメンテーション ネットワークの作成

この例では、膨張畳み込みに基づく単純なセマンティック セグメンテーション ネットワークを使用します。

学習データのデータ ソースを作成して、各ラベルのピクセル数を取得します。

ds = combine(imdsTrain,pxdsTrain);
tbl = countEachLabel(pxdsTrain)
tbl=2×3 table
         Name         PixelCount    ImagePixelCount
    ______________    __________    _______________

    {'triangle'  }         10326       2.048e+05   
    {'background'}    1.9447e+05       2.048e+05   

ピクセル ラベルの大部分は背景用です。このクラスの不均衡によって、上位クラスを優先して学習プロセスにバイアスがかけられます。これを修正するには、クラス加重を使用してクラスのバランスを調整します。クラス加重の計算にはいくつかの方法を使用できます。一般的な方法の 1 つは、クラス加重がクラスの頻度の逆となる逆頻度重み付けです。この方法では、過少に表現されたクラスに与えられる重みが増えます。逆頻度重み付けを使用してクラス加重を計算します。

numberPixels = sum(tbl.PixelCount);
frequency = tbl.PixelCount / numberPixels;
classWeights = dlarray(1 ./ frequency,"C");

入力サイズが入力イメージのサイズに対応するイメージ入力層を使用して、ピクセル分類用ネットワークを作成します。次に、畳み込み層、バッチ正規化層、および ReLU 層の 3 つのブロックを指定します。各畳み込み層で、膨張係数を増加させながら 3 行 3 列のフィルターを 32 個指定します。また、名前と値の引数 Padding"same" として設定して、出力と同じサイズになるように入力をパディングします。ピクセルを分類するには K 個の 1 行 1 列の畳み込みを使用する畳み込み層を含め (K はクラスの数)、その後にソフトマックス層を含めます。ピクセルの分類は、組み込みトレーナー trainnet 内で、カスタム モデル損失を使用して行われます。

inputSize = [32 32 1];
filterSize = 3;
numFilters = 32;
numClasses = numel(classNames);

layers = [
    imageInputLayer(inputSize)
    
    convolution2dLayer(filterSize,numFilters,DilationFactor=1,Padding="same")
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(filterSize,numFilters,DilationFactor=2,Padding="same")
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(filterSize,numFilters,DilationFactor=4,Padding="same")
    batchNormalizationLayer
    reluLayer
    
    convolution2dLayer(1,numClasses)
    softmaxLayer];

モデル損失関数

セマンティック セグメンテーション ネットワークには、さまざまな損失関数を使用して学習させることができます。組み込みトレーナーtrainnet (Deep Learning Toolbox)は、カスタム損失関数だけでなく、"crossentropy" や "mse" などのいくつかの標準損失関数もサポートしています。カスタム損失関数では、ネットワークの予測を実際のグラウンド トゥルースまたはターゲット値と比較することにより、学習データの各バッチの損失を手動で計算します。カスタム損失関数は、関数構文 loss = f(Y1,...,Yn,T1,...,Tm) をもつ関数ハンドルを使用します。ここで、Y1,...,Yn は、n 個のネットワーク予測に対応する dlarray オブジェクトであり、T1,...,Tm は、m 個のターゲットに対応する dlarray オブジェクトです。

この例では、データに見られるクラスの不均衡を考慮する 2 つの異なる損失関数のいずれかを選択できます。その損失関数は次のとおりです。

  1. 関数crossentropy (Deep Learning Toolbox)を使用する加重クロスエントロピー損失。加重クロスエントロピー損失では、学習中にそのクラスの誤差をスケーリングすることにより、少数しか存在しないクラスにも比重を置きます。

  2. Tversky 損失 [2] を計算するカスタム損失関数 tverskyLoss。Tversky 損失は、クラスの不均衡に特化した損失です。

Tversky 損失は、セグメント化された 2 つのイメージの間のオーバーラップを測定する Tversky 指数に基づいています。1 つのイメージ Y と対応するグラウンド トゥルース T の間の Tversky 指数 TIc は、次のようになります。

TIc=m=1MYcmTcmm=1MYcmTcm+αm=1MYcmTcm+βm=1MYcmTcm

  • c はクラスに対応し、c はクラス c 以外に対応します。

  • M は、Y の最初の 2 つの次元に沿った要素の数です。

  • αβ は、各クラス偽陽性と偽陰性の損失に対する寄与を制御する重み係数です。

クラス数 C に対する損失 L は、次のようになります。

L=c=1C1-TIc

学習の際に使用する損失関数を選択します。

lossFunction = "tverskyLoss"
lossFunction = 
"tverskyLoss"
if strcmp(lossFunction,"tverskyLoss")
    % Declare Tversky loss weighting coefficients for false positives and
    % false negatives. These coefficients are set and passed to the
    % training loss function using trainnet.
    alpha = 0.7;
    beta = 0.3;
    lossFcn = @(Y,T) tverskyLoss(Y,T,alpha,beta);
else
    % Use weighted cross-entropy loss during training.
    lossFcn = @(Y,T) crossentropy(Y,T,classWeights,NormalizationFactor="all-elements");
end

ネットワークの学習

学習オプションを指定します。

options = trainingOptions("sgdm",...
    MaxEpochs=100,...
    MiniBatchSize= 64,... 
    InitialLearnRate=1e-2,...
    Verbose=false);

trainnet (Deep Learning Toolbox)を使用してネットワークに学習させます。損失を損失関数 lossFcn として指定します。

net = trainnet(ds,layers,lossFcn,options);

ネットワークのテスト

テスト データを読み込みます。イメージの imageDatastore を作成します。グラウンド トゥルース ピクセル ラベル用の pixelLabelDatastore を作成します。

imageFolderTest = fullfile(dataFolder,"testImages");
imdsTest = imageDatastore(imageFolderTest);
labelFolderTest = fullfile(dataFolder,"testLabels");
pxdsTest = pixelLabelDatastore(labelFolderTest,classNames,labels);

テスト データと学習済みネットワークを使用して、予測を実行します。

pxdsPred = semanticseg(imdsTest,net,...
    Classes=classNames,...
    MiniBatchSize=32,...
    WriteLocation=tempdir);
Running semantic segmentation network
-------------------------------------
* Processed 100 images.

evaluateSemanticSegmentationを使用して予測精度を評価します。

metrics = evaluateSemanticSegmentation(pxdsPred,pxdsTest);
Evaluating semantic segmentation results
----------------------------------------
* Selected metrics: global accuracy, class accuracy, IoU, weighted IoU, BF score.
* Processed 100 images.
* Finalizing... Done.
* Data set metrics:

    GlobalAccuracy    MeanAccuracy    MeanIoU    WeightedIoU    MeanBFScore
    ______________    ____________    _______    ___________    ___________

       0.99674          0.98562       0.96447      0.99362        0.92831  

新しいイメージのセグメンテーション

テスト イメージ triangleTest.jpg を読み取り、semanticseg を使用してテスト イメージをセグメント化します。labeloverlay を使用して結果を表示します。

imgTest = imread("triangleTest.jpg");
[C,scores] = semanticseg(imgTest,net,classes=classNames);

B = labeloverlay(imgTest,C);
montage({imgTest,B})

Figure contains an axes object. The axes object contains an object of type image.

サポート関数

function loss = tverskyLoss(Y,T,alpha,beta)
    % loss = tverskyLoss(Y,T,alpha,beta) returns the Tversky loss
    % between the predictions Y and the training targets T.   
    
    Pcnot = 1-Y;
    Gcnot = 1-T;
    TP = sum(sum(Y.*T,1),2);
    FP = sum(sum(Y.*Gcnot,1),2);
    FN = sum(sum(Pcnot.*T,1),2); 
    
    epsilon = 1e-8;
    numer = TP + epsilon;
    denom = TP + alpha*FP + beta*FN + epsilon;
    
    % Compute tversky index.
    lossTIc = 1 - numer./denom;
    lossTI = sum(lossTIc,3);
    
    % Return average Tversky index loss.
    N = size(Y,4);
    loss = sum(lossTI)/N;
end

参考文献

[1] Chen, Liang-Chieh et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation." ECCV (2018).

[2] Salehi, Seyed Sadegh Mohseni, Deniz Erdogmus, and Ali Gholipour. "Tversky loss function for image segmentation using 3D fully convolutional deep networks." International Workshop on Machine Learning in Medical Imaging. Springer, Cham, 2017.