Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

YOLO v3 深層学習を使用したオブジェクトの検出

この例では、YOLO v3 オブジェクト検出器に学習させる方法を説明します。

深層学習は、ロバストなオブジェクト検出器に学習させるために使用できる強力な機械学習手法です。オブジェクト検出の方法は、Faster R-CNN、You Only Look Once (YOLO) v2、シングル ショット検出器 (SSD) など複数あります。この例では、YOLO v3 オブジェクト検出器に学習させる方法を説明します。YOLO v3 は YOLO v2 を改良したもので、複数のスケールにおける検出を追加してより小さなオブジェクトを検出できるようになっています。さらに、学習で使用される損失関数は、境界ボックス回帰用の平均二乗誤差と、オブジェクト分類用のバイナリ交差エントロピーに分割されており、検出精度が向上しています。

事前学習済みのネットワークのダウンロード

学習の完了を待たなくて済むように、補助関数 downloadPretrainedYOLOv3Detector を使用して、事前学習済みのネットワークをダウンロードします。ネットワークに学習させる場合は、変数 doTrainingtrue に設定します。

doTraining = false;

if ~doTraining
    net = downloadPretrainedYOLOv3Detector();    
end

データの読み込み

この例では、295 枚のイメージを含んだ小さなラベル付きデータセットを使用します。各イメージには、1 または 2 個のラベル付けされた車両インスタンスが含まれています。小さなデータセットは YOLO v3 の学習手順を調べるうえで役立ちますが、実際にロバストなネットワークに学習させるにはラベル付けされたイメージがより多く必要になります。

車両のイメージを解凍し、車両のグラウンド トゥルース データを読み込みます。

unzip vehicleDatasetImages.zip
data = load('vehicleDatasetGroundTruth.mat');
vehicleDataset = data.vehicleDataset;

% Add the full path to the local vehicle data folder.
vehicleDataset.imageFilename = fullfile(pwd, vehicleDataset.imageFilename);

メモ: 複数のクラスの場合、データは 3 列に整理することもできます。最初の列には、パスを含むイメージ ファイル名を格納し、2 列目には境界ボックスを格納します。3 列目は、各境界ボックスに対応するラベル名を含む cell ベクトルでなければなりません。境界ボックスとラベルの調整方法の詳細については、boxLabelDatastoreを参照してください。

すべての境界ボックスは、[x y width height] の形式でなければなりません。このベクトルは、境界ボックスの左上隅とサイズをピクセル単位で指定します。

データセットを、ネットワークに学習させるための学習セットとネットワークを評価するためのテスト セットに分割します。データの 90% を学習セットに使用し、残りをテスト セットに使用します。

rng(0);
shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * length(shuffledIndices));
trainingDataTbl = vehicleDataset(shuffledIndices(1:idx), :);
testDataTbl = vehicleDataset(shuffledIndices(idx+1:end), :);

イメージを読み込むためのイメージ データストアを作成します。

imdsTrain = imageDatastore(trainingDataTbl.imageFilename);
imdsTest = imageDatastore(testDataTbl.imageFilename);

グラウンド トゥルース境界ボックス用のデータストアを作成します。

bldsTrain = boxLabelDatastore(trainingDataTbl(:, 2:end));
bldsTest = boxLabelDatastore(testDataTbl(:, 2:end));

イメージ データストアとボックス ラベル データストアを組み合わせます。

trainingData = combine(imdsTrain, bldsTrain);
testData = combine(imdsTest, bldsTest);

validateInputData を使用して、次のような無効なイメージ、境界ボックス、またはラベルを検出します。

  • 無効なイメージ形式であるか、NaN を含むサンプル

  • ゼロ/NaN/Inf を含むか、空である境界ボックス

  • 欠損ラベル/非カテゴリカル ラベル。

境界ボックスの値は、有限、正、整数で、NaN 以外でなければなりません。また、正の高さと幅をもつイメージ境界内に収まらなくてはなりません。無効なサンプルは、破棄するか、適切な学習のために修正しなければなりません。

validateInputData(trainingData);
validateInputData(testData);

データ拡張

データ拡張は、学習中に元のデータをランダムに変換してネットワークの精度を高めるために使用されます。データ拡張を使用すると、ラベル付き学習サンプルの数を実際に増やさずに、学習データをさらに多様化させることができます。

関数 transform を使用して、カスタムのデータ拡張を学習データに適用します。この例の最後にリストされている補助関数 augmentData によって、入力データに以下の拡張が適用されます。

  • HSV 空間でのカラー ジッターの付加

  • 水平方向のランダムな反転

  • 10% のランダムなスケーリング

augmentedTrainingData = transform(trainingData, @augmentData);

同じイメージを 4 回読み取り、拡張された学習データを表示します。

% Visualize the augmented images.
augmentedData = cell(4,1);
for k = 1:4
    data = read(augmentedTrainingData);
    augmentedData{k} = insertShape(data{1,1}, 'Rectangle', data{1,2});
    reset(augmentedTrainingData);
end
figure
montage(augmentedData, 'BorderSize', 10)

学習データの前処理

ネットワーク入力サイズを指定します。ネットワーク入力サイズを選択する際には、ネットワーク自体の実行に必要な最小サイズ、学習イメージのサイズ、および選択したサイズでデータを処理することによって発生する計算コストを考慮します。可能な場合、学習イメージのサイズに近く、ネットワークに必要な入力サイズより大きいネットワーク入力サイズを選択します。この例の実行にかかる計算コストを削減するため、ネットワーク入力サイズを [227 227 3] に指定します。

