Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

YOLO v3 深層学習を使用したオブジェクトの検出

この例では、YOLO v3 オブジェクト検出器に学習させる方法を説明します。

深層学習は、ロバストなオブジェクト検出器に学習させるために使用できる強力な機械学習手法です。オブジェクト検出の方法は、Faster R-CNN、You Only Look Once (YOLO) v2、シングル ショット検出器 (SSD) など複数あります。この例では、YOLO v3 オブジェクト検出器に学習させる方法を説明します。YOLO v3 は YOLO v2 を改良したもので、複数のスケールにおける検出を追加してより小さなオブジェクトを検出できるようになっています。学習で使用される損失関数は、境界ボックス回帰用の平均二乗誤差と、オブジェクト分類用のバイナリ交差エントロピーに分割されており、検出精度が向上しています。

メモ: この例では、Computer Vision Toolbox™ Model for YOLO v3 Object Detection が必要です。Computer Vision Toolbox Model for YOLO v3 Object Detection はアドオン エクスプローラーからインストールできます。アドオンのインストールの詳細については、アドオンの取得と管理を参照してください。

事前学習済みのネットワークのダウンロード

学習の完了を待たなくて済むように、補助関数 downloadPretrainedYOLOv3Detector を使用して、事前学習済みのネットワークをダウンロードします。ネットワークに学習させる場合は、変数 doTrainingtrue に設定します。

doTraining = false;

if ~doTraining
    preTrainedDetector = downloadPretrainedYOLOv3Detector();    
end

データの読み込み

この例では、295 枚のイメージを含んだ小さなラベル付きデータセットを使用します。これらのイメージの多くは、Caltech の Cars 1999 データ セットおよび Cars 2001 データ セットからのものです (Caltech Computational Vision の Web サイトで入手可能)。Pietro Perona 氏によって作成されたもので、許可を得て使用しています。各イメージには、1 または 2 個のラベル付けされた車両インスタンスが含まれています。小さなデータセットは YOLO v3 の学習手順を調べるうえで役立ちますが、実際にロバストなネットワークに学習させるにはラベル付けされたイメージがより多く必要になります。

車両のイメージを解凍し、車両のグラウンド トゥルース データを読み込みます。

unzip vehicleDatasetImages.zip
data = load('vehicleDatasetGroundTruth.mat');
vehicleDataset = data.vehicleDataset;

% Add the full path to the local vehicle data folder.
vehicleDataset.imageFilename = fullfile(pwd, vehicleDataset.imageFilename);

メモ: 複数のクラスの場合、データは 3 列に整理することもできます。最初の列には、パスを含むイメージ ファイル名を格納し、2 列目には境界ボックスを格納します。3 列目は、各境界ボックスに対応するラベル名を含む cell ベクトルでなければなりません。境界ボックスとラベルの調整方法の詳細については、boxLabelDatastoreを参照してください。

すべての境界ボックスは、[x y width height] の形式でなければなりません。このベクトルは、境界ボックスの左上隅とサイズをピクセル単位で指定します。

データセットを、ネットワークに学習させるための学習セットとネットワークを評価するためのテスト セットに分割します。データの 60% を学習セットに使用し、残りをテスト セットに使用します。

rng(0);
shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * length(shuffledIndices));
trainingDataTbl = vehicleDataset(shuffledIndices(1:idx), :);
testDataTbl = vehicleDataset(shuffledIndices(idx+1:end), :);

イメージを読み込むためのイメージ データストアを作成します。

imdsTrain = imageDatastore(trainingDataTbl.imageFilename);
imdsTest = imageDatastore(testDataTbl.imageFilename);

グラウンド トゥルース境界ボックス用のデータストアを作成します。

bldsTrain = boxLabelDatastore(trainingDataTbl(:, 2:end));
bldsTest = boxLabelDatastore(testDataTbl(:, 2:end));

イメージ データストアとボックス ラベル データストアを組み合わせます。

trainingData = combine(imdsTrain, bldsTrain);
testData = combine(imdsTest, bldsTest);

