Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

説明可能 FCDD ネットワークを使用したイメージ異常の検出

この例では、1 クラスの完全畳み込みデータ記述 (Fully Convolutional Data Description: FCDD) 異常検出ネットワークを使用して、錠剤イメージの欠陥を検出する方法を示します。

異常検出の重要な目標は、学習済みネットワークがイメージを異常として分類する理由を人間の観察者が理解できるようにすることです。FCDD では "説明可能な分類" を行うことができます。これは、ニューラル ネットワークが分類決定にどのように到達したかを正当化する情報でクラス予測を補足するものです [1]。FCDD ネットワークは、各ピクセルの異常を示す確率のヒートマップを返します。分類器は、異常スコア ヒートマップの平均値に基づいて、イメージを正常または異常としてラベル付けします。

分類データ セット用の錠剤イメージのダウンロード

この例では、PillQC データ セットを使用します。データ セットには 3 つのクラスのイメージ (欠陥のない normal イメージ、錠剤に欠けのある chip イメージ、および汚れのある dirt イメージ) が格納されています。データ セットには、149 個の normal イメージ、43 個の chip イメージ、および 138 個の dirt イメージがあります。このデータ セットのサイズは 3.57 MB です。

dataDir をデータ セットの目的の場所として設定します。補助関数 downloadPillQCData を使用してデータ セットをダウンロードします。この関数は、この例にサポート ファイルとして添付されています。関数は ZIP ファイルをダウンロードし、データを chipdirt、および normal の各サブディレクトリに解凍します。

dataDir = fullfile(tempdir,"PillDefects");
downloadPillQCData(dataDir)

次のイメージは、各クラスのイメージの例を示しています。左は欠陥のない正常な錠剤、中央は汚れが付着した錠剤、右は欠けのある錠剤です。このデータ セット内のイメージには影、焦点のブレ、および背景色のバリエーションが含まれていますが、この例で使用したアプローチは、これらのイメージ取得アーティファクトに対してロバストです。

montageImage.png

データの読み込みと前処理

イメージ データを読み取って管理する imageDatastore を作成します。ディレクトリの名前に従って、各イメージに chipdirt、または normal のラベルを付けます。

imageDir = fullfile(dataDir,"pillQC-main","images");
imds = imageDatastore(imageDir,IncludeSubfolders=true,LabelSource="foldernames");

学習セット、キャリブレーション セット、テスト セットへのデータの分割

関数splitAnomalyDataを使用して、学習セット、キャリブレーション セット、およびテスト セットを作成します。この例では、Outlier Exposure (OE) を使用する FCDD アプローチを実装します。OE は、正常なイメージを主として、少数の異常なイメージを加えた学習データで構成されます。主に正常なシーンのみのサンプルを使用して学習させているにもかかわらず、モデルは正常なシーンと異常なシーンを区別する方法を学習します。

正常なイメージの 50% と各異常クラスのわずかな割合 (5%) を学習データ セットに割り当てます。正常なイメージの 10% と各異常クラスの 20% をキャリブレーション セットに割り当てます。残りのイメージをテスト セットに割り当てます。

normalTrainRatio  = 0.5;
anomalyTrainRatio = 0.05;
normalCalRatio  = 0.10;
anomalyCalRatio = 0.20;
normalTestRatio  = 1 - (normalTrainRatio + normalCalRatio);
anomalyTestRatio = 1 - (anomalyTrainRatio + anomalyCalRatio);

anomalyClasses = ["chip","dirt"];
[imdsTrain,imdsCal,imdsTest] = splitAnomalyData(imds,anomalyClasses, ...
    NormalLabelsRatio=[normalTrainRatio normalCalRatio normalTestRatio], ...
    AnomalyLabelsRatio=[anomalyTrainRatio anomalyCalRatio anomalyTestRatio]);
Splitting anomaly dataset
-------------------------
* Finalizing... Done.
* Number of files and proportions per class in all the datasets:

                     Input                  Train                Validation                Test        
              NumFiles     Ratio     NumFiles     Ratio      NumFiles     Ratio     NumFiles     Ratio 
              ___________________    ____________________    ___________________    ___________________

    chip         43        0.1303        2        0.02381        9       0.17647       32        0.1641
    dirt        138       0.41818        7       0.083333       28       0.54902      103       0.52821
    normal      149       0.45152       75        0.89286       14       0.27451       60       0.30769

さらに学習データを 2 つのデータストアに分割し、1 つは正常データのみを含み、もう 1 つは異常データのみを含むようにします。

[imdsNormalTrain,imdsAnomalyTrain] = splitAnomalyData(imdsTrain,anomalyClasses, ...
    NormalLabelsRatio=[1 0 0],AnomalyLabelsRatio=[0 1 0],Verbose=false);

学習データの拡張

