Main Content

Fast R-CNN 一時停止標識検出器の学習

学習データを読み込みます。

data = load('rcnnStopSigns.mat', 'stopSigns', 'fastRCNNLayers');
stopSigns = data.stopSigns;
fastRCNNLayers = data.fastRCNNLayers;

イメージ ファイルに絶対パスを追加します。

stopSigns.imageFilename = fullfile(toolboxdir('vision'),'visiondata', ...
    stopSigns.imageFilename);

学習のためにデータをランダムにシャッフルします。

rng(0);
shuffledIdx = randperm(height(stopSigns));
stopSigns = stopSigns(shuffledIdx,:);

table のファイルを使用して imageDatastore を作成します。

imds = imageDatastore(stopSigns.imageFilename);

table のラベル列を使用して boxLabelDatastore を作成します。

blds = boxLabelDatastore(stopSigns(:,2:end));

データストアを統合します。

ds = combine(imds, blds);

一時停止標識の学習イメージのサイズが異なっています。データを前処理して、イメージとボックスのサイズを事前定義されたサイズに変更します。

ds = transform(ds,@(data)preprocessData(data,[920 968 3]));

ネットワーク学習オプションを設定します。

options = trainingOptions('sgdm', ...
    'MiniBatchSize', 10, ...
    'InitialLearnRate', 1e-3, ...
    'MaxEpochs', 10, ...
    'CheckpointPath', tempdir);

Fast R-CNN 検出器に学習させます。学習は、完了するのに 2 ~ 3 分かかることがあります。

frcnn = trainFastRCNNObjectDetector(ds, fastRCNNLayers , options, ...
    'NegativeOverlapRange', [0 0.1], ...
    'PositiveOverlapRange', [0.7 1]);
*******************************************************************
Training a Fast R-CNN Object Detector for the following object classes:

* stopSign

--> Extracting region proposals from training datastore...done.

Training on single GPU.
|=======================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Mini-batch  |  Mini-batch  |  Base Learning  |
|         |             |   (hh:mm:ss)   |     Loss     |   Accuracy   |     RMSE     |      Rate       |
|=======================================================================================================|
|       1 |           1 |       00:00:29 |       0.3787 |       93.59% |         0.96 |          0.0010 |
|      10 |          10 |       00:05:14 |       0.3032 |       98.52% |         0.95 |          0.0010 |
|=======================================================================================================|

Detector training complete.
*******************************************************************

Fast R-CNN 検出器をテスト イメージでテストします。

img = imread('stopSignTest.jpg');

検出器を実行します。

[bbox, score, label] = detect(frcnn, img);

検出結果を表示します。

detectedImg = insertObjectAnnotation(img,'rectangle',bbox,score);
figure
imshow(detectedImg)

サポート関数

function data = preprocessData(data,targetSize)
% Resize image and bounding boxes to the targetSize.
scale = targetSize(1:2)./size(data{1},[1 2]);
data{1} = imresize(data{1},targetSize(1:2));
bboxes = round(data{2});
data{2} = bboxresize(bboxes,scale);
end