validateInputData を使用して、次のような無効なイメージ、境界ボックス、またはラベルを検出します。

  • 無効なイメージ形式であるか、NaN を含むサンプル

  • ゼロ/NaN/Inf を含むか、空である境界ボックス

  • 欠損ラベル/非カテゴリカル ラベル。

境界ボックスの値は、有限、正、整数で、NaN 以外でなければなりません。また、正の高さと幅をもつイメージ境界内に収まらなくてはなりません。無効なサンプルは、破棄するか、適切な学習のために修正しなければなりません。

validateInputData(trainingData);
validateInputData(testData);

データ拡張

データ拡張は、学習中に元のデータをランダムに変換してネットワークの精度を高めるために使用されます。データ拡張を使用すると、ラベル付き学習サンプルの数を実際に増やさずに、学習データをさらに多様化させることができます。

関数 transform を使用して、カスタムのデータ拡張を学習データに適用します。この例の最後にリストされている補助関数 augmentData によって、入力データに以下の拡張が適用されます。

  • HSV 空間でのカラー ジッターの付加

  • 水平方向のランダムな反転

  • 10% のランダムなスケーリング

augmentedTrainingData = transform(trainingData, @augmentData);

同じイメージを 4 回読み取り、拡張された学習データを表示します。

% Visualize the augmented images.
augmentedData = cell(4,1);
for k = 1:4
    data = read(augmentedTrainingData);
    augmentedData{k} = insertShape(data{1,1}, 'Rectangle', data{1,2});
    reset(augmentedTrainingData);
end
figure
montage(augmentedData, 'BorderSize', 10)

YOLO v3 オブジェクト検出器の定義

この例の YOLO v3 検出器は、SqueezeNet がベースとなっています。この検出器は、SqueezeNet の特徴抽出ネットワークを使用し、最後に 2 つの検出ヘッドが追加されています。2 番目の検出ヘッドのサイズは、最初の検出ヘッドの 2 倍となっているため、小さなオブジェクトをより的確に検出できます。検出したいオブジェクトのサイズに基づいて、サイズが異なる検出ヘッドを任意の数だけ指定できます。YOLO v3 検出器は、学習データを使用して推定されたアンカー ボックスを使用します。これにより、データ セットの種類に対応した初期の事前確率が改善され、ボックスを正確に予測できるように検出器に学習させることができます。アンカー ボックスの詳細については、アンカー ボックスによるオブジェクトの検出を参照してください。

YOLO v3 検出器内に存在する YOLO v3 ネットワークを次の図に示します。

ディープ ネットワーク デザイナー (Deep Learning Toolbox)を使用して、この図に示されているネットワークを作成することができます。

ネットワーク入力サイズを指定します。ネットワーク入力サイズを選択する際には、ネットワーク自体の実行に必要な最小サイズ、学習イメージのサイズ、および選択したサイズでデータを処理することによって発生する計算コストを考慮します。可能な場合、学習イメージのサイズに近く、ネットワークに必要な入力サイズより大きいネットワーク入力サイズを選択します。この例の実行にかかる計算コストを削減するため、ネットワーク入力サイズを [227 227 3] に指定します。

networkInputSize = [227 227 3];

この例で使用されている学習イメージは、227 x 227 より大きく、サイズがまちまちであるため、まず、transform を使用して、アンカー ボックスを計算するための学習データを前処理します。アンカーの数と平均 IoU との良好なトレードオフを実現するため、アンカーの数を 6 に指定します。関数 estimateAnchorBoxes を使用してアンカー ボックスを推定します。アンカー ボックスの推定の詳細については、学習データからのアンカー ボックスの推定を参照してください。事前学習済みの YOLOv3 オブジェクト検出器を使用する場合、特定の学習データセットで計算されたアンカー ボックスを指定する必要があります。推定プロセスは確定的なものではないことに注意してください。推定されたアンカー ボックスが他のハイパーパラメーターの調整中に変化しないように、rng を使用して推定前に乱数シードを設定します。

rng(0)
trainingDataForEstimation = transform(trainingData, @(data)preprocessData(data, networkInputSize));
numAnchors = 6;
[anchors, meanIoU] = estimateAnchorBoxes(trainingDataForEstimation, numAnchors)
anchors = 6×2

    41    34
   163   130
    98    93
   144   125
    33    24
    69    66

