Main Content

このページは前リリースの情報です。該当の英語のページはこのリリースで削除されています。

Tversky 損失を使用するカスタム ピクセル分類層の定義

この例では、Tversky 損失を使用するカスタム ピクセル分類層を定義および作成する方法を説明します。

この層を使用して、セマンティック セグメンテーション ネットワークに学習させることができます。カスタム深層学習層の作成の詳細については、カスタム深層学習層の定義を参照してください。

Tversky 損失

Tversky 損失は、セグメント化された 2 つのイメージの間のオーバーラップを測定する Tversky 指数に基づきます [1]。1 つのイメージ Y と対応するグラウンド トゥルース T の間の Tversky 指数 TIc は、次のようになります。

TIc=m=1MYcmTcmm=1MYcmTcm+αm=1MYcmTcm+βm=1MYcmTcm

  • c はクラスに対応し、c はクラス c 以外に対応します。

  • M は、Y の最初の 2 つの次元に沿った要素の数です。

  • αβ は、各クラス偽陽性と偽陰性の損失に対する寄与を制御する重み係数です。

クラス数 C に対する損失 L は、次のようになります。

L=c=1C1-TIc

分類層テンプレート

分類層のテンプレートを MATLAB® の新しいファイルにコピーします。このテンプレートは、分類層の構造の概要を示しており、層の動作を定義する関数が含まれます。この例の残りの部分では、tverskyPixelClassificationLayer を完成する方法を示します。

classdef tverskyPixelClassificationLayer < nnet.layer.ClassificationLayer

   properties
      % Optional properties
   end

   methods

        function loss = forwardLoss(layer, Y, T)
            % Layer forward loss function goes here
        end
        
    end
end

層のプロパティの宣言

既定では、カスタム出力層には次のプロパティがあります。

  • Name – 層の名前。文字ベクトルまたは string スカラーとして指定します。層グラフにこの層を含めるには、空ではない一意の層の名前を指定しなければなりません。この層が含まれる系列ネットワークに学習させて Name'' に設定すると、学習時に層に名前が自動的に割り当てられます。

  • Description – 層についての 1 行の説明。文字ベクトルまたは string スカラーとして指定します。この説明は、層が Layer 配列に表示されるときに表示されます。層の説明を指定しない場合、層のクラス名が表示されます。

  • Type – 層のタイプ。文字ベクトルまたは string スカラーとして指定します。Type の値は、層が Layer 配列に表示されるときに表示されます。層のタイプを指定しない場合、'Classification layer' または 'Regression layer' が表示されます。

カスタム分類層には次のプロパティもあります。

  • Classes – 出力層のクラス。categorical ベクトル、string 配列、文字ベクトルの cell 配列、または 'auto' として指定します。Classes'auto' の場合、学習時にクラスが自動的に設定されます。string 配列または文字ベクトルの cell 配列 str を指定すると、出力層のクラスが categorical(str,str) に設定されます。既定値は 'auto' です。

層にその他のプロパティがない場合は、properties セクションを省略できます。

Tversky 損失には、ゼロ除算を防止するために小さい定数値が必要です。プロパティ Epsilon を指定して、この値を保持します。2 つの可変プロパティ AlphaBeta も必要です。それぞれ、偽陽性と偽陰性の重みを制御します。

classdef tverskyPixelClassificationLayer < nnet.layer.ClassificationLayer

    properties(Constant)
       % Small constant to prevent division by zero. 
       Epsilon = 1e-8;
    end

    properties
       % Default weighting coefficients for false positives and false negatives 
       Alpha = 0.5;
       Beta = 0.5;  
    end

    ...
end

コンストラクター関数の作成

層を構築する関数を作成し、層のプロパティを初期化します。層を作成するために必要な変数をコンストラクター関数への入力として指定します。

作成時に Name プロパティを割り当てるオプションの入力引数 name を指定します。

function layer = tverskyPixelClassificationLayer(name, alpha, beta)
    % layer =  tverskyPixelClassificationLayer(name) creates a Tversky
    % pixel classification layer with the specified name.
           
    % Set layer name          
    layer.Name = name;

    % Set layer properties
    layer.Alpha = alpha;
    layer.Beta = beta;

    % Set layer description
    layer.Description = 'Tversky loss';
end

順方向損失関数の作成

