メインコンテンツ

車両検出のための YOLO v2 ネットワークの学習

車両検出用の学習データをワークスペースに読み込みます。

data = load("vehicleTrainingData.mat");
trainingData = data.vehicleTrainingData;

学習サンプルが保存されているディレクトリを指定します。ファイル名の絶対パスを学習データに追加します。

dataDir = fullfile(toolboxdir("vision"),"visiondata");
trainingData.imageFilename = fullfile(dataDir,trainingData.imageFilename);

学習のためにデータをランダムにシャッフルします。

rng(0)
shuffledIdx = randperm(height(trainingData));
trainingData = trainingData(shuffledIdx,:);

table のファイルを使用して imageDatastore を作成します。

imds = imageDatastore(trainingData.imageFilename);

table のラベル列を使用して boxLabelDatastore を作成します。

blds = boxLabelDatastore(trainingData(:,2:end));

データストアを統合します。

ds = combine(imds,blds);

table のラベル列を使用してクラス名を指定します。

classes = trainingData.Properties.VariableNames(2:end);

アンカー ボックスを指定します。

anchorBoxes = [8 8; 32 48; 40 24; 72 48];

事前に初期化された YOLO v2 オブジェクト検出ネットワークを読み込みます。

load("yolov2VehicleDetectorNet.mat","net");

YOLO v2 オブジェクト検出ネットワークを作成します。

detector = yolov2ObjectDetector(net,classes,anchorBoxes)
detector = 
  yolov2ObjectDetector with properties:

                  Network: [1×1 dlnetwork]
                InputSize: [128 128 3]
        TrainingImageSize: [128 128]
              AnchorBoxes: [4×2 double]
               ClassNames: vehicle
    ReorganizeLayerSource: ''
              LossFactors: [5 1 1 1]
                ModelName: ''

ネットワーク学習オプションを構成します。

options = trainingOptions("sgdm", ...
    InitialLearnRate=0.001, ...
    Verbose=true, ...
    MiniBatchSize=16, ...
    MaxEpochs=30, ...
    Shuffle="never", ...
    VerboseFrequency=30, ...
    CheckpointPath=tempdir);

YOLO v2 ネットワークに学習させます。

[trainedDetector,info] = trainYOLOv2ObjectDetector(ds,detector,options);
*************************************************************************
Training a YOLO v2 Object Detector for the following object classes:

* vehicle

Training on single CPU.
|========================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |     RMSE     |     Loss     |      Rate       |
|========================================================================================|
|       1 |           1 |       00:00:01 |         7.13 |         50.8 |          0.0010 |
|       2 |          30 |       00:00:15 |         1.32 |          1.8 |          0.0010 |
|       4 |          60 |       00:00:31 |         0.93 |          0.9 |          0.0010 |
|       5 |          90 |       00:00:45 |         0.64 |          0.4 |          0.0010 |
|       7 |         120 |       00:01:00 |         0.58 |          0.3 |          0.0010 |
|       9 |         150 |       00:01:14 |         0.64 |          0.4 |          0.0010 |
|      10 |         180 |       00:01:27 |         0.46 |          0.2 |          0.0010 |
|      12 |         210 |       00:01:41 |         0.40 |          0.2 |          0.0010 |
|      14 |         240 |       00:01:54 |         0.58 |          0.3 |          0.0010 |
|      15 |         270 |       00:02:08 |         0.40 |          0.2 |          0.0010 |
|      17 |         300 |       00:02:21 |         0.37 |          0.1 |          0.0010 |
|      19 |         330 |       00:02:35 |         0.50 |          0.2 |          0.0010 |
|      20 |         360 |       00:02:47 |         0.37 |          0.1 |          0.0010 |
|      22 |         390 |       00:03:00 |         0.36 |          0.1 |          0.0010 |
|      24 |         420 |       00:03:14 |         0.43 |          0.2 |          0.0010 |
|      25 |         450 |       00:03:26 |         0.54 |          0.3 |          0.0010 |
|      27 |         480 |       00:03:39 |         0.54 |          0.3 |          0.0010 |
|      29 |         510 |       00:03:52 |         0.66 |          0.4 |          0.0010 |
|      30 |         540 |       00:04:05 |         0.38 |          0.1 |          0.0010 |
|========================================================================================|
Training finished: Max epochs completed.
Detector training complete.
*************************************************************************

反復ごとの学習損失を調べ、学習精度を確認します。

figure
plot(info.TrainingLoss)
grid on
xlabel("Number of Iterations")
ylabel("Training Loss for Each Iteration")

Figure contains an axes object. The axes object with xlabel Number of Iterations, ylabel Training Loss for Each Iteration contains an object of type line.

テスト イメージをワークスペースに読み取ります。

img = imread("detectcars.png");

学習済みの YOLO v2 オブジェクト検出器をテスト イメージに対して実行し、車両検出を行います。

[bboxes,scores] = detect(trainedDetector,img);

検出結果を表示します。

if(~isempty(bboxes))
    img = insertObjectAnnotation(img,"rectangle",bboxes,scores);
end
figure
imshow(img)

Figure contains an axes object. The hidden axes object contains an object of type image.