networkInputSize = [227 227 3];

拡張された学習データを前処理して学習用に準備します。この例の最後にリストされている補助関数 preprocessData によって、入力データに以下の前処理演算が適用されます。

  • イメージのサイズをネットワークの入力サイズに変更します。

  • イメージのピクセルを [0 1] の範囲にスケーリングします。

preprocessedTrainingData = transform(augmentedTrainingData, @(data)preprocessData(data, networkInputSize));

前処理された学習データを読み取ります。

data = read(preprocessedTrainingData);

境界ボックスと共にイメージを表示します。

I = data{1,1};
bbox = data{1,2};
annotatedImage = insertShape(I, 'Rectangle', bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)

データストアをリセットします。

reset(preprocessedTrainingData);

YOLO v3 ネットワークの定義

この例の YOLO v3 ネットワークは、SqueezeNet がベースとなっています。このネットワークは、SqueezeNet の特徴抽出ネットワークを使用し、最後に 2 つの検出ヘッドが追加されています。2 番目の検出ヘッドのサイズは、最初の検出ヘッドの 2 倍となっているため、小さなオブジェクトをより的確に検出できます。検出したいオブジェクトのサイズに基づいて、サイズが異なる検出ヘッドを任意の数だけ指定できます。YOLO v3 ネットワークは、学習データを使用して推定されたアンカー ボックスを使用します。これにより、データセットの種類に対応した初期の事前確率が改善され、ボックスを正確に予測できるようにネットワークに学習させることができます。アンカー ボックスの詳細については、アンカー ボックスによるオブジェクトの検出を参照してください。

この例の YOLO v3 ネットワークを次の図に示します。

ディープ ネットワーク デザイナー (Deep Learning Toolbox)を使用して、この図に示されているネットワークを作成することができます。

この例で使用されている学習イメージは、227 x 227 より大きく、サイズがまちまちであるため、まず、transform を使用して、アンカー ボックスを計算するための学習データを前処理します。アンカーの数と平均 IoU との良好なトレードオフを実現するため、アンカーの数を 6 に指定します。関数 estimateAnchorBoxes を使用してアンカー ボックスを推定します。アンカー ボックスの推定の詳細については、学習データからのアンカー ボックスの推定を参照してください。事前学習済みの YOLOv3 オブジェクト検出器を使用する場合、特定の学習データセットで計算されたアンカー ボックスを指定する必要があります。推定プロセスは確定的なものではないことに注意してください。推定されたアンカー ボックスが他のハイパーパラメーターの調整中に変化しないように、rng を使用して推定前に乱数シードを設定します。

rng(0)
trainingDataForEstimation = transform(trainingData, @(data)preprocessData(data, networkInputSize));
numAnchors = 6;
[anchorBoxes, meanIoU] = estimateAnchorBoxes(trainingDataForEstimation, numAnchors)
anchorBoxes = 6×2

    41    34
   163   130
    98    93
   144   125
    33    24
    69    66

meanIoU = 0.8507

2 つの検出ヘッドで使用するアンカー ボックスを選択するための anchorBoxMasks を指定します。anchorBoxMasks は、M 行 1 列の cell 配列です。ここで、M は検出ヘッドの数を表します。各検出ヘッドは、anchorBoxes に含まれるアンカーの行インデックスを表す 1 行 N 列の配列で構成されます。ここで、N は使用するアンカー ボックスの数です。サイズに基づいて、各検出ヘッドのアンカー ボックスを選択します。スケールが小さい場合は大きいアンカー ボックスを使用し、スケールが大きい場合は小さいアンカー ボックスを使用します。これを行うには、より大きいアンカー ボックスが先頭に来るようにアンカー ボックスを並べ替えてから、最初の 3 つのアンカー ボックスを最初の検出ヘッドに割り当て、次の 3 つのアンカー ボックスを 2 番目の検出ヘッドに割り当てます。

area = anchorBoxes(:, 1).*anchorBoxes(:, 2);
[~, idx] = sort(area, 'descend');
anchorBoxes = anchorBoxes(idx, :);
anchorBoxMasks = {[1,2,3]
    [4,5,6]
    };

Imagenet データセットで事前学習された SqueezeNet ネットワークを読み込みます。MobileNet-v2 や ResNet-18 といった他の事前学習済みのネットワークを選択して読み込むこともできます。YOLO v3 は、事前学習済みのネットワークを使用すると、より優れたパフォーマンスを発揮し、より高速に学習させることができます。

次に、特徴抽出ネットワークを作成します。最適な特徴抽出層を選択するには、試行錯誤が必要です。analyzeNetwork を使用すると、ネットワーク内に存在する可能性がある特徴抽出層の名前を検索できます。この例では、この例の最後にリストされている補助関数 squeezenetFeatureExtractor を使用して、特徴抽出層 'fire9-concat' の後の層を削除します。この層の後の層は、分類タスクに固有のものであり、オブジェクト検出の役には立ちません。

baseNetwork = squeezenet;
lgraph = squeezenetFeatureExtractor(baseNetwork, networkInputSize);