関数transformを補助関数 augmentDataForPillAnomalyDetector によって指定されたカスタム前処理演算と共に使用して、学習データを拡張します。補助関数は、この例にサポート ファイルとして添付されています。

関数 augmentDataForPillAnomalyDetector は、各入力イメージに 90 度の回転と水平方向および垂直方向の反転をランダムに適用します。

imdsNormalTrain = transform(imdsNormalTrain,@augmentDataForPillAnomalyDetector);
imdsAnomalyTrain = transform(imdsAnomalyTrain,@augmentDataForPillAnomalyDetector);

関数 transform を補助関数 addLabelData で指定された演算と共に使用して、バイナリ ラベルをキャリブレーション データ セットとテスト データ セットに追加します。補助関数はこの例の最後で定義されており、normal クラスのイメージにバイナリ ラベル 0 を割り当て、chip クラスまたは dirt クラスのイメージにバイナリ ラベル 1 を割り当てます。

dsCal = transform(imdsCal,@addLabelData,IncludeInfo=true);
dsTest = transform(imdsTest,@addLabelData,IncludeInfo=true);

拡張された 9 つの学習イメージのサンプルを可視化します。

exampleData = readall(subset(imdsNormalTrain,1:9));
montage(exampleData(:,1));

FCDD モデルの作成

この例では、Fully Convolutional Data Description (FCDD) モデル [1] を使用します。FCDD の基本的な考え方は、入力イメージ内の各領域に異常なコンテンツが含まれる確率を記述する異常スコア マップを生成するようにネットワークに学習させることです。

関数pretrainedEncoderNetworkは、事前学習済みのバックボーンとして使用するために、ImageNet で事前に学習させた Inception-v3 ネットワークの最初の 3 つのダウンサンプリング ステージを返します。

backbone = pretrainedEncoderNetwork("inceptionv3",3);

Inception-v3 バックボーンで関数fcddAnomalyDetectorを使用して、FCDD 異常検出器ネットワークを作成します。

net = fcddAnomalyDetector(backbone);

ネットワークの学習または事前学習済みネットワークのダウンロード

既定では、この例は補助関数 downloadTrainedNetwork を使用して、FCDD 異常検出器の事前学習済みバージョンをダウンロードします。補助関数は、この例にサポート ファイルとして添付されています。事前学習済みのネットワークを使用することで、学習の完了を待たずに例全体を実行できます。

ネットワークに学習させるには、次のコードで変数 doTrainingtrue に設定します。フィールドに値を入力して、学習に使用するエポック数 numEpochs を指定します。関数trainFCDDAnomalyDetectorを使用して、モデルに学習させます。

可能であれば、1 つ以上の GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™、および CUDA® 対応の NVIDIA® GPU が必要です。詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。学習には NVIDIA Titan RTX™ で約 3 分を要します。

doTraining = false;
numEpochs = 200;
if doTraining
    options = trainingOptions("adam", ...
        Shuffle="every-epoch",...
        MaxEpochs=numEpochs,InitialLearnRate=1e-4, ...
        MiniBatchSize=32,...
        BatchNormalizationStatistics="moving");
    detector = trainFCDDAnomalyDetector(imdsNormalTrain,imdsAnomalyTrain,net,options);
    modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss"));
    save(fullfile(dataDir,"trainedPillAnomalyDetector-"+modelDateTime+".mat"),"detector");
else
    trainedPillAnomalyDetectorNet_url = "https://ssd.mathworks.com/supportfiles/"+ ...
        "vision/data/trainedFCDDPillAnomalyDetectorSpkg.zip";
    downloadTrainedNetwork(trainedPillAnomalyDetectorNet_url,dataDir);
    load(fullfile(dataDir,"folderForSupportFilesInceptionModel", ...
        "trainedPillFCDDNet.mat"));
end

異常のしきい値の設定

異常検出器の異常スコアしきい値を選択します。異常検出器は、スコアがしきい値を上回るか下回るかに基づいてイメージを分類します。この例では、正常なイメージと異常なイメージの両方を含むキャリブレーション データ セットを使用してしきい値を選定します。

キャリブレーション セット内の各イメージについて、平均異常スコアとグラウンド トゥルース ラベルを取得します。

scores = predict(detector,dsCal);
labels = imdsCal.Labels ~= "normal";

正常クラスと異常クラスの平均異常スコアのヒストグラムをプロットします。モデルが予測した異常スコアによって、分布がきれいに分かれます。

numBins = 20;
[~,edges] = histcounts(scores,numBins);
figure
hold on
hNormal = histogram(scores(labels==0),edges);
hAnomaly = histogram(scores(labels==1),edges);
hold off
legend([hNormal,hAnomaly],"Normal","Anomaly")
xlabel("Mean Anomaly Score")
ylabel("Counts")