meanIoU = 0.8507

両方の検出ヘッドで使用する anchorBoxes を指定します。anchorBoxes は、M 行 1 列の cell 配列です。ここで、M は検出ヘッドの数を表します。各検出ヘッドは、anchors の N 行 2 列の行列で構成されます。ここで、N は使用するアンカーの数です。特徴マップのサイズに基づいて、各検出ヘッドの anchorBoxes を選択します。スケールが小さい場合は大きい anchors を使用し、スケールが大きい場合は小さい anchors を使用します。これを行うには、より大きいアンカー ボックスが先頭に来るように anchors を並べ替えてから、最初の 3 つを最初の検出ヘッドに割り当て、次の 3 つを 2 番目の検出ヘッドに割り当てます。

area = anchors(:, 1).*anchors(:, 2);
[~, idx] = sort(area, 'descend');
anchors = anchors(idx, :);
anchorBoxes = {anchors(1:3,:)
    anchors(4:6,:)
    };

Imagenet データ セットで事前学習された SqueezeNet ネットワークを読み込んでから、クラス名を指定します。COCO データ セットで学習させた tiny-yolov3-cocodarknet53-coco、Imagenet データ セットで学習させた MobileNet-v2 や ResNet-18 といった他の事前学習済みのネットワークを選択して読み込むこともできます。YOLO v3 は、事前学習済みのネットワークを使用すると、より優れたパフォーマンスを発揮し、より高速に学習させることができます。

baseNetwork = squeezenet;
classNames = trainingDataTbl.Properties.VariableNames(2:end);

次に、検出ネットワーク ソースを追加して、yolov3ObjectDetector オブジェクトを作成します。最適な検出ネットワーク ソースを選択するには、試行錯誤が必要です。analyzeNetwork を使用すると、ネットワーク内に存在する可能性がある検出ネットワーク ソースの名前を検索できます。この例では、層 fire9-concat と層 fire5-concatDetectionNetworkSource として使用します。

yolov3Detector = yolov3ObjectDetector(baseNetwork, classNames, anchorBoxes, 'DetectionNetworkSource', {'fire9-concat', 'fire5-concat'});

あるいは、SqueezeNet を使って上記で作成したネットワークの代わりに、MS-COCO などのより大規模なデータセットで学習させた、他の事前学習済みの YOLOv3 アーキテクチャを使用して、カスタム オブジェクト検出タスクで検出器の転移学習を行うこともできます。転移学習は、classNames と anchorBoxes を変更することによって実現できます。

学習データの前処理

拡張された学習データを前処理して学習用に準備します。yolov3ObjectDetectorpreprocessメソッドによって、入力データに以下の前処理演算が適用されます。

  • 縦横比を維持したまま、イメージのサイズをネットワークの入力サイズに変更します。

  • イメージのピクセルを [0 1] の範囲にスケーリングします。

preprocessedTrainingData = transform(augmentedTrainingData, @(data)preprocess(yolov3Detector, data));

前処理された学習データを読み取ります。

data = read(preprocessedTrainingData);

境界ボックスと共にイメージを表示します。

I = data{1,1};
bbox = data{1,2};
annotatedImage = insertShape(I, 'Rectangle', bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)

データストアをリセットします。

reset(preprocessedTrainingData);

学習オプションの指定

これらの学習オプションを指定します。

  • エポック数を 80 に設定します。

  • ミニ バッチ サイズを 8 に設定します。使用されるミニ バッチ サイズが大きい場合は、学習率を高めることで安定した学習が可能になります。.ただし、これは利用可能なメモリに応じて設定しなければなりません。

  • 学習率を 0.001 に設定します。

  • ウォームアップ期間を 1000 反復に設定します。このパラメーターは、式 learningRate×(iterationwarmupPeriod)4 に基づいて学習率を指数関数的に増やす反復処理の回数を表します。これは、学習率が高いときの勾配を安定させるのに役立ちます。

  • L2 正則化係数を 0.0005 に設定します。

  • ペナルティしきい値を 0.5 に設定します。グラウンド トゥルースとのオーバーラップが 0.5 未満の検出にペナルティが課されます。

  • 勾配の速度を [] に初期化します。これは、SGDM が勾配の速度を保存するのに使用されます。

