Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

カスタム分類出力層の定義

ヒント

k 個の互いに排他的なクラスの交差エントロピー損失を含む分類出力層を構築するには、classificationLayer を使用します。分類問題に別の損失関数を使用する場合は、この例を指針として使用してカスタム分類出力層を定義できます。

この例では、残差平方和 (SSE) 損失を含むカスタム分類出力層を定義し、畳み込みニューラル ネットワークで使用する方法を説明します。

カスタム分類出力層を定義するために、この例で提供するテンプレートを使用できます。この例では、次のステップで説明を進めます。

  1. 層の命名 – MATLAB® で使用できるように層に名前を付けます。

  2. 層のプロパティの宣言 – 層のプロパティを指定します。

  3. コンストラクター関数の作成 (オプション) – 層の構築とそのプロパティ初期化の方法を指定します。コンストラクター関数を指定しない場合、プロパティは作成時に '' で初期化されます。

  4. 順方向損失関数の作成 – 予測と学習ターゲットの間の損失を指定します。

  5. 逆方向損失関数の作成 (オプション) – 予測についての損失の微分を指定します。逆方向損失関数を指定しない場合、順方向損失関数は dlarray オブジェクトをサポートしなければなりません。

分類 SSE 層は、分類問題の残差平方和損失を計算します。SSE は、2 つの連続確率変数間の誤差測定です。予測 Y と学習ターゲット T について、Y と T の間の SSE 損失は次で与えられます。

L=1Nn=1Ni=1K(YniTni)2,

ここで、N は観測値の数、K はクラスの数です。

分類出力層テンプレート

分類出力層のテンプレートを MATLAB の新しいファイルにコピーします。このテンプレートは、分類出力層の構造の概要を示しており、層の動作を定義する関数が含まれます。

classdef myClassificationLayer < nnet.layer.ClassificationLayer % ...
        % & nnet.layer.Acceleratable % (Optional)
        
    properties
        % (Optional) Layer properties.

        % Layer properties go here.
    end
 
    methods
        function layer = myClassificationLayer()           
            % (Optional) Create a myClassificationLayer.

            % Layer constructor function goes here.
        end

        function loss = forwardLoss(layer,Y,T)
            % Return the loss between the predictions Y and the training 
            % targets T.
            %
            % Inputs:
            %         layer - Output layer
            %         Y     – Predictions made by network
            %         T     – Training targets
            %
            % Output:
            %         loss  - Loss between Y and T

            % Layer forward loss function goes here.
        end
        
        function dLdY = backwardLoss(layer,Y,T)
            % (Optional) Backward propagate the derivative of the loss 
            % function.
            %
            % Inputs:
            %         layer - Output layer
            %         Y     – Predictions made by network
            %         T     – Training targets
            %
            % Output:
            %         dLdY  - Derivative of the loss with respect to the 
            %                 predictions Y

            % Layer backward loss function goes here.
        end
    end
end

層の命名とスーパークラスの指定

まず、層に名前を付けます。クラス ファイルの最初の行で、既存の名前 myClassificationLayersseClassificationLayer に置き換えます。この層は高速化をサポートしているため、nnet.layer.Acceleratable クラスも追加します。カスタム層の高速化の詳細については、Custom Layer Function Accelerationを参照してください。

classdef sseClassificationLayer < nnet.layer.ClassificationLayer ...
        & nnet.layer.Acceleratable

    ...
end

次に、コンストラクター関数 myClassificationLayer (methods セクションの最初の関数) の名前を層と同じ名前に変更します。

    methods
        function layer = sseClassificationLayer()           
            ...
        end

        ...
     end

層の保存

層のクラス ファイルを sseClassificationLayer.m という名前の新しいファイルに保存します。このファイル名は層の名前に一致しなければなりません。この層を使用するには、このファイルを現在のフォルダーまたは MATLAB パス上のフォルダーに保存しなければなりません。

層のプロパティの宣言

層のプロパティを properties セクションで宣言します。