関数anomalyThresholdを使用して、異常しきい値の最適値を計算します。最初の 2 つの入力引数を、キャリブレーション データ セットのグラウンド トゥルース ラベル (labels) および予測異常スコア (scores) として指定します。真陽性異常イメージの labels 値は true であるため、3 番目の入力引数を true として指定します。関数 anomalyThreshold は、検出器用の最適なしきい値と、rocmetrics (Deep Learning Toolbox)オブジェクトとして格納される受信者動作特性 (ROC) 曲線を返します。

[thresh,roc] = anomalyThreshold(labels,scores,true);

異常検出器の Threshold プロパティを最適な値に設定します。

detector.Threshold = thresh;

rocmetrics のオブジェクト関数plot (Deep Learning Toolbox)を使用して ROC をプロットします。ROC 曲線は、しきい値が取り得る範囲に対する分類器のパフォーマンスを示します。ROC 曲線上の各点は、異なるしきい値を使用してキャリブレーション セット イメージを分類する場合の偽陽性率 (x 座標) と真陽性率 (y 座標) を表します。青い実線は ROC 曲線を表します。赤い破線は、50% の成功率に対応するスキルなしの分類器を表します。ROC 曲線下面積 (AUC) メトリクスは分類器のパフォーマンスを示し、完全な分類器に対応する最大 ROC AUC は 1.0 となります。

plot(roc)
title("ROC AUC: "+ roc.AUC)

分類モデルの評価

テスト セット内の各イメージを正常または異常のいずれかに分類します。

testSetOutputLabels = classify(detector,dsTest);

各テスト イメージのグラウンド トゥルース ラベルを取得します。

testSetTargetLabels = dsTest.UnderlyingDatastores{1}.Labels;

関数evaluateAnomalyDetectionを使用してパフォーマンス メトリクスを計算し、異常検出器を評価します。関数は、テスト データ セット用の検出器の精度、適合率、感度、特異度を評価するいくつかのメトリクスを計算します。

metrics = evaluateAnomalyDetection(testSetOutputLabels,testSetTargetLabels,anomalyClasses);
Evaluating anomaly detection results
------------------------------------
* Finalizing... Done.
* Data set metrics:

    GlobalAccuracy    MeanAccuracy    Precision    Recall     Specificity    F1Score    FalsePositiveRate    FalseNegativeRate
    ______________    ____________    _________    _______    ___________    _______    _________________    _________________

       0.96923          0.97778           1        0.95556         1         0.97727            0                0.044444     

metricsConfusionMatrix プロパティには、テスト セットの混同行列が含まれます。混同行列を抽出し、混同プロットを表示します。この例の分類モデルは非常に正確で、予測の偽陽性と偽陰性の割合はわずかです。

M = metrics.ConfusionMatrix{:,:};
confusionchart(M,["Normal","Anomaly"])
acc = sum(diag(M)) / sum(M,"all");
title("Accuracy: "+acc)

dirtchip など、複数の異常クラス ラベルをこの例で指定した場合、関数 evaluateAnomalyDetection はデータ セット全体および各異常クラスについて、メトリクスを計算します。クラスごとのメトリクスは、anomalyDetectionMetricsオブジェクト metricsClassMetrics プロパティで返されます。

metrics.ClassMetrics
ans=2×2 table
               Accuracy    AccuracyPerSubClass
               ________    ___________________

    Normal           1         {1×1 table}    
    Anomaly    0.95556         {2×1 table}    

metrics.ClassMetrics(2,"AccuracyPerSubClass").AccuracyPerSubClass{1}
ans=2×1 table
            AccuracyPerSubClass
            ___________________

    chip          0.84375      
    dirt          0.99029      

分類判定の説明

異常検出器によって予測された異常値のヒートマップを使用して、イメージが正常または異常に分類される理由を説明することができます。このアプローチは、偽陰性と偽陽性のパターンを識別するのに役立ちます。これらのパターンを使用して、学習データのクラス均衡化を高めたり、ネットワーク パフォーマンスを改善したりするための戦略を特定できます。

異常値のヒート マップの表示範囲の計算

正常イメージと異常イメージを含むキャリブレーション セット全体で観察される異常スコアの範囲を表す表示範囲を計算します。イメージ間で同じ表示範囲を使用することで、各イメージをそれらの最小値と最大値でスケーリングする場合よりも簡単にイメージを比較できます。この例では、表示範囲をすべてのヒートマップに適用します。

minMapVal = inf;
maxMapVal = -inf;
reset(dsCal)
while hasdata(dsCal)
    img = read(dsCal);
    map = anomalyMap(detector,img{1});
    minMapVal = min(min(map,[],"all"),minMapVal);
    maxMapVal = max(max(map,[],"all"),maxMapVal);
end
displayRange = [minMapVal,maxMapVal];