ネットワークで行った予測と学習ターゲットの間の重み付き交差エントロピー損失を返す、forwardLoss という名前の関数を作成します。forwardLoss の構文は loss = forwardLoss(layer,Y,T) です。ここで、Y は前の層の出力であり、T は学習ターゲットを表します。

セマンティック セグメンテーションの問題の場合、T の次元は Y の次元と一致する必要があります。Y は、H x W x K x N の 4 次元配列です。K はクラスの数、N はミニバッチ サイズです。

Y のサイズは前の層の出力によって異なります。YT と同じサイズになるように、正しいサイズを出力する層を出力層の前に含めなければなりません。たとえば、Y を必ず K 個のクラスの予測スコアを持つ 4 次元配列にするために、サイズが K の全結合層か K 個のフィルターを持つ畳み込み層を含め、その後にソフトマックス層、出力層の順に配置することができます。

function loss = forwardLoss(layer, Y, T)
    % loss = forwardLoss(layer, Y, T) returns the Tversky loss between
    % the predictions Y and the training targets T.

    Pcnot = 1-Y;
    Gcnot = 1-T;
    TP = sum(sum(Y.*T,1),2);
    FP = sum(sum(Y.*Gcnot,1),2);
    FN = sum(sum(Pcnot.*T,1),2);

    numer = TP + layer.Epsilon;
    denom = TP + layer.Alpha*FP + layer.Beta*FN + layer.Epsilon;
    
    % Compute Tversky index
    lossTIc = 1 - numer./denom;
    lossTI = sum(lossTIc,3);
    
    % Return average Tversky index loss
    N = size(Y,4);
    loss = sum(lossTI)/N;

end

逆方向損失関数

関数 forwardLoss は自動微分を完全にサポートしているため、逆方向損失の関数を作成する必要はありません。

自動微分をサポートする関数の一覧については、dlarray をサポートする関数の一覧を参照してください。

完成した層

完成した層は、tverskyPixelClassificationLayer.m で提供されています。これは、この例にサポート ファイルとして添付されています。

classdef tverskyPixelClassificationLayer < nnet.layer.ClassificationLayer
    % This layer implements the Tversky loss function for training
    % semantic segmentation networks.
    
    % References
    % Salehi, Seyed Sadegh Mohseni, Deniz Erdogmus, and Ali Gholipour.
    % "Tversky loss function for image segmentation using 3D fully
    % convolutional deep networks." International Workshop on Machine
    % Learning in Medical Imaging. Springer, Cham, 2017.
    % ----------
    
    
    properties(Constant)
        % Small constant to prevent division by zero.
        Epsilon = 1e-8;
    end
    
    properties
        % Default weighting coefficients for False Positives and False
        % Negatives
        Alpha = 0.5;
        Beta = 0.5;
    end

    
    methods
        
        function layer = tverskyPixelClassificationLayer(name, alpha, beta)
            % layer =  tverskyPixelClassificationLayer(name, alpha, beta) creates a Tversky
            % pixel classification layer with the specified name and properties alpha and beta.
            
            % Set layer name.          
            layer.Name = name;
            
            layer.Alpha = alpha;
            layer.Beta = beta;
            
            % Set layer description.
            layer.Description = 'Tversky loss';
        end
        
        
        function loss = forwardLoss(layer, Y, T)
            % loss = forwardLoss(layer, Y, T) returns the Tversky loss between
            % the predictions Y and the training targets T.   

            Pcnot = 1-Y;
            Gcnot = 1-T;
            TP = sum(sum(Y.*T,1),2);
            FP = sum(sum(Y.*Gcnot,1),2);
            FN = sum(sum(Pcnot.*T,1),2); 
            
            numer = TP + layer.Epsilon;
            denom = TP + layer.Alpha*FP + layer.Beta*FN + layer.Epsilon;
            
            % Compute tversky index
            lossTIc = 1 - numer./denom;
            lossTI = sum(lossTIc,3);
            
            % Return average tversky index loss.
            N = size(Y,4);
            loss = sum(lossTI)/N;
            
        end     
    end
end

GPU 互換性

tverskyPixelClassificationLayerforwardLoss で使用する MATLAB 関数はすべて、gpuArray 入力をサポートしているため、層は GPU 互換です。

出力層の有効性のチェック

層のインスタンスを作成します。

layer = tverskyPixelClassificationLayer('tversky',0.7,0.3);

