Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

evaluateDetectionPrecision

オブジェクト検出の適合率メトリクスの評価

説明

averagePrecision = evaluateDetectionPrecision(detectionResults,groundTruthData) は、groundTruthData と比較した detectionResults の平均適合率を返します。平均適合率を使用して、オブジェクト検出器のパフォーマンスを測定できます。マルチクラス検出器の場合、関数は averagePrecisiongroundTruthData で指定された順序で各オブジェクト クラスのスコアのベクトルとして返します。

[averagePrecision,recall,precision] = evaluateDetectionPrecision(___) は、前の構文の入力引数を使用して、適合率/再現率曲線をプロットするデータ点を返します。

[___] = evaluateDetectionPrecision(___,threshold) は、検出をグラウンド トゥルース ボックスに割り当てるためのオーバーラップしきい値を指定します。

すべて折りたたむ

この例では、事前学習済みの YOLO v2 オブジェクト検出器を評価する方法を説明します。

車両のグラウンド トゥルース データの読み込み

車両の学習データを含む table を読み込みます。1 列目には学習イメージが含まれ、残りの列にはラベル付き境界ボックスが含まれています。

data = load('vehicleTrainingData.mat');
trainingData = data.vehicleTrainingData;

ローカルの車両データ フォルダーへの絶対パスを追加します。

dataDir = fullfile(toolboxdir('vision'), 'visiondata');
trainingData.imageFilename = fullfile(dataDir, trainingData.imageFilename);

table のファイルを使用して imageDatastore を作成します。

imds = imageDatastore(trainingData.imageFilename);

table のラベル列を使用して boxLabelDatastore を作成します。

blds = boxLabelDatastore(trainingData(:,2:end));

検出用の YOLOv2 検出器の読み込み

学習用の layerGraph を含む検出器を読み込みます。

vehicleDetector = load('yolov2VehicleDetector.mat');
detector = vehicleDetector.detector;

結果の評価およびプロット

imageDatastore を指定して検出器を実行します。

results = detect(detector, imds);

グラウンド トゥルース データに対して結果を評価します。

[ap, recall, precision] = evaluateDetectionPrecision(results, blds);

適合率/再現率曲線をプロットします。

figure;
plot(recall, precision);
grid on
title(sprintf('Average precision = %.1f', ap))

プリロードされたグラウンド トゥルース情報を使用して ACF ベースの検出器を学習させます。学習イメージで検出器を実行します。検出器を評価し、適合率/再現率曲線を表示します。

グラウンド トゥルースの table を読み込みます。

load('stopSignsAndCars.mat')
stopSigns = stopSignsAndCars(:,1:2);
stopSigns.imageFilename = fullfile(toolboxdir('vision'),'visiondata', ...
    stopSigns.imageFilename);

ACF ベースの検出器を学習させます。

detector = trainACFObjectDetector(stopSigns,'NegativeSamplesFactor',2);
ACF Object Detector Training
The training will take 4 stages. The model size is 34x31.
Sample positive examples(~100% Completed)
Compute approximation coefficients...Completed.
Compute aggregated channel features...Completed.
--------------------------------------------
Stage 1:
Sample negative examples(~100% Completed)
Compute aggregated channel features...Completed.
Train classifier with 42 positive examples and 84 negative examples...Completed.
The trained classifier has 19 weak learners.
--------------------------------------------
Stage 2:
Sample negative examples(~100% Completed)
Found 84 new negative examples for training.
Compute aggregated channel features...Completed.
Train classifier with 42 positive examples and 84 negative examples...Completed.
The trained classifier has 20 weak learners.
--------------------------------------------
Stage 3:
Sample negative examples(~100% Completed)
Found 84 new negative examples for training.
Compute aggregated channel features...Completed.
Train classifier with 42 positive examples and 84 negative examples...Completed.
The trained classifier has 54 weak learners.
--------------------------------------------
Stage 4:
Sample negative examples(~100% Completed)
Found 84 new negative examples for training.
Compute aggregated channel features...Completed.
Train classifier with 42 positive examples and 84 negative examples...Completed.
The trained classifier has 61 weak learners.
--------------------------------------------
ACF object detector training is completed. Elapsed time is 19.2586 seconds.

結果を保存する table を作成します。

numImages = height(stopSigns);
results = table('Size',[numImages 2],...
       'VariableTypes',{'cell','cell'},...
       'VariableNames',{'Boxes','Scores'}); 