numEpochs = 80;
miniBatchSize = 8;
learningRate = 0.001;
warmupPeriod = 1000;
l2Regularization = 0.0005;
penaltyThreshold = 0.5;
velocity = [];

モデルの学習

GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™、および CUDA® 対応の NVIDIA® GPU が必要です。サポートされる Compute Capability の詳細については、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

関数 minibatchqueue を使用して、前処理された学習データをサポート関数 createBatchData でバッチに分割します。このサポート関数は、バッチ処理されたイメージ、およびそれぞれのクラス ID と組み合わせた境界ボックスを返します。学習用バッチ データの抽出を高速化するには、dispatchInBackground を "true" に設定して、並列プールの使用を確保する必要があります。

minibatchqueue は、GPU の可用性を自動的に検出します。GPU がない場合、または学習で GPU を使用しない場合は、OutputEnvironment パラメーターを "cpu" に設定します。

if canUseParallelPool
   dispatchInBackground = true;
else
   dispatchInBackground = false;
end

mbqTrain = minibatchqueue(preprocessedTrainingData, 2,...
        "MiniBatchSize", miniBatchSize,...
        "MiniBatchFcn", @(images, boxes, labels) createBatchData(images, boxes, labels, classNames), ...
        "MiniBatchFormat", ["SSCB", ""],...
        "DispatchInBackground", dispatchInBackground,...
        "OutputCast", ["", "double"]);

サポート関数 configureTrainingProgressPlotter を使用して、学習の進行状況のプロッターを作成し、カスタム学習ループを使用して検出器オブジェクトに学習させながら、プロットを確認します。

最後に、カスタム学習ループを指定します。それぞれの反復で次を行います。

  • minibatchqueue からデータを読み取ります。データがなくなった場合は、minibatchqueue をリセットしてシャッフルします。

  • 関数 dlfeval および modelGradients を使用してモデルの勾配を評価します。サポート関数としてリストされている関数 modelGradients は、net の学習可能なパラメーターに関する損失勾配、対応するミニバッチの損失、および現在のバッチの状態を返します。

  • よりロバストな学習を実現するため、重み減衰係数を勾配に適用して正則化を実施します。

  • サポート関数 piecewiseLearningRateWithWarmup を使用し、反復に基づいて学習率を決定します。

  • 関数 sgdmupdate を使用して検出器パラメーターを更新します。

  • 移動平均を使用して、検出器の state パラメーターを更新します。

  • すべての反復について、学習率、合計損失、個々の損失 (ボックス損失、オブジェクト損失、クラス損失) を表示します。これらは、各反復においてそれぞれの損失がどのように変化しているかを解釈するために使用できます。たとえば、数回の反復後、ボックス損失にスパイクが突然発生していれば、予測に Inf または NaN があることを意味しています。

  • 学習の進行状況プロットを更新します。

数エポックにわたって損失が飽和する場合は、学習を終了することもできます。

if doTraining
    
    % Create subplots for the learning rate and mini-batch loss.
    fig = figure;
    [lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(fig);

    iteration = 0;
    % Custom training loop.
    for epoch = 1:numEpochs
          
        reset(mbqTrain);
        shuffle(mbqTrain);
        
        while(hasdata(mbqTrain))
            iteration = iteration + 1;
           
            [XTrain, YTrain] = next(mbqTrain);
            
            % Evaluate the model gradients and loss using dlfeval and the
            % modelGradients function.
            [gradients, state, lossInfo] = dlfeval(@modelGradients, yolov3Detector, XTrain, YTrain, penaltyThreshold);
    
            % Apply L2 regularization.
            gradients = dlupdate(@(g,w) g + l2Regularization*w, gradients, yolov3Detector.Learnables);
    
            % Determine the current learning rate value.
            currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs);
    
            % Update the detector learnable parameters using the SGDM optimizer.
            [yolov3Detector.Learnables, velocity] = sgdmupdate(yolov3Detector.Learnables, gradients, velocity, currentLR);
    
            % Update the state parameters of dlnetwork.
            yolov3Detector.State = state;
              
            % Display progress.
            displayLossInfo(epoch, iteration, currentLR, lossInfo);  
                
            % Update training plot with new points.
            updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, lossInfo.totalLoss);
        end        
    end
