最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

カスタム分類出力層の定義

ヒント

k 個の互いに排他的なクラスの交差エントロピー損失を含む分類出力層を構築するには、classificationLayer を使用します。分類問題に別の損失関数を使用する場合は、この例を指針として使用してカスタム分類出力層を定義できます。

この例では、残差平方和 (SSE) 損失を含むカスタム分類出力層を定義し、畳み込みニューラル ネットワークで使用する方法を説明します。

カスタム分類出力層を定義するために、この例で提供するテンプレートを使用できます。この例では、次のステップで説明を進めます。

  1. 層の命名 – MATLAB® で使用できるように層に名前を付けます。

  2. 層のプロパティの宣言 – 層のプロパティを指定します。

  3. コンストラクター関数の作成 (オプション) – 層の構築とそのプロパティ初期化の方法を指定します。コンストラクター関数を指定しない場合、プロパティは作成時に '' で初期化されます。

  4. 順方向損失関数の作成 – 予測と学習ターゲットの間の損失を指定します。

  5. 逆方向損失関数の作成 (オプション) – 予測についての損失の微分を指定します。逆方向損失関数を指定しない場合、順方向損失関数は dlarray オブジェクトをサポートしなければなりません。

分類 SSE 層は、分類問題の残差平方和損失を計算します。SSE は、2 つの連続確率変数間の誤差測定です。予測 Y と学習ターゲット T について、Y と T の間の SSE 損失は次で与えられます。

L=1Nn=1Ni=1K(YniTni)2,

ここで、N は観測値の数、K はクラスの数です。

分類出力層テンプレート

分類出力層のテンプレートを MATLAB の新しいファイルにコピーします。このテンプレートは、分類出力層の構造の概要を示しており、層の動作を定義する関数が含まれます。

classdef myClassificationLayer < nnet.layer.ClassificationLayer
        
    properties
        % (Optional) Layer properties.

        % Layer properties go here.
    end
 
    methods
        function layer = myClassificationLayer()           
            % (Optional) Create a myClassificationLayer.

            % Layer constructor function goes here.
        end

        function loss = forwardLoss(layer, Y, T)
            % Return the loss between the predictions Y and the training 
            % targets T.
            %
            % Inputs:
            %         layer - Output layer
            %         Y     – Predictions made by network
            %         T     – Training targets
            %
            % Output:
            %         loss  - Loss between Y and T

            % Layer forward loss function goes here.
        end
        
        function dLdY = backwardLoss(layer, Y, T)
            % (Optional) Backward propagate the derivative of the loss 
            % function.
            %
            % Inputs:
            %         layer - Output layer
            %         Y     – Predictions made by network
            %         T     – Training targets
            %
            % Output:
            %         dLdY  - Derivative of the loss with respect to the 
            %                 predictions Y

            % Layer backward loss function goes here.
        end
    end
end

層の命名

まず、層に名前を付けます。クラス ファイルの最初の行で、既存の名前 myClassificationLayersseClassificationLayer に置き換えます。

classdef sseClassificationLayer < nnet.layer.ClassificationLayer
    ...
end

次に、コンストラクター関数 myClassificationLayer (methods セクションの最初の関数) の名前を層と同じ名前に変更します。

    methods
        function layer = sseClassificationLayer()           
            ...
        end

        ...
     end

層の保存

層のクラス ファイルを sseClassificationLayer.m という名前の新しいファイルに保存します。このファイル名は層の名前に一致しなければなりません。この層を使用するには、このファイルを現在のフォルダーまたは MATLAB パス上のフォルダーに保存しなければなりません。

層のプロパティの宣言

層のプロパティを properties セクションで宣言します。

既定では、カスタム出力層には次のプロパティがあります。

  • Name層の名前。文字ベクトルまたは string スカラーとして指定します。層グラフに層を含めるには、空ではない一意の層の名前を指定しなければなりません。この層が含まれる系列ネットワークに学習させて Name'' に設定すると、学習時に層に名前が自動的に割り当てられます。

  • Description – 層についての 1 行の説明。文字ベクトルまたは string スカラーとして指定します。この説明は、層が Layer 配列に表示されるときに表示されます。層の説明を指定しない場合、"Classification Output" または "Regression Output" が表示されます。

  • Type – 層のタイプ。文字ベクトルまたは string スカラーとして指定します。Type の値は、層が Layer 配列に表示されるときに表示されます。層のタイプを指定しない場合、層のクラス名が表示されます。

カスタム分類層には次のプロパティもあります。

  • Classes出力層のクラス。categorical ベクトル、string 配列、文字ベクトルの cell 配列、または 'auto' として指定します。Classes'auto' の場合、学習時にクラスが自動的に設定されます。string 配列または文字ベクトルの cell 配列 str を指定すると、出力層のクラスが categorical(str,str) に設定されます。既定値は 'auto' です。

カスタム回帰層には次のプロパティもあります。

  • ResponseNames応答の名前。文字ベクトルの cell 配列または string 配列として指定します。学習時に、学習データに従って応答名が自動的に設定されます。既定値は {} です。

層にその他のプロパティがない場合は、properties セクションを省略できます。