既定では、カスタム出力層には次のプロパティがあります。

  • Name層の名前。文字ベクトルまたは string スカラーとして指定します。Layer 配列入力の場合、関数 trainnettrainNetworkassembleNetworklayerGraph、および dlnetwork は、名前が "" の層に自動的に名前を割り当てます。

  • Description — 層についての 1 行の説明。文字ベクトルまたは string スカラーとして指定します。この説明は、層が Layer 配列に表示されるときに表示されます。層の説明を指定しない場合、"Classification Output" または "Regression Output" が表示されます。

  • Type — 層のタイプ。文字ベクトルまたは string スカラーとして指定します。Type の値は、層が Layer 配列に表示されるときに表示されます。層のタイプを指定しない場合、層のクラス名が表示されます。

カスタム分類層には次のプロパティもあります。

  • Classes出力層のクラス。categorical ベクトル、string 配列、文字ベクトルの cell 配列、または "auto" として指定します。Classes"auto" の場合、学習時にクラスが自動的に設定されます。string 配列または文字ベクトルの cell 配列 str を指定すると、出力層のクラスが categorical(str,str) に設定されます。

カスタム回帰層には次のプロパティもあります。

  • ResponseNames応答の名前。文字ベクトルの cell 配列または string 配列として指定します。学習時に、学習データに従って応答名が自動的に設定されます。既定値は {} です。

層にその他のプロパティがない場合は、properties セクションを省略できます。

この例の層には追加のプロパティが必要ないため、properties セクションは削除できます。

コンストラクター関数の作成

層を構築する関数を作成し、層のプロパティを初期化します。層を作成するために必要な変数をコンストラクター関数への入力として指定します。

作成時に Name プロパティに割り当てる入力引数 name を指定します。関数の構文を説明するコメントを関数の上部に追加します。

        function layer = sseClassificationLayer(name)
            % layer = sseClassificationLayer(name) creates a sum of squares
            % error classification layer and specifies the layer name.

            ...
        end

層のプロパティの初期化

コメント % Layer constructor function goes here を、層のプロパティを初期化するコードに置き換えます。

層の Description プロパティを設定して、層に 1 行の説明を指定します。Name プロパティを入力引数 name に設定します。

        function layer = sseClassificationLayer(name)
            % layer = sseClassificationLayer(name) creates a sum of squares
            % error classification layer and specifies the layer name.
    
            % Set layer name.
            layer.Name = name;

            % Set layer description.
            layer.Description = 'Sum of squares error';
        end

順方向損失関数の作成

ネットワークで行った予測と学習ターゲットの間の SSE 損失を返す、forwardLoss という名前の関数を作成します。forwardLoss の構文は loss = forwardLoss(layer, Y, T) です。ここで、Y は前の層の出力であり、T は学習ターゲットを表します。

分類問題の場合、T の次元は問題のタイプによって異なります。

分類タスク
形状データ形式
2 次元イメージ分類1×1×K×N。ここで、K はクラスの数、N は観測値の数です。"SSCB"
3 次元イメージ分類1×1×1×K×N。ここで、K はクラスの数、N は観測値の数です。"SSSCB"
sequence-to-label 分類K 行 N 列。K はクラスの数、N は観測値の数です。"CB"
sequence-to-sequence 分類K×N×S。ここで、K はクラス数、N は観測値の数、S はシーケンス長です。"CBT"

Y のサイズは前の層の出力によって異なります。YT と同じサイズになるように、正しいサイズを出力する層を出力層の前に含めなければなりません。たとえば、Y を必ず K 個のクラスの予測スコアを持つ 4 次元配列にするために、サイズが K の全結合層を含め、その後にソフトマックス層、出力層の順に配置することができます。

分類 SSE 層は、分類問題の残差平方和損失を計算します。SSE は、2 つの連続確率変数間の誤差測定です。予測 Y と学習ターゲット T について、Y と T の間の SSE 損失は次で与えられます。

L=1Nn=1Ni=1K(YniTni)2,

ここで、N は観測値の数、K はクラスの数です。