checkLayer を使用して、層の有効性をチェックします。層への典型的な入力における 1 つの観測値のサイズになるように有効な入力サイズを指定します。層には H x W x K x N の配列を入力する必要があります。K はクラスの数、N はミニバッチ内の観測値の数です。

numClasses = 2;
validInputSize = [4 4 numClasses];
checkLayer(layer,validInputSize, 'ObservationDimension',4)
Skipping GPU tests. No compatible GPU device found.
 
Skipping code generation compatibility tests. To check validity of the layer for code generation, specify the CheckCodegenCompatibility and ObservationDimension options.
 
Running nnet.checklayer.TestOutputLayerWithoutBackward
........
Done nnet.checklayer.TestOutputLayerWithoutBackward
__________

Test Summary:
	 8 Passed, 0 Failed, 0 Incomplete, 2 Skipped.
	 Time elapsed: 1.3954 seconds.

テストの概要では、パスしたテスト、失敗したテスト、未完了のテスト、およびスキップされたテストの数が報告されます。

セマンティック セグメンテーション ネットワークでのカスタム層の使用

tverskyPixelClassificationLayer を使用するセマンティック セグメンテーション ネットワークを作成します。

layers = [
    imageInputLayer([32 32 1])
    convolution2dLayer(3,64,'Padding',1)
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(2,'Stride',2)
    convolution2dLayer(3,64,'Padding',1)
    reluLayer
    transposedConv2dLayer(4,64,'Stride',2,'Cropping',1)
    convolution2dLayer(1,2)
    softmaxLayer
    tverskyPixelClassificationLayer('tversky',0.3,0.7)];

imageDatastorepixelLabelDatastore を使用してセマンティック セグメンテーション用の学習データを読み込みます。

dataSetDir = fullfile(toolboxdir('vision'),'visiondata','triangleImages');
imageDir = fullfile(dataSetDir,'trainingImages');
labelDir = fullfile(dataSetDir,'trainingLabels');

imds = imageDatastore(imageDir);

classNames = ["triangle" "background"];
labelIDs = [255 0];
pxds = pixelLabelDatastore(labelDir, classNames, labelIDs);

データストア combine を使用してイメージとピクセル ラベル データを関連付けます。

ds = combine(imds,pxds);

学習オプションを設定し、ネットワークに学習させます。

options = trainingOptions('adam', ...
    'InitialLearnRate',1e-3, ...
    'MaxEpochs',100, ...
    'LearnRateDropFactor',5e-1, ...
    'LearnRateDropPeriod',20, ...
    'LearnRateSchedule','piecewise', ...
    'MiniBatchSize',50);

net = trainNetwork(ds,layers,options);
Training on single CPU.
Initializing input data normalization.
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:02 |       50.32% |       1.2933 |          0.0010 |
|      13 |          50 |       00:00:24 |       98.83% |       0.0988 |          0.0010 |
|      25 |         100 |       00:00:43 |       99.33% |       0.0547 |          0.0005 |
|      38 |         150 |       00:01:04 |       99.38% |       0.0472 |          0.0005 |
|      50 |         200 |       00:01:24 |       99.48% |       0.0401 |          0.0003 |
|      63 |         250 |       00:01:44 |       99.47% |       0.0384 |          0.0001 |
|      75 |         300 |       00:02:03 |       99.54% |       0.0349 |          0.0001 |
|      88 |         350 |       00:02:25 |       99.51% |       0.0352 |      6.2500e-05 |
|     100 |         400 |       00:02:44 |       99.56% |       0.0331 |      6.2500e-05 |
|========================================================================================|
Training finished: Max epochs completed.

テスト イメージをセグメント化し、セグメンテーション結果を表示して、学習済みネットワークを評価します。

I = imread('triangleTest.jpg');
[C,scores] = semanticseg(I,net);

B = labeloverlay(I,C);
montage({I,B})

Figure contains an axes object. The axes object contains an object of type image.

参考文献

[1] Salehi, Seyed Sadegh Mohseni, Deniz Erdogmus, and Ali Gholipour. "Tversky loss function for image segmentation using 3D fully convolutional deep networks." International Workshop on Machine Learning in Medical Imaging. Springer, Cham, 2017.

参考

| | | (Computer Vision Toolbox) | (Computer Vision Toolbox)

関連するトピック