オブジェクト クラスの名前、検出するオブジェクト クラスの数、およびアンカー ボックスごとの予測要素の数を指定します。アンカー ボックスごとの予測数は、5 にオブジェクト クラスの数を加えた数値に設定します。"5" は 4 つの境界ボックス属性と 1 つのオブジェクト信頼度を表しています。事前学習済みの YOLOv3 ネットワークを使用する場合、ネットワークの学習用に指定したのと同じ順序でクラス名を指定します。

classNames = {'vehicle'};
numClasses = size(classNames, 2);
numPredictorsPerAnchor = 5 + numClasses;

特徴抽出ネットワークに検出ヘッドを追加します。各検出ヘッドは、境界ボックス座標 (x、y、幅、高さ)、オブジェクトの信頼度、および各アンカー ボックス マスクに対するクラスの確率を予測します。そのため、各検出ヘッドにおける最終畳み込み層の出力フィルターの数は、アンカー ボックス マスクの数と、アンカー ボックスあたりの予測要素の数を乗算したものになります。サポート関数 addFirstDetectionHeadaddSecondDetectionHead を使用して、特徴抽出ネットワークに検出ヘッドを追加します。

lgraph = addFirstDetectionHead(lgraph, anchorBoxMasks{1}, numPredictorsPerAnchor);
lgraph = addSecondDetectionHead(lgraph, anchorBoxMasks{2}, numPredictorsPerAnchor);

最後に、最初の検出ヘッドを特徴抽出層に接続し、2 番目の検出ヘッドを最初の検出ヘッドの出力に接続して接続ヘッドどうしを接続します。さらに、2 番目の検出ヘッドのアップサンプリングされた特徴と、'fire5-concat' 層からの特徴をマージし、より意味のあるセマンティック情報を 2 番目の検出ヘッドから取得します。

lgraph = connectLayers(lgraph, 'fire9-concat', 'conv1Detection1');
lgraph = connectLayers(lgraph, 'relu1Detection1', 'upsample1Detection2');
lgraph = connectLayers(lgraph, 'fire5-concat', 'depthConcat1Detection2/in2');

この検出ヘッドは、ネットワークの出力層を構成します。出力特徴量を抽出するには、M 行 1 列の配列を使用して検出ヘッドの名前を指定します。M は検出ヘッドの数です。ネットワークで出現する順に、検出ヘッドの名前を指定します。

networkOutputs = ["conv2Detection1"
    "conv2Detection2"
    ];

あるいは、SqueezeNet を使って上記で作成したネットワークの代わりに、MS-COCO などのより大規模なデータセットで学習させた、他の事前学習済みの YOLOv3 アーキテクチャを使用して、カスタム オブジェクト検出タスクで検出器の転移学習を行うこともできます。転移学習は、最終畳み込み層のフィルター数の値を変更するか、上述した新しい検出ヘッドを作成することによって実現できます。後者の場合は、関連する層を抽出するために squeezenetFeatureExtractor を参照します。転移学習のワークフローは、カスタム オブジェクト検出のクラスが、事前学習済みのネットワークで学習させたクラスまたはクラスのサブクラスの 1 つとして存在している場合に推奨されます。

学習オプションの指定

これらの学習オプションを指定します。

  • エポック数を 70 に設定します。

  • ミニ バッチ サイズを 8 に設定します。使用されるミニ バッチ サイズが大きい場合は、学習率を高めることで安定した学習が可能になります。.ただし、これは利用可能なメモリに応じて設定しなければなりません。

  • 学習率を 0.001 に設定します。

  • ウォームアップ期間を 1000 反復に設定します。このパラメーターは、式 learningRate×(iterationwarmupPeriod)4 に基づいて学習率を指数関数的に増やす反復処理の回数を表します。これは、学習率が高いときの勾配を安定させるのに役立ちます。

  • L2 正則化係数を 0.0005 に設定します。

  • ペナルティしきい値を 0.5 に設定します。グラウンド トゥルースとのオーバーラップが 0.5 未満の検出にペナルティが課されます。

  • 勾配の速度を [] に初期化します。これは、SGDM が勾配の速度を保存するのに使用されます。

numEpochs = 70;
miniBatchSize = 8;
learningRate = 0.001;
warmupPeriod = 1000;
l2Regularization = 0.0005;
penaltyThreshold = 0.5;
velocity = [];

モデルの学習

GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™、および Compute Capability 3.0 以上の CUDA® 対応 NVIDIA® GPU が必要です。

関数 minibatchqueue を使用して、前処理された学習データをサポート関数 createBatchData でバッチに分割します。このサポート関数は、バッチ処理されたイメージ、およびそれぞれのクラス ID と組み合わせた境界ボックスを返します。学習用バッチ データの抽出を高速化するには、dispatchInBackground を "true" に設定して、並列プールの使用を確保する必要があります。

minibatchqueue は、GPU の可用性を自動的に検出します。GPU がない場合、または学習で GPU を使用しない場合は、OutputEnvironment パラメーターを "cpu" に設定します。

if canUseParallelPool
   dispatchInBackground = true;
else
   dispatchInBackground = false;
end

mbqTrain = minibatchqueue(preprocessedTrainingData, 2,...
        "MiniBatchSize", miniBatchSize,...
        "MiniBatchFcn", @(images, boxes, labels) createBatchData(images, boxes, labels, classNames), ...
        "MiniBatchFormat", ["SSCB", ""],...
        "DispatchInBackground", dispatchInBackground,...
        "OutputCast", ["", "double"]);

カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層グラフを dlnetwork オブジェクトに変換します。次に、サポート関数 configureTrainingProgressPlotter を使用して、学習の進行状況のプロッターを作成します。

最後に、カスタム学習ループを指定します。それぞれの反復で次を行います。

  • minibatchqueue. からデータを読み取ります。データがなくなった場合は、minibatchqueue をリセットしてシャッフルします。

  • 関数 dlfeval および modelGradients を使用してモデルの勾配を評価します。サポート関数としてリストされている関数 modelGradients は、net の学習可能なパラメーターに関する損失勾配、対応するミニバッチの損失、および現在のバッチの状態を返します。

  • よりロバストな学習を実現するため、重み減衰係数を勾配に適用して正則化を実施します。

  • サポート関数 piecewiseLearningRateWithWarmup を使用し、反復に基づいて学習率を決定します。

  • 関数 sgdmupdate を使用してネットワーク パラメーターを更新します。

  • 移動平均を使用して、netstate パラメーターを更新します。

  • すべての反復について、学習率、合計損失、個々の損失 (ボックス損失、オブジェクト損失、クラス損失) を表示します。これらは、各反復においてそれぞれの損失がどのように変化しているかを解釈するために使用できます。たとえば、数回の反復後、ボックス損失にスパイクが突然発生していれば、予測に Inf または NaN があることを意味しています。

  • 学習の進行状況プロットを更新します。

数エポックにわたって損失が飽和する場合は、学習を終了することもできます。

if doTraining
    % Convert layer graph to dlnetwork.
    net = dlnetwork(lgraph);
    
    % Create subplots for the learning rate and mini-batch loss.
    fig = figure;
    [lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(fig);

    iteration = 0;
    % Custom training loop.
    for epoch = 1:numEpochs
          
        reset(mbqTrain);
        shuffle(mbqTrain);
        
        while(hasdata(mbqTrain))
            iteration = iteration + 1;
           
            [XTrain, YTrain] = next(mbqTrain);
            
            % Evaluate the model gradients and loss using dlfeval and the
            % modelGradients function.
            [gradients, state, lossInfo] = dlfeval(@modelGradients, net, XTrain, YTrain, anchorBoxes, anchorBoxMasks, penaltyThreshold, networkOutputs);
    
            % Apply L2 regularization.
            gradients = dlupdate(@(g,w) g + l2Regularization*w, gradients, net.Learnables);
    
            % Determine the current learning rate value.
            currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs);
    
            % Update the network learnable parameters using the SGDM optimizer.
            [net, velocity] = sgdmupdate(net, gradients, velocity, currentLR);
    
            % Update the state parameters of dlnetwork.
            net.State = state;
              
            % Display progress.
            displayLossInfo(epoch, iteration, currentLR, lossInfo);  
                
            % Update training plot with new points.
            updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, lossInfo.totalLoss);
        end        
    end
end

モデルの評価

Computer Vision System Toolbox™ には、平均適合率 (evaluateDetectionPrecision) や対数平均ミス率 (evaluateDetectionMissRate) などの一般的なメトリクスを測定するオブジェクト検出器の評価関数が用意されています。この例では、平均適合率メトリクスを使用します。平均適合率は、検出器が正しい分類を実行できること (適合率) と検出器がすべての関連オブジェクトを検出できること (再現率) を示す単一の数値です。

以下の手順を実行して、テスト データで学習させた dlnetwork オブジェクトの net を評価します。

  • 信頼度しきい値に 0.5 を指定し、信頼度スコアがこの値より高い検出のみ保持します。

  • オーバーラップしきい値に 0.5 を指定し、オーバーラップしている検出を削除します。

  • 学習データと同じ前処理変換をテスト データに適用します。データ拡張はテスト データには適用されないことに注意してください。テスト データは元のデータを代表するもので、バイアスのない評価を行うために変更なしで使用されなければなりません。

  • preprocessedTestData に対して検出器を実行し、検出結果を収集します。サポート関数 yolov3Detect を使用して、境界ボックス、オブジェクトの信頼度スコア、およびクラス ラベルを取得します。

  • 予測された resultspreprocessedTestData を引数として、evaluateDetectionPrecision を呼び出します。

confidenceThreshold = 0.5;
overlapThreshold = 0.5;

% Create the test datastore.
preprocessedTestData = transform(testData, @(data)preprocessData(data, networkInputSize));

% Create a table to hold the bounding boxes, scores, and labels returned by
% the detector. 
numImages = size(testDataTbl, 1);
results = table('Size', [0 3], ...
    'VariableTypes', {'cell','cell','cell'}, ...
    'VariableNames', {'Boxes','Scores','Labels'});

mbqTest = minibatchqueue(preprocessedTestData, 1, ...
    "MiniBatchSize", miniBatchSize, ...
    "MiniBatchFormat", "SSCB");

% Run detector on images in the test set and collect results.
while hasdata(mbqTest)
    % Read the datastore and get the image.
    XTest = next(mbqTest);
    
    % Run the detector.
    [bboxes, scores, labels] = yolov3Detect(net, XTest, networkOutputs, anchorBoxes, anchorBoxMasks, confidenceThreshold, overlapThreshold, classNames);
    
    % Collect the results.
    tbl = table(bboxes, scores, labels, 'VariableNames', {'Boxes','Scores','Labels'});
    results = [results; tbl];
end

% Evaluate the object detector using Average Precision metric.
[ap, recall, precision] = evaluateDetectionPrecision(results, preprocessedTestData);