入力 Y および T は、この方程式の Y および T にそれぞれ対応しています。出力 loss は L に対応します。関数の構文を説明するコメントを関数の上部に追加します。

        function loss = forwardLoss(layer, Y, T)
            % loss = forwardLoss(layer, Y, T) returns the SSE loss between
            % the predictions Y and the training targets T.

            % Calculate sum of squares.
            sumSquares = sum((Y-T).^2);
    
            % Take mean over mini-batch.
            N = size(Y,4);
            loss = sum(sumSquares)/N;
        end

関数 forwardLossdlarray オブジェクトをサポートする関数のみを使用するため、関数 backwardLoss の定義はオプションです。dlarray オブジェクトをサポートしている関数の一覧については、dlarray をサポートする関数の一覧を参照してください。

完成した層

完成した分類出力層のクラス ファイルを表示します。

classdef sseClassificationLayer < nnet.layer.ClassificationLayer ... 
        & nnet.layer.Acceleratable
    % Example custom classification layer with sum of squares error loss.
    
    methods
        function layer = sseClassificationLayer(name)
            % layer = sseClassificationLayer(name) creates a sum of squares
            % error classification layer and specifies the layer name.
    
            % Set layer name.
            layer.Name = name;

            % Set layer description.
            layer.Description = 'Sum of squares error';
        end
        
        function loss = forwardLoss(layer, Y, T)
            % loss = forwardLoss(layer, Y, T) returns the SSE loss between
            % the predictions Y and the training targets T.

            % Calculate sum of squares.
            sumSquares = sum((Y-T).^2);
    
            % Take mean over mini-batch.
            N = size(Y,4);
            loss = sum(sumSquares)/N;
        end
    end
end

GPU 互換性

層の順方向関数が dlarray オブジェクトを完全にサポートしている場合、層は GPU 互換です。そうでない場合、GPU 互換にするには、層関数が入力をサポートし、gpuArray (Parallel Computing Toolbox) 型の出力を返さなければなりません。

多くの MATLAB 組み込み関数が入力引数 gpuArray (Parallel Computing Toolbox) および dlarray をサポートしています。dlarray オブジェクトをサポートしている関数の一覧については、dlarray をサポートする関数の一覧を参照してください。GPU で実行される関数の一覧については、GPU での MATLAB 関数の実行 (Parallel Computing Toolbox)を参照してください。深層学習に GPU を使用するには、サポートされている GPU デバイスもなければなりません。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。MATLAB での GPU の使用の詳細は、MATLAB での GPU 計算 (Parallel Computing Toolbox)を参照してください。

forwardLoss で使用する MATLAB 関数はすべて、dlarray オブジェクトをサポートしているため、層は GPU 互換です。

出力層の有効性のチェック

カスタム分類出力層 sseClassificationLayer について層の有効性をチェックします。

サポート ファイルとしてこの例に添付されている層 sseClassificationLayer のインスタンスを作成します。

layer = sseClassificationLayer('sse');

checkLayer を使用して層が有効であることをチェックします。層への典型的な入力における 1 つの観測値のサイズになるように有効な入力サイズを指定します。層には 1 x 1 x K x N の配列を入力する必要があります。K はクラスの数、N はミニバッチ内の観測値の数です。

validInputSize = [1 1 10];
checkLayer(layer,validInputSize,'ObservationDimension',4);
Skipping GPU tests. No compatible GPU device found.
 
Skipping code generation compatibility tests. To check validity of the layer for code generation, specify the CheckCodegenCompatibility and ObservationDimension options.
 
Running nnet.checklayer.TestOutputLayerWithoutBackward
........
Done nnet.checklayer.TestOutputLayerWithoutBackward
__________

Test Summary:
	 8 Passed, 0 Failed, 0 Incomplete, 2 Skipped.
	 Time elapsed: 1.0955 seconds.

テストの概要では、パスしたテスト、失敗したテスト、未完了のテスト、およびスキップされたテストの数が報告されます。