異常イメージのヒートマップの表示

異常であると正しく分類されたイメージを選択します。次の結果は真陽性分類です。イメージを表示します。

testSetAnomalyLabels = testSetTargetLabels ~= "normal";
idxTruePositive = find(testSetAnomalyLabels' & testSetOutputLabels,1,"last");
dsExample = subset(dsTest,idxTruePositive);
img = read(dsExample);
img = img{1};
map = anomalyMap(detector,img);
imshow(anomalyMapOverlay(img,map,MapRange=displayRange,Blend="equal"))

通常イメージのヒートマップの表示

正常なイメージであると正しく分類されたイメージを選択して表示します。次の結果は真陰性分類です。

idxTrueNegative = find(~(testSetAnomalyLabels' | testSetOutputLabels));
dsExample = subset(dsTest,idxTrueNegative);
img = read(dsExample);
img = img{1};
map = anomalyMap(detector,img);
imshow(anomalyMapOverlay(img,map,MapRange=displayRange,Blend="equal"))

偽陰性イメージのヒートマップの表示

偽陰性は、ネットワークが正常として分類するもののうち、錠剤に欠陥異常があるイメージです。ネットワークからの説明を使用して、誤分類についての洞察を取得します。

テスト セットからすべての偽陰性イメージを見つけます。関数 transform を使用して、偽陰性イメージのヒートマップ オーバーレイを取得します。変換の演算は無名関数によって指定されます。無名関数では、関数anomalyMapOverlayを適用してテスト セット内の各偽陰性のヒートマップ オーバーレイを取得します。

falseNegativeIdx = find(testSetAnomalyLabels' & ~testSetOutputLabels);
if ~isempty(falseNegativeIdx)
    fnExamples = subset(dsTest,falseNegativeIdx);
    fnExamplesWithHeatmapOverlays = transform(fnExamples,@(x) {...
        anomalyMapOverlay(x{1},anomalyMap(detector,x{1}), ...
        MapRange=displayRange,Blend="equal")});
    fnExamples = readall(fnExamples);
    fnExamples = fnExamples(:,1);
    fnExamplesWithHeatmapOverlays = readall(fnExamplesWithHeatmapOverlays);
    montage(fnExamples)
    montage(fnExamplesWithHeatmapOverlays)
else
    disp("No false negatives detected.")
end

偽陽性イメージのヒートマップの表示

偽陽性は、ネットワークが異常として分類するもののうち、錠剤に欠陥異常がないイメージです。テスト セット内のすべての偽陽性を見つけます。ネットワークからの説明を使用して、誤分類についての洞察を取得します。たとえば、異常スコアがイメージの背景に局在化している場合は、前処理の際に背景の抑制処理を検討できます。

falsePositiveIdx = find(~testSetAnomalyLabels' & testSetOutputLabels);
if ~isempty(falsePositiveIdx)
    fpExamples = subset(dsTest,falsePositiveIdx);
    fpExamplesWithHeatmapOverlays = transform(fpExamples,@(x) { ...
        anomalyMapOverlay(x{1},anomalyMap(detector,x{1}), ...
        MapRange=displayRange,Blend="equal")});
    fpExamples = readall(fpExamples);
    fpExamples = fpExamples(:,1);
    fpExamplesWithHeatmapOverlays = readall(fpExamplesWithHeatmapOverlays);
    montage(fpExamples)
    montage(fpExamplesWithHeatmapOverlays)
else
    disp("No false positives detected.")
end
No false positives detected.

サポート関数

補助関数 addLabelData は、data 内のラベル情報の one-hot 符号化表現を作成します。

function [data,info] = addLabelData(data,info)
    if info.Label == categorical("normal")
        onehotencoding = 0;
    else
        onehotencoding = 1;
    end
    data = {data,onehotencoding};
end

参考文献

[1] Liznerski, Philipp, Lukas Ruff, Robert A. Vandermeulen, Billy Joe Franks, Marius Kloft, and Klaus-Robert Müller. "Explainable Deep One-Class Classification." Preprint, submitted March 18, 2021. https://arxiv.org/abs/2007.01760.

[2] Ruff, Lukas, Robert A. Vandermeulen, Billy Joe Franks, Klaus-Robert Müller, and Marius Kloft. "Rethinking Assumptions in Deep Anomaly Detection." Preprint, submitted May 30, 2020. https://arxiv.org/abs/2006.00339.

[3] Simonyan, Karen, and Andrew Zisserman."Very Deep Convolutional Networks for Large-Scale Image Recognition." Preprint, submitted April 10, 2015. https://arxiv.org/abs/1409.1556.

[4] ImageNet. https://www.image-net.org.

参考

| | | | | | | | | (Deep Learning Toolbox) | (Deep Learning Toolbox)

関連する例

詳細