適合率/再現率 (PR) の曲線は、さまざまなレベルの再現率における検出器の適合率を示しています。すべてのレベルの再現率で適合率が 1 になるのが理想的です。

% Plot precision-recall curve.
figure
plot(recall, precision)
xlabel('Recall')
ylabel('Precision')
grid on
title(sprintf('Average Precision = %.2f', ap))

YOLO v3 を使用したオブジェクトの検出

ネットワークを使用してオブジェクトを検出します。

  • イメージを読み取ります。

  • イメージを dlarray に変換します。可能であれば GPU を使用します。

  • サポート関数 yolov3Detect を使用して、予測された境界ボックス、信頼度スコア、およびクラス ラベルを取得します。

  • 境界ボックスと信頼度スコアと共にイメージを表示します。

% Read the datastore.
reset(preprocessedTestData)
data = read(preprocessedTestData);

% Get the image.
I = data{1};

% Convert to dlarray.
XTest = dlarray(I, 'SSCB');

executionEnvironment = "auto";

% If GPU is available, then convert data to gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    XTest = gpuArray(XTest);
end

[bboxes, scores, labels] = yolov3Detect(net, XTest, networkOutputs, anchorBoxes, anchorBoxMasks, confidenceThreshold, overlapThreshold, classNames);

% Clear the persistent variables used in the yolov3Detect function to avoid retaining their values in memory.
clear yolov3Detect  

% Display the detections on image.
if ~isempty(scores{1})
    I = insertObjectAnnotation(I, 'rectangle', bboxes{1}, scores{1});
end
figure
imshow(I)

サポート関数

モデル勾配関数

関数 modelGradients は、dlnetwork オブジェクトの net、入力データ XTrain および対応するグラウンド トゥルース ボックス YTrain から成るミニバッチ、アンカー ボックス、アンカー ボックス マスク、指定されたペナルティしきい値、およびネットワーク出力名を入力引数として取り、net の学習可能なパラメーターの損失勾配、対応するミニバッチの損失、および現在のバッチの状態を返します。

モデル勾配関数は、以下の演算を行って、合計損失と勾配を計算します。

  • サポート関数 yolov3Forward を使用して、イメージの入力バッチから予測を生成します。

  • 後処理を行うため、CPU 上にある予測を収集します。

  • YOLO v3 のグリッド セルの座標から得られた予測を境界ボックスの座標に変換します。これにより、サポート関数 generateTiledAnchorsapplyAnchorBoxOffsets を使用してグラウンド トゥルース データと簡単に比較できるようになります。

  • 変換された予測とグラウンド トゥルース データを使用して、損失計算ターゲットを生成します。これらのターゲットは、境界ボックスの位置 (x、y、幅、高さ)、オブジェクトの信頼度、およびクラスの確率を対象として生成されます。サポート関数 generateTargets を参照してください。

  • 予測された境界ボックス座標とターゲット ボックスとの平均二乗誤差を計算します。サポート関数 bboxOffsetLoss を参照してください。

  • 予測されたオブジェクトの信頼度スコアとターゲット オブジェクトの信頼度スコアとのバイナリ交差エントロピーを決定します。サポート関数 objectnessLoss を参照してください。

  • 予測されたオブジェクトのクラスとターゲットとのバイナリ交差エントロピーを決定します。サポート関数 classConfidenceLoss を参照してください。

  • すべての損失を加算して合計損失を計算します。

  • 合計損失に対する学習可能パラメーターの勾配を計算します。

function [gradients, state, info] = modelGradients(net, XTrain, YTrain, anchors, mask, penaltyThreshold, networkOutputs)
inputImageSize = size(XTrain,1:2);

% Gather the ground truths in the CPU for post processing
YTrain = gather(extractdata(YTrain));

% Extract the predictions from the network.
[YPredCell, state] = yolov3Forward(net,XTrain,networkOutputs,mask);

% Gather the activations in the CPU for post processing and extract dlarray data. 
gatheredPredictions = cellfun(@ gather, YPredCell(:,1:6),'UniformOutput',false); 
gatheredPredictions = cellfun(@ extractdata, gatheredPredictions, 'UniformOutput', false);

% Convert predictions from grid cell coordinates to box coordinates.
tiledAnchors = generateTiledAnchors(gatheredPredictions(:,2:5),anchors,mask);
gatheredPredictions(:,2:5) = applyAnchorBoxOffsets(tiledAnchors, gatheredPredictions(:,2:5), inputImageSize);

% Generate target for predictions from the ground truth data.
[boxTarget, objectnessTarget, classTarget, objectMaskTarget, boxErrorScale] = generateTargets(gatheredPredictions, YTrain, inputImageSize, anchors, mask, penaltyThreshold);

% Compute the loss.
boxLoss = bboxOffsetLoss(YPredCell(:,[2 3 7 8]),boxTarget,objectMaskTarget,boxErrorScale);
objLoss = objectnessLoss(YPredCell(:,1),objectnessTarget,objectMaskTarget);
clsLoss = classConfidenceLoss(YPredCell(:,6),classTarget,objectMaskTarget);
totalLoss = boxLoss + objLoss + clsLoss;

info.boxLoss = boxLoss;
info.objLoss = objLoss;
info.clsLoss = clsLoss;
info.totalLoss = totalLoss;