この例の層には追加のプロパティが必要ないため、properties セクションは削除できます。

コンストラクター関数の作成

層を構築する関数を作成し、層のプロパティを初期化します。層を作成するために必要な変数をコンストラクター関数への入力として指定します。

作成時に Name プロパティに割り当てる入力引数 name を指定します。関数の構文を説明するコメントを関数の上部に追加します。

        function layer = sseClassificationLayer(name)
            % layer = sseClassificationLayer(name) creates a sum of squares
            % error classification layer and specifies the layer name.

            ...
        end

層のプロパティの初期化

コメント % Layer constructor function goes here を、層のプロパティを初期化するコードに置き換えます。

層の Description プロパティを設定して、層に 1 行の説明を指定します。Name プロパティを入力引数 name に設定します。

        function layer = sseClassificationLayer(name)
            % layer = sseClassificationLayer(name) creates a sum of squares
            % error classification layer and specifies the layer name.
    
            % Set layer name.
            layer.Name = name;

            % Set layer description.
            layer.Description = 'Sum of squares error';
        end

順方向損失関数の作成

ネットワークで行った予測と学習ターゲットの間の SSE 損失を返す、forwardLoss という名前の関数を作成します。forwardLoss の構文は loss = forwardLoss(layer, Y, T) です。ここで、Y は前の層の出力であり、T は学習ターゲットを表します。

分類問題の場合、T の次元は問題のタイプによって異なります。

分類タスク入力サイズ観察値の次元
2 次元イメージ分類1 x 1 x K x N。ここで、K はクラスの数、N は観測値の数です。4
3 次元イメージ分類1 x 1 x 1 x K x N。ここで、K はクラスの数、N は観測値の数です。5
sequence-to-label 分類K 行 N 列。K はクラスの数、N は観測値の数です。2
sequence-to-sequence 分類K x N x S。ここで、K はクラス数、N は観測値の数、S はシーケンス長です。2

Y のサイズは前の層の出力によって異なります。YT と同じサイズになるように、正しいサイズを出力する層を出力層の前に含めなければなりません。たとえば、Y を必ず K 個のクラスの予測スコアを持つ 4 次元配列にするために、サイズが K の全結合層を含め、その後にソフトマックス層、出力層の順に配置することができます。

分類 SSE 層は、分類問題の残差平方和損失を計算します。SSE は、2 つの連続確率変数間の誤差測定です。予測 Y と学習ターゲット T について、Y と T の間の SSE 損失は次で与えられます。

L=1Nn=1Ni=1K(YniTni)2,

ここで、N は観測値の数、K はクラスの数です。

入力 Y および T は、この方程式の Y および T にそれぞれ対応しています。出力 loss は L に対応します。関数の構文を説明するコメントを関数の上部に追加します。

        function loss = forwardLoss(layer, Y, T)
            % loss = forwardLoss(layer, Y, T) returns the SSE loss between
            % the predictions Y and the training targets T.

            % Calculate sum of squares.
            sumSquares = sum((Y-T).^2);
    
            % Take mean over mini-batch.
            N = size(Y,4);
            loss = sum(sumSquares)/N;
        end

関数 forwardLossdlarray オブジェクトをサポートする関数のみを使用するため、関数 backwardLoss の定義はオプションです。dlarray オブジェクトをサポートしている関数の一覧については、dlarray をサポートする関数の一覧を参照してください。

完成した層

完成した分類出力層のクラス ファイルを表示します。

classdef sseClassificationLayer < nnet.layer.ClassificationLayer
    % Example custom classification layer with sum of squares error loss.
    
    methods
        function layer = sseClassificationLayer(name)
            % layer = sseClassificationLayer(name) creates a sum of squares
            % error classification layer and specifies the layer name.
    
            % Set layer name.
            layer.Name = name;

            % Set layer description.
            layer.Description = 'Sum of squares error';
        end
        
        function loss = forwardLoss(layer, Y, T)
            % loss = forwardLoss(layer, Y, T) returns the SSE loss between
            % the predictions Y and the training targets T.

            % Calculate sum of squares.
            sumSquares = sum((Y-T).^2);
    
            % Take mean over mini-batch.
            N = size(Y,4);
            loss = sum(sumSquares)/N;
        end
    end
end

GPU 互換性

層の順方向関数が dlarray オブジェクトを完全にサポートしている場合、層は GPU 互換です。そうでない場合、GPU 互換にするには、層関数が入力をサポートし、gpuArray 型の出力を返さなければなりません。

多くの MATLAB 組み込み関数が入力引数 gpuArray および dlarray をサポートしています。dlarray オブジェクトをサポートしている関数の一覧については、dlarray をサポートする関数の一覧を参照してください。GPU で実行される関数の一覧については、GPU での MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。深層学習に GPU を使用するには、Compute Capability 3.0 以上の CUDA® 対応 NVIDIA® GPU も必要です。MATLAB での GPU の使用の詳細は、MATLAB での GPU 計算 (Parallel Computing Toolbox)を参照してください。

forwardLoss で使用する MATLAB 関数はすべて、dlarray オブジェクトをサポートしているため、層は GPU 互換です。