else
    yolov3Detector = preTrainedDetector;
end

モデルの評価

Computer Vision Toolbox™ には、平均適合率 (evaluateDetectionPrecision) や対数平均ミス率 (evaluateDetectionMissRate) などの一般的なメトリクスを測定するオブジェクト検出器の評価関数が用意されています。この例では、平均適合率メトリクスを使用します。平均適合率は、検出器が正しい分類を実行できること (適合率) と検出器がすべての関連オブジェクトを検出できること (再現率) を示す単一の数値です。

results = detect(yolov3Detector,testData,'MiniBatchSize',8);

% Evaluate the object detector using Average Precision metric.
[ap,recall,precision] = evaluateDetectionPrecision(results,testData);

適合率/再現率 (PR) の曲線は、さまざまなレベルの再現率における検出器の適合率を示しています。すべてのレベルの再現率で適合率が 1 になるのが理想的です。

% Plot precision-recall curve.
figure
plot(recall,precision)
xlabel('Recall')
ylabel('Precision')
grid on
title(sprintf('Average Precision = %.2f', ap))

YOLO v3 を使用したオブジェクトの検出

検出器を使用してオブジェクトを検出します。

% Read the datastore.
data = read(testData);

% Get the image.
I = data{1};

[bboxes,scores,labels] = detect(yolov3Detector,I);

% Display the detections on image.
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);

figure
imshow(I)

サポート関数

モデル勾配関数

関数 modelGradients は、yolov3ObjectDetector オブジェクト、入力データ XTrain と対応するグラウンド トゥルース ボックス YTrain から成るミニバッチ、指定されたペナルティしきい値を入力引数として取り、yolov3ObjectDetector の学習可能パラメーターの損失勾配、対応するミニバッチの損失情報、および現在のバッチの状態を返します。

モデル勾配関数は、以下の演算を行って、合計損失と勾配を計算します。

  • forward メソッドを使用して、イメージの入力バッチから予測を生成します。

  • 後処理を行うため、CPU 上にある予測を収集します。

  • YOLO v3 のグリッド セルの座標から得られた予測を境界ボックスの座標に変換します。これにより、yolov3ObjectDetectoranchorBoxGenerator メソッドを使用してグラウンド トゥルース データと簡単に比較できるようになります。

  • 変換された予測とグラウンド トゥルース データを使用して、損失計算ターゲットを生成します。これらのターゲットは、境界ボックスの位置 (x、y、幅、高さ)、オブジェクトの信頼度、およびクラスの確率を対象として生成されます。サポート関数 generateTargets を参照してください。

  • 予測された境界ボックス座標とターゲット ボックスとの平均二乗誤差を計算します。サポート関数 bboxOffsetLoss を参照してください。

  • 予測されたオブジェクトの信頼度スコアとターゲット オブジェクトの信頼度スコアとのバイナリ交差エントロピーを決定します。サポート関数 objectnessLoss を参照してください。

  • 予測されたオブジェクトのクラスとターゲットとのバイナリ交差エントロピーを決定します。サポート関数 classConfidenceLoss を参照してください。

  • すべての損失を加算して合計損失を計算します。

  • 合計損失に対する学習可能パラメーターの勾配を計算します。

function [gradients, state, info] = modelGradients(detector, XTrain, YTrain, penaltyThreshold)
inputImageSize = size(XTrain,1:2);

% Gather the ground truths in the CPU for post processing
YTrain = gather(extractdata(YTrain));

% Extract the predictions from the detector.
[gatheredPredictions, YPredCell, state] = forward(detector, XTrain);

% Generate target for predictions from the ground truth data.
[boxTarget, objectnessTarget, classTarget, objectMaskTarget, boxErrorScale] = generateTargets(gatheredPredictions,...
    YTrain, inputImageSize, detector.AnchorBoxes, penaltyThreshold);