% Compute gradients of learnables with regard to loss.
gradients = dlgradient(totalLoss, net.Learnables);
end

function [YPredCell, state] = yolov3Forward(net, XTrain, networkOutputs, anchorBoxMask)
% Predict the output of network and extract the confidence score, x, y,
% width, height, and class.
YPredictions = cell(size(networkOutputs));
[YPredictions{:}, state] = forward(net, XTrain, 'Outputs', networkOutputs);
YPredCell = extractPredictions(YPredictions, anchorBoxMask);

% Append predicted width and height to the end as they are required
% for computing the loss.
YPredCell(:,7:8) = YPredCell(:,4:5);

% Apply sigmoid and exponential activation.
YPredCell(:,1:6) = applyActivations(YPredCell(:,1:6));
end

function boxLoss = bboxOffsetLoss(boxPredCell, boxDeltaTarget, boxMaskTarget, boxErrorScaleTarget)
% Mean squared error for bounding box position.
lossX = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,1),boxDeltaTarget(:,1),boxMaskTarget(:,1),boxErrorScaleTarget));
lossY = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,2),boxDeltaTarget(:,2),boxMaskTarget(:,1),boxErrorScaleTarget));
lossW = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,3),boxDeltaTarget(:,3),boxMaskTarget(:,1),boxErrorScaleTarget));
lossH = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,4),boxDeltaTarget(:,4),boxMaskTarget(:,1),boxErrorScaleTarget));
boxLoss = lossX+lossY+lossW+lossH;
end

function objLoss = objectnessLoss(objectnessPredCell, objectnessDeltaTarget, boxMaskTarget)
% Binary cross-entropy loss for objectness score.
objLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'TargetCategories','independent'),objectnessPredCell,objectnessDeltaTarget,boxMaskTarget(:,2)));
end

function clsLoss = classConfidenceLoss(classPredCell, classTarget, boxMaskTarget)
% Binary cross-entropy loss for class confidence score.
clsLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'TargetCategories','independent'),classPredCell,classTarget,boxMaskTarget(:,3)));
end

拡張関数とデータ処理関数

function data = augmentData(A)
% Apply random horizontal flipping, and random X/Y scaling. Boxes that get
% scaled outside the bounds are clipped if the overlap is above 0.25. Also,
% jitter image color.

data = cell(size(A));
for ii = 1:size(A,1)
    I = A{ii,1};
    bboxes = A{ii,2};
    labels = A{ii,3};
    sz = size(I);

    if numel(sz) == 3 && sz(3) == 3
        I = jitterColorHSV(I,...
            'Contrast',0.0,...
            'Hue',0.1,...
            'Saturation',0.2,...
            'Brightness',0.2);
    end
    
    % Randomly flip image.
    tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]);
    rout = affineOutputView(sz,tform,'BoundsStyle','centerOutput');
    I = imwarp(I,tform,'OutputView',rout);
    
    % Apply same transform to boxes.
    [bboxes,indices] = bboxwarp(bboxes,tform,rout,'OverlapThreshold',0.25);
    labels = labels(indices);
    
    % Return original data only when all boxes are removed by warping.
    if isempty(indices)
        data(ii,:) = A(ii,:);
    else
        data(ii,:) = {I, bboxes, labels};
    end
end
end


function data = preprocessData(data, targetSize)
% Resize the images and scale the pixels to between 0 and 1. Also scale the
% corresponding bounding boxes.

for ii = 1:size(data,1)
    I = data{ii,1};
    imgSize = size(I);
    
    % Convert an input image with single channel to 3 channels.
    if numel(imgSize) < 3 
        I = repmat(I,1,1,3);
    end
    bboxes = data{ii,2};

    I = im2single(imresize(I,targetSize(1:2)));
    scale = targetSize(1:2)./imgSize(1:2);
    bboxes = bboxresize(bboxes,scale);
    
    data(ii, 1:2) = {I, bboxes};
end
end

function [XTrain, YTrain] = createBatchData(data, groundTruthBoxes, groundTruthClasses, classNames)
% Returns images combined along the batch dimension in XTrain and
% normalized bounding boxes concatenated with classIDs in YTrain

% Concatenate images along the batch dimension.
XTrain = cat(4, data{:,1});

