Fast R-CNN 一時停止標識検出器の学習
学習データを読み込みます。
data = load('rcnnStopSigns.mat', 'stopSigns', 'fastRCNNLayers'); stopSigns = data.stopSigns; fastRCNNLayers = data.fastRCNNLayers;
イメージ ファイルに絶対パスを追加します。
stopSigns.imageFilename = fullfile(toolboxdir('vision'),'visiondata', ... stopSigns.imageFilename);
学習のためにデータをランダムにシャッフルします。
rng(0); shuffledIdx = randperm(height(stopSigns)); stopSigns = stopSigns(shuffledIdx,:);
table のファイルを使用して imageDatastore を作成します。
imds = imageDatastore(stopSigns.imageFilename);
table のラベル列を使用して boxLabelDatastore を作成します。
blds = boxLabelDatastore(stopSigns(:,2:end));
データストアを統合します。
ds = combine(imds, blds);
一時停止標識の学習イメージのサイズが異なっています。データを前処理して、イメージとボックスのサイズを事前定義されたサイズに変更します。
ds = transform(ds,@(data)preprocessData(data,[920 968 3]));
ネットワーク学習オプションを設定します。
options = trainingOptions('sgdm', ... 'MiniBatchSize', 10, ... 'InitialLearnRate', 1e-3, ... 'MaxEpochs', 10, ... 'CheckpointPath', tempdir);
Fast R-CNN 検出器に学習させます。学習は、完了するのに 2 ~ 3 分かかることがあります。
frcnn = trainFastRCNNObjectDetector(ds, fastRCNNLayers , options, ... 'NegativeOverlapRange', [0 0.1], ... 'PositiveOverlapRange', [0.7 1]);
******************************************************************* Training a Fast R-CNN Object Detector for the following object classes: * stopSign --> Extracting region proposals from training datastore...done. Training on single GPU. |=======================================================================================================| | Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Mini-batch | Base Learning | | | | (hh:mm:ss) | Loss | Accuracy | RMSE | Rate | |=======================================================================================================| | 1 | 1 | 00:00:29 | 0.3787 | 93.59% | 0.96 | 0.0010 | | 10 | 10 | 00:05:14 | 0.3032 | 98.52% | 0.95 | 0.0010 | |=======================================================================================================| Detector training complete. *******************************************************************
Fast R-CNN 検出器をテスト イメージでテストします。
img = imread('stopSignTest.jpg');
検出器を実行します。
[bbox, score, label] = detect(frcnn, img);
検出結果を表示します。
detectedImg = insertObjectAnnotation(img,'rectangle',bbox,score);
figure
imshow(detectedImg)
サポート関数
function data = preprocessData(data,targetSize) % Resize image and bounding boxes to the targetSize. scale = targetSize(1:2)./size(data{1},[1 2]); data{1} = imresize(data{1},targetSize(1:2)); bboxes = round(data{2}); data{2} = bboxresize(bboxes,scale); end