% Compute the loss.
boxLoss = bboxOffsetLoss(YPredCell(:,[2 3 7 8]),boxTarget,objectMaskTarget,boxErrorScale);
objLoss = objectnessLoss(YPredCell(:,1),objectnessTarget,objectMaskTarget);
clsLoss = classConfidenceLoss(YPredCell(:,6),classTarget,objectMaskTarget);
totalLoss = boxLoss + objLoss + clsLoss;

info.boxLoss = boxLoss;
info.objLoss = objLoss;
info.clsLoss = clsLoss;
info.totalLoss = totalLoss;

% Compute gradients of learnables with regard to loss.
gradients = dlgradient(totalLoss, detector.Learnables);
end

function boxLoss = bboxOffsetLoss(boxPredCell, boxDeltaTarget, boxMaskTarget, boxErrorScaleTarget)
% Mean squared error for bounding box position.
lossX = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,1),boxDeltaTarget(:,1),boxMaskTarget(:,1),boxErrorScaleTarget));
lossY = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,2),boxDeltaTarget(:,2),boxMaskTarget(:,1),boxErrorScaleTarget));
lossW = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,3),boxDeltaTarget(:,3),boxMaskTarget(:,1),boxErrorScaleTarget));
lossH = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,4),boxDeltaTarget(:,4),boxMaskTarget(:,1),boxErrorScaleTarget));
boxLoss = lossX+lossY+lossW+lossH;
end

function objLoss = objectnessLoss(objectnessPredCell, objectnessDeltaTarget, boxMaskTarget)
% Binary cross-entropy loss for objectness score.
objLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'TargetCategories','independent'),objectnessPredCell,objectnessDeltaTarget,boxMaskTarget(:,2)));
end

function clsLoss = classConfidenceLoss(classPredCell, classTarget, boxMaskTarget)
% Binary cross-entropy loss for class confidence score.
clsLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'TargetCategories','independent'),classPredCell,classTarget,boxMaskTarget(:,3)));
end

拡張関数とデータ処理関数

function data = augmentData(A)
% Apply random horizontal flipping, and random X/Y scaling. Boxes that get
% scaled outside the bounds are clipped if the overlap is above 0.25. Also,
% jitter image color.

data = cell(size(A));
for ii = 1:size(A,1)
    I = A{ii,1};
    bboxes = A{ii,2};
    labels = A{ii,3};
    sz = size(I);

    if numel(sz) == 3 && sz(3) == 3
        I = jitterColorHSV(I,...
            'Contrast',0.0,...
            'Hue',0.1,...
            'Saturation',0.2,...
            'Brightness',0.2);
    end
    
    % Randomly flip image.
    tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]);
    rout = affineOutputView(sz,tform,'BoundsStyle','centerOutput');
    I = imwarp(I,tform,'OutputView',rout);
    
    % Apply same transform to boxes.
    [bboxes,indices] = bboxwarp(bboxes,tform,rout,'OverlapThreshold',0.25);
    labels = labels(indices);
    
    % Return original data only when all boxes are removed by warping.
    if isempty(indices)
        data(ii,:) = A(ii,:);
    else
        data(ii,:) = {I, bboxes, labels};
    end
end
end


function data = preprocessData(data, targetSize)
% Resize the images and scale the pixels to between 0 and 1. Also scale the
% corresponding bounding boxes.

for ii = 1:size(data,1)
    I = data{ii,1};
    imgSize = size(I);
    
    % Convert an input image with single channel to 3 channels.
    if numel(imgSize) < 3 
        I = repmat(I,1,1,3);
    end
    bboxes = data{ii,2};

    I = im2single(imresize(I,targetSize(1:2)));
    scale = targetSize(1:2)./imgSize(1:2);
    bboxes = bboxresize(bboxes,scale);
    
    data(ii, 1:2) = {I, bboxes};
end
end

function [XTrain, YTrain] = createBatchData(data, groundTruthBoxes, groundTruthClasses, classNames)
% Returns images combined along the batch dimension in XTrain and
% normalized bounding boxes concatenated with classIDs in YTrain

% Concatenate images along the batch dimension.
XTrain = cat(4, data{:,1});