% Get class IDs from the class names.
classNames = repmat({categorical(classNames')}, size(groundTruthClasses));
[~, classIndices] = cellfun(@(a,b)ismember(a,b), groundTruthClasses, classNames, 'UniformOutput', false);

% Append the label indexes and training image size to scaled bounding boxes
% and create a single cell array of responses.
combinedResponses = cellfun(@(bbox, classid)[bbox, classid], groundTruthBoxes, classIndices, 'UniformOutput', false);
len = max( cellfun(@(x)size(x,1), combinedResponses ) );
paddedBBoxes = cellfun( @(v) padarray(v,[len-size(v,1),0],0,'post'), combinedResponses, 'UniformOutput',false);
YTrain = cat(4, paddedBBoxes{:,1});
end

ネットワーク作成関数

function lgraph = squeezenetFeatureExtractor(net, imageInputSize)
% The squeezenetFeatureExtractor function removes the layers after 'fire9-concat'
% in SqueezeNet and also removes any data normalization used by the image input layer.

% Convert to layerGraph.
lgraph = layerGraph(net);

lgraph = removeLayers(lgraph, {'drop9' 'conv10' 'relu_conv10' 'pool10' 'prob' 'ClassificationLayer_predictions'});
inputLayer = imageInputLayer(imageInputSize,'Normalization','none','Name','data');
lgraph = replaceLayer(lgraph,'data',inputLayer);
end

function lgraph = addFirstDetectionHead(lgraph,anchorBoxMasks,numPredictorsPerAnchor)
% The addFirstDetectionHead function adds the first detection head.

numAnchorsScale1 = size(anchorBoxMasks, 2);
% Compute the number of filters for last convolution layer.
numFilters = numAnchorsScale1*numPredictorsPerAnchor;
firstDetectionSubNetwork = [
    convolution2dLayer(3,256,'Padding','same','Name','conv1Detection1','WeightsInitializer','he')
    reluLayer('Name','relu1Detection1')
    convolution2dLayer(1,numFilters,'Padding','same','Name','conv2Detection1','WeightsInitializer','he')
    ];
lgraph = addLayers(lgraph,firstDetectionSubNetwork);
end

function lgraph = addSecondDetectionHead(lgraph,anchorBoxMasks,numPredictorsPerAnchor)
% The addSecondDetectionHead function adds the second detection head.

numAnchorsScale2 = size(anchorBoxMasks, 2);
% Compute the number of filters for the last convolution layer.
numFilters = numAnchorsScale2*numPredictorsPerAnchor;
    
secondDetectionSubNetwork = [
    upsampleLayer(2,'upsample1Detection2')
    depthConcatenationLayer(2, 'Name', 'depthConcat1Detection2');
    convolution2dLayer(3,128,'Padding','same','Name','conv1Detection2','WeightsInitializer','he')
    reluLayer('Name','relu1Detection2')
    convolution2dLayer(1,numFilters,'Padding','same','Name','conv2Detection2','WeightsInitializer','he')
    ];
lgraph = addLayers(lgraph,secondDetectionSubNetwork);
end

学習率スケジュール関数

function currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs)
% The piecewiseLearningRateWithWarmup function computes the current
% learning rate based on the iteration number.
persistent warmUpEpoch;

if iteration <= warmupPeriod
    % Increase the learning rate for number of iterations in warmup period.
    currentLR = learningRate * ((iteration/warmupPeriod)^4);
    warmUpEpoch = epoch;
elseif iteration >= warmupPeriod && epoch < warmUpEpoch+floor(0.6*(numEpochs-warmUpEpoch))
    % After warm up period, keep the learning rate constant if the remaining number of epochs is less than 60 percent. 
    currentLR = learningRate;
    
elseif epoch >= warmUpEpoch + floor(0.6*(numEpochs-warmUpEpoch)) && epoch < warmUpEpoch+floor(0.9*(numEpochs-warmUpEpoch))
    % If the remaining number of epochs is more than 60 percent but less
    % than 90 percent multiply the learning rate by 0.1.
    currentLR = learningRate*0.1;
    
else
    % If remaining epochs are more than 90 percent multiply the learning
    % rate by 0.01.
    currentLR = learningRate*0.01;
end

end

予測関数

function [bboxes,scores,labels] = yolov3Detect(net, XTest, networkOutputs, anchors, anchorBoxMask, confidenceThreshold, overlapThreshold, classes)
% The yolov3Detect function detects the bounding boxes, scores, and labels in an image.

imageSize = size(XTest, [1,2]);

% Find the input image layer and get the network input size. To retain 'networkInputSize' in memory and avoid
% recalculating it, declare it as persistent. 
persistent networkInputSize

if isempty(networkInputSize)
    networkInputIdx = arrayfun( @(x)isa(x,'nnet.cnn.layer.ImageInputLayer'), net.Layers);
    networkInputSize = net.Layers(networkInputIdx).InputSize;  
end

% Predict and filter the detections based on confidence threshold.
predictions = yolov3Predict(net,XTest,networkOutputs,anchorBoxMask);
predictions = cellfun(@ gather, predictions,'UniformOutput',false);
predictions = cellfun(@ extractdata, predictions, 'UniformOutput', false);
tiledAnchors = generateTiledAnchors(predictions(:,2:5),anchors,anchorBoxMask);
predictions(:,2:5) = applyAnchorBoxOffsets(tiledAnchors, predictions(:,2:5), networkInputSize);

numMiniBatch = size(XTest, 4);

bboxes = cell(numMiniBatch, 1);
scores = cell(numMiniBatch, 1);
labels = cell(numMiniBatch, 1);

for ii = 1:numMiniBatch
    fmap = cellfun(@(x) x(:,:,:,ii), predictions, 'UniformOutput', false);
    [bboxes{ii}, scores{ii}, labels{ii}] = ...
        generateYOLOv3Detections(fmap, confidenceThreshold, overlapThreshold, imageSize, classes);
end

end

function YPredCell = yolov3Predict(net,XTrain,networkOutputs,anchorBoxMask)
% Predict the output of network and extract the confidence, x, y,
% width, height, and class.
YPredictions = cell(size(networkOutputs));
[YPredictions{:}] = predict(net, XTrain);
YPredCell = extractPredictions(YPredictions, anchorBoxMask);

% Apply activation to the predicted cell array.
YPredCell = applyActivations(YPredCell);
end

ユーティリティ関数

function YPredCell = applyActivations(YPredCell)
YPredCell(:,1:3) = cellfun(@ sigmoid, YPredCell(:,1:3), 'UniformOutput', false);
YPredCell(:,4:5) = cellfun(@ exp, YPredCell(:,4:5), 'UniformOutput', false);    
YPredCell(:,6) = cellfun(@ sigmoid, YPredCell(:,6), 'UniformOutput', false);
end

function predictions = extractPredictions(YPredictions, anchorBoxMask)
predictions = cell(size(YPredictions, 1),6);
for ii = 1:size(YPredictions, 1)
    % Get the required info on feature size.
    numChannelsPred = size(YPredictions{ii},3);
    numAnchors = size(anchorBoxMask{ii},2);
    numPredElemsPerAnchors = numChannelsPred/numAnchors;
    allIds = (1:numChannelsPred);
    
    stride = numPredElemsPerAnchors;
    endIdx = numChannelsPred;

    % X positions.
    startIdx = 1;
    predictions{ii,2} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
    xIds = startIdx:stride:endIdx;
    
    % Y positions.
    startIdx = 2;
    predictions{ii,3} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
    yIds = startIdx:stride:endIdx;
    
    % Width.
    startIdx = 3;
    predictions{ii,4} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
    wIds = startIdx:stride:endIdx;
    
    % Height.
    startIdx = 4;
    predictions{ii,5} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
    hIds = startIdx:stride:endIdx;
    
    % Confidence scores.
    startIdx = 5;
    predictions{ii,1} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
    confIds = startIdx:stride:endIdx;
    
    % Accummulate all the non-class indexes
    nonClassIds = [xIds yIds wIds hIds confIds];
    
    % Class probabilities.
    % Get the indexes which do not belong to the nonClassIds
    classIdx = setdiff(allIds,nonClassIds);
    predictions{ii,6} = YPredictions{ii}(:,:,classIdx,:);
end
end

function tiledAnchors = generateTiledAnchors(YPredCell,anchorBoxes,anchorBoxMask)
% Generate tiled anchor offset.
tiledAnchors = cell(size(YPredCell));
for i=1:size(YPredCell,1)
    anchors = anchorBoxes(anchorBoxMask{i}, :);
    [h,w,~,n] = size(YPredCell{i,1});
    [tiledAnchors{i,2}, tiledAnchors{i,1}] = ndgrid(0:h-1,0:w-1,1:size(anchors,1),1:n);
    [~,~,tiledAnchors{i,3}] = ndgrid(0:h-1,0:w-1,anchors(:,2),1:n);
    [~,~,tiledAnchors{i,4}] = ndgrid(0:h-1,0:w-1,anchors(:,1),1:n);
end
end

function tiledAnchors = applyAnchorBoxOffsets(tiledAnchors,YPredCell,inputImageSize)
% Convert grid cell coordinates to box coordinates.
for i=1:size(YPredCell,1)
    [h,w,~,~] = size(YPredCell{i,1});  
    tiledAnchors{i,1} = (tiledAnchors{i,1}+YPredCell{i,1})./w;
    tiledAnchors{i,2} = (tiledAnchors{i,2}+YPredCell{i,2})./h;
    tiledAnchors{i,3} = (tiledAnchors{i,3}.*YPredCell{i,3})./inputImageSize(2);
    tiledAnchors{i,4} = (tiledAnchors{i,4}.*YPredCell{i,4})./inputImageSize(1);
end
end

function [lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(f)
% Create the subplots to display the loss and learning rate.
figure(f);
clf
subplot(2,1,1);
ylabel('Learning Rate');
xlabel('Iteration');
learningRatePlotter = animatedline;
subplot(2,1,2);
ylabel('Total Loss');
xlabel('Iteration');
lossPlotter = animatedline;
end

function displayLossInfo(epoch, iteration, currentLR, lossInfo)
% Display loss information for each iteration.
disp("Epoch : " + epoch + " | Iteration : " + iteration + " | Learning Rate : " + currentLR + ...
   " | Total Loss : " + double(gather(extractdata(lossInfo.totalLoss))) + ...
   " | Box Loss : " + double(gather(extractdata(lossInfo.boxLoss))) + ...
   " | Object Loss : " + double(gather(extractdata(lossInfo.objLoss))) + ...
   " | Class Loss : " + double(gather(extractdata(lossInfo.clsLoss))));
end

function updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, totalLoss)
% Update loss and learning rate plots.
addpoints(lossPlotter, iteration, double(extractdata(gather(totalLoss))));
addpoints(learningRatePlotter, iteration, currentLR);
drawnow
end

function net = downloadPretrainedYOLOv3Detector()
% Download a pretrained yolov3 detector.
if ~exist('yolov3SqueezeNetVehicleExample_20b.mat', 'file')
    if ~exist('yolov3SqueezeNetVehicleExample_20b.zip', 'file')
        disp('Downloading pretrained detector (8.9 MB)...');
        pretrainedURL = 'https://ssd.mathworks.com/supportfiles/vision/data/yolov3SqueezeNetVehicleExample_20b.zip';
        websave('yolov3SqueezeNetVehicleExample_20b.zip', pretrainedURL);
    end
    unzip('yolov3SqueezeNetVehicleExample_20b.zip');
end
pretrained = load("yolov3SqueezeNetVehicleExample_20b.mat");
net = pretrained.net;
end

参考文献

1.Redmon, Joseph, and Ali Farhadi. "YOLOv3: An Incremental Improvement." Preprint, submitted April 8, 2018. https://arxiv.org/abs/1804.02767.