学習イメージで検出器を実行します。結果を table として保存します。

for i = 1 : numImages
    I = imread(stopSigns.imageFilename{i});
    [bboxes, scores] = detect(detector,I);
    results.Boxes{i} = bboxes;
    results.Scores{i} = scores;
end 

グラウンド トゥルース データに対して結果を評価します。適合率の統計を取得します。

[ap,recall,precision] = evaluateDetectionPrecision(results,stopSigns(:,2));

適合率/再現率曲線をプロットします。

figure
plot(recall,precision)
grid on
title(sprintf('Average Precision = %.1f',ap))

入力引数

すべて折りたたむ

オブジェクトの位置とスコア。検出された各オブジェクトの境界ボックスとスコアを含む 2 列のテーブルとして指定します。マルチクラス検出の場合、3 番目の列には各検出の予測ラベルが含まれます。境界ボックスは M 行 4 列の cell 配列に保存しなければなりません。スコアは M 行 1 列の cell 配列に保存し、ラベルは categorical ベクトルとして保存しなければなりません。

オブジェクトを検出するときに、imageDatastore を使用して検出結果の table を作成できます。

        ds = imageDatastore(stopSigns.imageFilename);
        detectionResults = detect(detector,ds);

データ型: table

ラベル付きのグラウンド トゥルース。データストアまたは table として指定します。

各境界ボックスは [x,y,width,height] の形式でなければなりません。

  • データストア — 関数 read および関数 readall が、境界ボックスとラベルの cell ベクトルの列を 2 つ以上もつ cell 配列または table を返すデータストア。境界ボックスは、[x,y,width,height] 形式の M 行 4 列の行列の cell 配列内になければなりません。データストアの関数 read および関数 readall は、次のいずれかの形式を返さなければなりません。

    • {boxes,labels} — boxLabelDatastore はこのタイプのデータストアを作成します。

    • {images,boxes,labels} — 統合されたデータストア。たとえば、combine(imds,blds) を使用しています。

    boxLabelDatastore を参照してください。

  • table — 1 つ以上の列。すべての列に境界ボックスが含まれています。各列は、stopSign、carRear、carFront などの単一のオブジェクト クラスを表す M 行 4 列の行列を含む cell ベクトルでなければなりません。これらの列には、[x,y,width,height] 形式の、M 個の境界ボックスの 4 要素 double 配列が含まれます。この形式は、対応するイメージでの境界ボックスの左上隅の位置とサイズを指定します。

グラウンド トゥルース ボックスに割り当てられた検出のオーバーラップしきい値。数値スカラーとして指定します。オーバーラップ率は、Intersection over Union として計算されます。

出力引数

すべて折りたたむ

すべての検出結果の平均適合率。数値スカラーまたはベクトルとして返されます。"適合率" は、グラウンド トゥルースに基づく、検出器内オブジェクトのすべての陽性インスタンスに対する真陽性インスタンスの比率です。マルチクラス検出器の場合、平均適合率は各オブジェクト クラスの平均適合率スコアのベクトルです。

各検出からの再現率値。M 行 1 列の数値スカラーのベクトルまたは cell 配列として返されます。M の長さは、クラスに割り当てられた検出数に 1 を加えた値と等しくなります。たとえば、検出結果にクラス ラベル 'car' を持つ 4 つの検出が含まれている場合、recall には 5 つの要素が含まれます。再現率の最初の値は常に 0 です。

"再現率" は、グラウンド トゥルースに基づいた、検出器内の真陽性と偽陰性の合計に対する真陽性インスタンスの比率です。マルチクラス検出器の場合、recallprecision は cell 配列で、各セルには各オブジェクト クラスのデータ点が含まれます。

各検出からの適合率値。M 行 1 列の数値スカラーのベクトルまたは cell 配列として返されます。M の長さは、クラスに割り当てられた検出数に 1 を加えた値と等しくなります。たとえば、検出結果にクラス ラベル 'car' を持つ 4 つの検出が含まれている場合、precision には 5 つの要素が含まれます。precision の最初の値は常に 1 です。

"適合率" は、グラウンド トゥルースに基づく、検出器内オブジェクトのすべての陽性インスタンスに対する真陽性インスタンスの比率です。マルチクラス検出器の場合、recallprecision は cell 配列で、各セルには各オブジェクト クラスのデータ点が含まれます。

R2017a で導入