% Get class IDs from the class names.
classNames = repmat({categorical(classNames')}, size(groundTruthClasses));
[~, classIndices] = cellfun(@(a,b)ismember(a,b), groundTruthClasses, classNames, 'UniformOutput', false);

% Append the label indexes and training image size to scaled bounding boxes
% and create a single cell array of responses.
combinedResponses = cellfun(@(bbox, classid)[bbox, classid], groundTruthBoxes, classIndices, 'UniformOutput', false);
len = max( cellfun(@(x)size(x,1), combinedResponses ) );
paddedBBoxes = cellfun( @(v) padarray(v,[len-size(v,1),0],0,'post'), combinedResponses, 'UniformOutput',false);
YTrain = cat(4, paddedBBoxes{:,1});
end

学習率スケジュール関数

function currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs)
% The piecewiseLearningRateWithWarmup function computes the current
% learning rate based on the iteration number.
persistent warmUpEpoch;

if iteration <= warmupPeriod
    % Increase the learning rate for number of iterations in warmup period.
    currentLR = learningRate * ((iteration/warmupPeriod)^4);
    warmUpEpoch = epoch;
elseif iteration >= warmupPeriod && epoch < warmUpEpoch+floor(0.6*(numEpochs-warmUpEpoch))
    % After warm up period, keep the learning rate constant if the remaining number of epochs is less than 60 percent. 
    currentLR = learningRate;
    
elseif epoch >= warmUpEpoch + floor(0.6*(numEpochs-warmUpEpoch)) && epoch < warmUpEpoch+floor(0.9*(numEpochs-warmUpEpoch))
    % If the remaining number of epochs is more than 60 percent but less
    % than 90 percent multiply the learning rate by 0.1.
    currentLR = learningRate*0.1;
    
else
    % If remaining epochs are more than 90 percent multiply the learning
    % rate by 0.01.
    currentLR = learningRate*0.01;
end

end

ユーティリティ関数

function [lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(f)
% Create the subplots to display the loss and learning rate.
figure(f);
clf
subplot(2,1,1);
ylabel('Learning Rate');
xlabel('Iteration');
learningRatePlotter = animatedline;
subplot(2,1,2);
ylabel('Total Loss');
xlabel('Iteration');
lossPlotter = animatedline;
end

function displayLossInfo(epoch, iteration, currentLR, lossInfo)
% Display loss information for each iteration.
disp("Epoch : " + epoch + " | Iteration : " + iteration + " | Learning Rate : " + currentLR + ...
   " | Total Loss : " + double(gather(extractdata(lossInfo.totalLoss))) + ...
   " | Box Loss : " + double(gather(extractdata(lossInfo.boxLoss))) + ...
   " | Object Loss : " + double(gather(extractdata(lossInfo.objLoss))) + ...
   " | Class Loss : " + double(gather(extractdata(lossInfo.clsLoss))));
end

function updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, totalLoss)
% Update loss and learning rate plots.
addpoints(lossPlotter, iteration, double(extractdata(gather(totalLoss))));
addpoints(learningRatePlotter, iteration, currentLR);
drawnow
end

function detector = downloadPretrainedYOLOv3Detector()
% Download a pretrained yolov3 detector.
if ~exist('yolov3SqueezeNetVehicleExample_21aSPKG.mat', 'file')
    if ~exist('yolov3SqueezeNetVehicleExample_21aSPKG.zip', 'file')
        disp('Downloading pretrained detector...');
        pretrainedURL = 'https://ssd.mathworks.com/supportfiles/vision/data/yolov3SqueezeNetVehicleExample_21aSPKG.zip';
        websave('yolov3SqueezeNetVehicleExample_21aSPKG.zip', pretrainedURL);
    end
    unzip('yolov3SqueezeNetVehicleExample_21aSPKG.zip');
end
pretrained = load("yolov3SqueezeNetVehicleExample_21aSPKG.mat");
detector = pretrained.detector;
end

参考文献

[1] Redmon, Joseph, and Ali Farhadi. "YOLOv3: An Incremental Improvement." Preprint, submitted April 8, 2018. https://arxiv.org/abs/1804.02767.

参考

| | | | (Deep Learning Toolbox) | |

関連するトピック