ネットワークにカスタム分類出力層を含める

Deep Learning Toolbox では、カスタム出力層を他の出力層と同じように使用できます。この節では、前に作成したカスタム分類出力層を使用して分類用のネットワークを作成し、学習を行う方法を説明します。

例の学習データを読み込みます。

[XTrain,YTrain] = digitTrain4DArrayData;

サポート ファイルとしてこの例に添付されているカスタム分類出力層 sseClassificationLayer を含む層配列を作成します。

layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(5,20)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(10)
    softmaxLayer
    sseClassificationLayer('sse')]
layers = 
  7x1 Layer array with layers:

     1   ''      Image Input             28x28x1 images with 'zerocenter' normalization
     2   ''      2-D Convolution         20 5x5 convolutions with stride [1  1] and padding [0  0  0  0]
     3   ''      Batch Normalization     Batch normalization
     4   ''      ReLU                    ReLU
     5   ''      Fully Connected         10 fully connected layer
     6   ''      Softmax                 softmax
     7   'sse'   Classification Output   Sum of squares error

学習オプションを設定し、ネットワークに学習させます。

options = trainingOptions('sgdm');
net = trainNetwork(XTrain,YTrain,layers,options);
Training on single CPU.
Initializing input data normalization.
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:00 |        9.38% |       0.9944 |          0.0100 |
|       2 |          50 |       00:00:06 |       74.22% |       0.3545 |          0.0100 |
|       3 |         100 |       00:00:10 |       92.97% |       0.1301 |          0.0100 |
|       4 |         150 |       00:00:14 |       96.09% |       0.0964 |          0.0100 |
|       6 |         200 |       00:00:19 |       96.09% |       0.0716 |          0.0100 |
|       7 |         250 |       00:00:24 |       97.66% |       0.0456 |          0.0100 |
|       8 |         300 |       00:00:31 |       99.22% |       0.0201 |          0.0100 |
|       9 |         350 |       00:00:38 |       99.22% |       0.0254 |          0.0100 |
|      11 |         400 |       00:00:43 |      100.00% |       0.0074 |          0.0100 |
|      12 |         450 |       00:00:48 |      100.00% |       0.0047 |          0.0100 |
|      13 |         500 |       00:00:54 |      100.00% |       0.0084 |          0.0100 |
|      15 |         550 |       00:01:00 |      100.00% |       0.0062 |          0.0100 |
|      16 |         600 |       00:01:05 |      100.00% |       0.0020 |          0.0100 |
|      17 |         650 |       00:01:08 |      100.00% |       0.0041 |          0.0100 |
|      18 |         700 |       00:01:11 |      100.00% |       0.0023 |          0.0100 |
|      20 |         750 |       00:01:15 |      100.00% |       0.0025 |          0.0100 |
|      21 |         800 |       00:01:20 |      100.00% |       0.0019 |          0.0100 |
|      22 |         850 |       00:01:26 |      100.00% |       0.0017 |          0.0100 |
|      24 |         900 |       00:01:32 |      100.00% |       0.0020 |          0.0100 |
|      25 |         950 |       00:01:38 |      100.00% |       0.0013 |          0.0100 |
|      26 |        1000 |       00:01:46 |      100.00% |       0.0012 |          0.0100 |
|      27 |        1050 |       00:01:55 |       99.22% |       0.0104 |          0.0100 |
|      29 |        1100 |       00:02:02 |      100.00% |       0.0013 |          0.0100 |
|      30 |        1150 |       00:02:09 |      100.00% |       0.0011 |          0.0100 |
|      30 |        1170 |       00:02:12 |       99.22% |       0.0088 |          0.0100 |
|========================================================================================|
Training finished: Max epochs completed.

新しいデータについて予測を行い、精度を計算することによって、ネットワーク性能を評価します。

[XTest,YTest] = digitTest4DArrayData;
YPred = classify(net, XTest);
accuracy = mean(YTest == YPred)
accuracy = 0.9842

参考

| | | | |

関連するトピック