出力層の有効性のチェック

カスタム分類出力層 sseClassificationLayer について層の有効性をチェックします。

カスタム残差平方和分類層を定義します。この層を作成するには、ファイル sseClassificationLayer.m を現在のフォルダーに保存します。層のインスタンスを作成します。

layer = sseClassificationLayer('sse');

checkLayer を使用して層が有効であることをチェックします。層への典型的な入力における 1 つの観測値のサイズになるように有効な入力サイズを指定します。層には 1 x 1 x K x N の配列を入力する必要があります。K はクラスの数、N はミニバッチ内の観測値の数です。

validInputSize = [1 1 10];
checkLayer(layer,validInputSize,'ObservationDimension',4);
Skipping GPU tests. No compatible GPU device found.
 
Running nnet.checklayer.TestOutputLayerWithoutBackward
........
Done nnet.checklayer.TestOutputLayerWithoutBackward
__________

Test Summary:
	 8 Passed, 0 Failed, 0 Incomplete, 2 Skipped.
	 Time elapsed: 0.21531 seconds.

テストの概要では、パスしたテスト、失敗したテスト、不完全なテスト、およびスキップされたテストの数が報告されます。

ネットワークにカスタム分類出力層を含める

Deep Learning Toolbox では、カスタム出力層を他の出力層と同じように使用できます。この節では、前に作成したカスタム分類出力層を使用して分類用のネットワークを作成し、学習を行う方法を説明します。

例の学習データを読み込みます。

[XTrain,YTrain] = digitTrain4DArrayData;

カスタム残差平方和分類層を定義します。この層を作成するには、ファイル sseClassificationLayer.m を現在のフォルダーに保存します。層のインスタンスを作成します。カスタム分類出力層 sseClassificationLayer を含む層配列を作成します。

layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(10)
    softmaxLayer
    sseClassificationLayer('sse')]
layers = 
  7x1 Layer array with layers:

     1   ''      Image Input             28x28x1 images with 'zerocenter' normalization
     2   ''      Convolution             20 5x5 convolutions with stride [1  1] and padding [0  0  0  0]
     3   ''      Batch Normalization     Batch normalization
     4   ''      ReLU                    ReLU
     5   ''      Fully Connected         10 fully connected layer
     6   ''      Softmax                 softmax
     7   'sse'   Classification Output   Sum of squares error

学習オプションを設定し、ネットワークに学習させます。

options = trainingOptions('sgdm');
net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:01 |        9.38% |       0.9944 |          0.0100 |
|       2 |          50 |       00:00:06 |       75.00% |       0.3561 |          0.0100 |
|       3 |         100 |       00:00:10 |       92.97% |       0.1316 |          0.0100 |
|       4 |         150 |       00:00:16 |       96.88% |       0.0915 |          0.0100 |
|       6 |         200 |       00:00:21 |       95.31% |       0.0738 |          0.0100 |
|       7 |         250 |       00:00:27 |       96.88% |       0.0485 |          0.0100 |
|       8 |         300 |       00:00:33 |       99.22% |       0.0203 |          0.0100 |
|       9 |         350 |       00:00:39 |       99.22% |       0.0264 |          0.0100 |
|      11 |         400 |       00:00:44 |      100.00% |       0.0069 |          0.0100 |
|      12 |         450 |       00:00:49 |      100.00% |       0.0045 |          0.0100 |
|      13 |         500 |       00:00:54 |      100.00% |       0.0078 |          0.0100 |
|      15 |         550 |       00:01:00 |      100.00% |       0.0059 |          0.0100 |
|      16 |         600 |       00:01:04 |      100.00% |       0.0021 |          0.0100 |
|      17 |         650 |       00:01:09 |      100.00% |       0.0040 |          0.0100 |
|      18 |         700 |       00:01:14 |      100.00% |       0.0024 |          0.0100 |
|      20 |         750 |       00:01:19 |      100.00% |       0.0028 |          0.0100 |
|      21 |         800 |       00:01:23 |      100.00% |       0.0020 |          0.0100 |
|      22 |         850 |       00:01:29 |      100.00% |       0.0017 |          0.0100 |
|      24 |         900 |       00:01:35 |      100.00% |       0.0020 |          0.0100 |
|      25 |         950 |       00:01:40 |      100.00% |       0.0013 |          0.0100 |
|      26 |        1000 |       00:01:46 |      100.00% |       0.0012 |          0.0100 |
|      27 |        1050 |       00:01:52 |       99.22% |       0.0104 |          0.0100 |
|      29 |        1100 |       00:01:58 |      100.00% |       0.0013 |          0.0100 |
|      30 |        1150 |       00:02:05 |      100.00% |       0.0012 |          0.0100 |
|      30 |        1170 |       00:02:08 |       99.22% |       0.0077 |          0.0100 |
|========================================================================================|

新しいデータについて予測を行い、精度を計算することによって、ネットワーク性能を評価します。

[XTest,YTest] = digitTest4DArrayData;
YPred = classify(net, XTest);
accuracy = mean(YTest == YPred)
accuracy = 0.9844

参考

| |

関連するトピック