このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
YOLO v4 深層学習を使用したオブジェクトの検出
この例では、You Only Look Once version 4 (YOLO v4) 深層学習ネットワークを使用して、イメージ内のオブジェクトを検出する方法を説明します。この例では、次の作業を行います。
YOLO v4 オブジェクト検出ネットワークの学習、検証、テスト用のデータセットを構成。また、ネットワーク効率を向上させるため、学習データセットのデータ拡張を実行。
YOLO v4 オブジェクト検出ネットワークの学習で使用するため、学習データからアンカー ボックスを計算。
関数
yolov4ObjectDetector
を使用して YOLO v4 オブジェクト検出器を作成し、関数trainYOLOv4ObjectDetector
を使用して検出器の学習を実行。
この例では、イメージに含まれる車両を検出するための事前学習済みの YOLO v4 オブジェクト検出器についても説明します。この事前学習済みのネットワークは、バックボーン ネットワークとして tiny-yolov4-coco を使用し、車両データセットで学習を行っています。YOLO v4 オブジェクト検出ネットワークの詳細については、YOLO v4 入門 (Computer Vision Toolbox)を参照してください。
データセットの読み込み
この例では、295 個のイメージを含んだ小さな車両データセットを使用します。これらのイメージの多くは、Caltech の Cars 1999 データセットおよび Cars 2001 データセットからのものです (Caltech Computational Vision の Web サイトで入手可能)。Pietro Perona 氏によって作成されたもので、許可を得て使用しています。各イメージには、1 個または 2 個のラベル付けされた車両インスタンスが含まれています。小さなデータセットは YOLO v4 の学習手順を調べるうえで役立ちますが、実際にロバストな検出器に学習させるにはラベル付けされたイメージがより多く必要になります。
車両のイメージを解凍し、車両のグラウンド トゥルース データを読み込みます。
unzip vehicleDatasetImages.zip data = load("vehicleDatasetGroundTruth.mat"); vehicleDataset = data.vehicleDataset;
車両データは 2 列の table に保存されています。1 列目にはイメージ ファイルのパスが含まれ、2 列目には境界ボックスが含まれています。
データ セットの最初の数行を表示します。
vehicleDataset(1:4,:)
ans=4×2 table
imageFilename vehicle
_________________________________ _________________
{'vehicleImages/image_00001.jpg'} {[220 136 35 28]}
{'vehicleImages/image_00002.jpg'} {[175 126 61 45]}
{'vehicleImages/image_00003.jpg'} {[108 120 45 33]}
{'vehicleImages/image_00004.jpg'} {[124 112 38 36]}
ローカルの車両データ フォルダーへの絶対パスを追加します。
vehicleDataset.imageFilename = fullfile(pwd,vehicleDataset.imageFilename);
データセットは学習、検証、テスト用のセットに分割します。データの 60% を学習用に、10% を検証用に、残りを学習済みの検出器のテスト用に選択します。
rng("default");
shuffledIndices = randperm(height(vehicleDataset));
idx = floor(0.6 * length(shuffledIndices) );
trainingIdx = 1:idx;
trainingDataTbl = vehicleDataset(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = vehicleDataset(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = vehicleDataset(shuffledIndices(testIdx),:);
imageDatastore
および boxLabelDatastore
を使用して、学習および評価中にイメージとラベル データを読み込むデータストアを作成します。
imdsTrain = imageDatastore(trainingDataTbl{:,"imageFilename"}); bldsTrain = boxLabelDatastore(trainingDataTbl(:,"vehicle")); imdsValidation = imageDatastore(validationDataTbl{:,"imageFilename"}); bldsValidation = boxLabelDatastore(validationDataTbl(:,"vehicle")); imdsTest = imageDatastore(testDataTbl{:,"imageFilename"}); bldsTest = boxLabelDatastore(testDataTbl(:,"vehicle"));
イメージ データストアとボックス ラベル データストアを組み合わせます。
trainingData = combine(imdsTrain,bldsTrain); validationData = combine(imdsValidation,bldsValidation); testData = combine(imdsTest,bldsTest);
データ セットに以下のいずれかが含まれる場合、validateInputData
を使用して、無効なイメージ、境界ボックス、またはラベルを検出します。
無効なイメージ形式または NaN 値を含むサンプル
ゼロ/NaN 値/Inf 値を含むか、空である境界ボックス
欠損ラベルまたは非カテゴリカル ラベル
境界ボックスの値は、有限の正の整数でなければならず、かつ NaN であってはなりません。境界ボックスの高さと幅の値は、正でなければならず、イメージ境界の内側に収まっていなければなりません。
validateInputData(trainingData); validateInputData(validationData); validateInputData(testData);
学習イメージとボックス ラベルのうちの 1 つを表示します。
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,"Rectangle",bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
reset(trainingData);
YOLO v4 オブジェクト検出器ネットワークの作成
学習に使用するネットワーク入力サイズを指定します。
inputSize = [416 416 3];
検出するオブジェクト クラスの名前を指定します。
className = "vehicle";
関数estimateAnchorBoxes
(Computer Vision Toolbox)を使用して、学習データ内のオブジェクトのサイズに基づいてアンカー ボックスを推定します。学習前のイメージのサイズ変更を考慮するには、アンカー ボックスを推定する学習データのサイズを変更します。関数 transform
を使用して学習データの前処理を行い、アンカー ボックスの数を定義してアンカー ボックスを推定します。補助関数 preprocessData
を使用して、ネットワークの入力サイズに合わせて学習データのサイズを変更します。
rng("default")
trainingDataForEstimation = transform(trainingData,@(data)preprocessData(data,inputSize));
numAnchors = 6;
[anchors,meanIoU] = estimateAnchorBoxes(trainingDataForEstimation,numAnchors);
すべての検出ヘッドで使用するアンカー ボックスとして、引数 anchorBoxes
を指定します。アンカー ボックスは [
M
x 1]
の cell 配列として指定します。ここで、M は検出ヘッドの数を表します。各検出ヘッドは、引数 anchors
に格納される [
N
x 2]
の行列で構成されます。ここで、N
は使用するアンカーの数です。特徴マップのサイズに基づいて、各検出ヘッドの anchorBoxes
を指定します。スケールが小さい場合は大きいアンカーを使用し、スケールが大きい場合は小さいアンカーを使用します。これを行うには、アンカーを面積ごとに降順に並べ替え、最初の 3 つを最初の検出ヘッドに割り当て、最後の 3 つを 2 番目の検出ヘッドに割り当てます。
area = anchors(:, 1).*anchors(:,2);
[~,idx] = sort(area,"descend");
anchors = anchors(idx,:);
anchorBoxes = {anchors(1:3,:)
anchors(4:6,:)};
アンカー ボックスの選択の詳細については、学習データからのアンカー ボックスの推定 (Computer Vision Toolbox) (Computer Vision Toolbox™) およびアンカー ボックスによるオブジェクトの検出 (Computer Vision Toolbox)を参照してください。
関数 yolov4ObjectDetector
を使用して、YOLO v4 オブジェクト検出器を作成します。COCO データセットで学習させた事前学習済みの YOLO v4 検出ネットワークの名前を指定します。クラス名および推定されるアンカー ボックスを指定します。
detector = yolov4ObjectDetector("tiny-yolov4-coco",className,anchorBoxes,InputSize=inputSize);
データ拡張の実行
学習精度を上げるため、データ拡張を実行します。関数 transform
を使用して、カスタムのデータ拡張を学習データに適用します。補助関数 augmentData
によって、入力データに以下の拡張が適用されます。
HSV 空間でのカラー ジッターの付加
水平方向のランダムな反転
10% のランダムなスケーリング
データ拡張は、テスト データと検証データには適用されないことに注意してください。理想的には、テスト データと検証データは元のデータを代表するもので、バイアスのない評価を行うために変更なしで使用されなければなりません。
augmentedTrainingData = transform(trainingData,@augmentData);
拡張された学習データのサンプルを読み取って表示します。
augmentedData = cell(4,1); for k = 1:4 data = read(augmentedTrainingData); augmentedData{k} = insertShape(data{1},"rectangle",data{2}); reset(augmentedTrainingData); end figure montage(augmentedData,BorderSize=10)
学習オプションの指定
trainingOptions
を使用してネットワーク学習オプションを指定します。Adam ソルバーを使用して、一定の学習率 0.001 でオブジェクト検出器を 80 エポック学習させます。検証損失が最小となる学習済み検出器を得るには、OutputNetwork
を "best-validation-loss"
に設定します。ValidationData
を検証データに設定し、ValidationFrequency
を 1000 に設定します。データをより頻繁に検証するには、ValidationFrequency
を減らして学習時間を増やします。ExecutionEnvironment
を使用して、ネットワークの学習に使用するハードウェア リソースを決定します。ExecutionEnvironment
の既定値は "auto"
で、利用可能な場合は GPU が選択され、そうでない場合は CPU が選択されます。学習プロセス中に部分的に学習させた検出器を保存できるように、CheckpointPath
を一時的な場所に設定します。停電やシステム障害などで学習が中断された場合に、保存したチェックポイントから学習を再開できます。
options = trainingOptions("adam", ... GradientDecayFactor=0.9, ... SquaredGradientDecayFactor=0.999, ... InitialLearnRate=0.001, ... LearnRateSchedule="none", ... MiniBatchSize=4, ... L2Regularization=0.0005, ... MaxEpochs=80, ... DispatchInBackground=true, ... ResetInputNormalization=true, ... Shuffle="every-epoch", ... VerboseFrequency=20, ... ValidationFrequency=1000, ... CheckpointPath=tempdir, ... ValidationData=validationData, ... OutputNetwork="best-validation-loss");
YOLO v4 オブジェクト検出器の学習
関数 trainYOLOv4ObjectDetector
を使用して YOLO v4 オブジェクト検出器に学習させます。この例は、24 GB メモリ搭載の NVIDIA™ RTX A5000 で実行されます。この設定を使用してこのネットワークに学習させるのに約 33 分かかりました。学習時間は使用するハードウェアによって異なります。ネットワークに学習させる代わりに、Computer Vision Toolbox™ に用意されている事前学習済みの YOLO v4 オブジェクト検出器を使用することもできます。
補助関数 downloadPretrainedYOLOv4Detector
を使用して、事前学習済みの検出器をダウンロードします。拡張した学習データで検出器に学習させるには、doTraining
の値を true
に設定します。
doTraining = false; if doTraining % Train the YOLO v4 detector. [detector,info] = trainYOLOv4ObjectDetector(augmentedTrainingData,detector,options); else % Load pretrained detector for the example. detector = downloadPretrainedYOLOv4Detector(); end
Downloading pretrained detector...
テスト イメージに対して検出器を実行します。
I = imread("highway.png");
[bboxes,scores,labels] = detect(detector,I);
結果を表示します。
I = insertObjectAnnotation(I,"rectangle",bboxes,scores);
figure
imshow(I)
テスト セットを使用した検出器の評価
大規模なイメージ セットで学習済みのオブジェクト検出器を評価し、パフォーマンスを測定します。Computer Vision Toolbox™ には、平均適合率や対数平均ミス率などの一般的なメトリクスを測定するためのオブジェクト検出器評価関数 (evaluateObjectDetection
(Computer Vision Toolbox)) が用意されています。この例では、平均適合率メトリクスを使用してパフォーマンスを評価します。平均適合率は、検出器が正しい分類を実行できること (適合率) と検出器がすべての関連オブジェクトを検出できること (再現率) を示す単一の数値です。
すべてのテスト イメージに対して検出器を実行します。できるだけ多くのオブジェクトを検出するには、検出しきい値を低い値に設定します。これは、検出器の適合率を、再現率の値の全範囲にわたって評価するのに役立ちます。
detectionResults = detect(detector,testData,Threshold=0.01);
平均適合率メトリクスを使用してオブジェクト検出器を評価します。
metrics = evaluateObjectDetection(detectionResults,testData); classID = 1; precision = metrics.ClassMetrics.Precision{classID}; recall = metrics.ClassMetrics.Recall{classID};
適合率/再現率 (PR) の曲線は、さまざまなレベルの再現率における検出器の適合率を示しています。すべてのレベルの再現率で適合率が 1 になるのが理想的です。より多くのデータを使用すると平均適合率を向上できますが、学習に必要な時間が長くなる場合があります。PR 曲線をプロットします。
figure plot(recall,precision) xlabel("Recall") ylabel("Precision") grid on title(sprintf("Average Precision = %.2f",metrics.ClassMetrics.mAP(classID)))
サポート関数
データ拡張を実行するための補助関数。
function data = augmentData(A) % Apply random horizontal flipping, and random X/Y scaling. Boxes that get % scaled outside the bounds are clipped if the overlap is above 0.25. Also, % jitter image color. data = cell(size(A)); for ii = 1:size(A,1) I = A{ii,1}; bboxes = A{ii,2}; labels = A{ii,3}; sz = size(I); if numel(sz) == 3 && sz(3) == 3 I = jitterColorHSV(I,... contrast=0.0,... Hue=0.1,... Saturation=0.2,... Brightness=0.2); end % Randomly flip image. tform = randomAffine2d(XReflection=true,Scale=[1 1.1]); rout = affineOutputView(sz,tform,BoundsStyle="centerOutput"); I = imwarp(I,tform,OutputView=rout); % Apply same transform to boxes. [bboxes,indices] = bboxwarp(bboxes,tform,rout,OverlapThreshold=0.25); labels = labels(indices); % Return original data only when all boxes are removed by warping. if isempty(indices) data(ii,:) = A(ii,:); else data(ii,:) = {I,bboxes,labels}; end end end function data = preprocessData(data,targetSize) % Resize the images and scale the pixels to between 0 and 1. Also scale the % corresponding bounding boxes. for ii = 1:size(data,1) I = data{ii,1}; imgSize = size(I); bboxes = data{ii,2}; I = im2single(imresize(I,targetSize(1:2))); scale = targetSize(1:2)./imgSize(1:2); bboxes = bboxresize(bboxes,scale); data(ii,1:2) = {I,bboxes}; end end
事前学習済みの YOLO v4 オブジェクト検出器をダウンロードするための補助関数。
function detector = downloadPretrainedYOLOv4Detector() % Download a pretrained yolov4 detector. if ~exist("yolov4TinyVehicleExample_24a.mat", "file") if ~exist("yolov4TinyVehicleExample_24a.zip", "file") disp("Downloading pretrained detector..."); pretrainedURL = "https://ssd.mathworks.com/supportfiles/vision/data/yolov4TinyVehicleExample_24a.zip"; websave("yolov4TinyVehicleExample_24a.zip", pretrainedURL); end unzip("yolov4TinyVehicleExample_24a.zip"); end pretrained = load("yolov4TinyVehicleExample_24a.mat"); detector = pretrained.detector; end
参考
アプリ
関数
trainYOLOv4ObjectDetector
(Computer Vision Toolbox) |estimateAnchorBoxes
(Computer Vision Toolbox) |analyzeNetwork
|combine
|transform
|read
|evaluateDetectionPrecision
(Computer Vision Toolbox)
オブジェクト
yolov4ObjectDetector
(Computer Vision Toolbox) |boxLabelDatastore
(Computer Vision Toolbox) |imageDatastore
|dlnetwork
|dlarray
関連するトピック
- アンカー ボックスによるオブジェクトの検出 (Computer Vision Toolbox)
- 学習データからのアンカー ボックスの推定 (Computer Vision Toolbox)
- ディープ ネットワーク デザイナーを使用した転移学習用のネットワークの準備
- 深層学習を使用したオブジェクト検出入門 (Computer Vision Toolbox)