このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
説明可能 FCDD ネットワークを使用したイメージ異常の検出
この例では、1 クラスの完全畳み込みデータ記述 (Fully Convolutional Data Description: FCDD) 異常検出ネットワークを使用して、錠剤イメージの欠陥を検出する方法を示します。
異常検出の重要な目標は、学習済みネットワークがイメージを異常として分類する理由を人間の観察者が理解できるようにすることです。FCDD では "説明可能な分類" を行うことができます。これは、ニューラル ネットワークが分類決定にどのように到達したかを正当化する情報でクラス予測を補足するものです [1]。FCDD ネットワークは、各ピクセルの異常を示す確率のヒートマップを返します。分類器は、異常スコア ヒートマップの平均値に基づいて、イメージを正常または異常としてラベル付けします。
分類データ セット用の錠剤イメージのダウンロード
この例では、PillQC データ セットを使用します。データ セットには 3 つのクラスのイメージ (欠陥のない normal
イメージ、錠剤に欠けのある chip
イメージ、および汚れのある dirt
イメージ) が格納されています。データ セットには、149 個の normal
イメージ、43 個の chip
イメージ、および 138 個の dirt
イメージがあります。このデータ セットのサイズは 3.57 MB です。
dataDir
をデータ セットの目的の場所として設定します。補助関数 downloadPillQCData
を使用してデータ セットをダウンロードします。この関数は、この例にサポート ファイルとして添付されています。関数は ZIP ファイルをダウンロードし、データを chip
、dirt
、および normal
の各サブディレクトリに解凍します。
dataDir = fullfile(tempdir,"PillDefects");
downloadPillQCData(dataDir)
次のイメージは、各クラスのイメージの例を示しています。左は欠陥のない正常な錠剤、中央は汚れが付着した錠剤、右は欠けのある錠剤です。このデータ セット内のイメージには影、焦点のブレ、および背景色のバリエーションが含まれていますが、この例で使用したアプローチは、これらのイメージ取得アーティファクトに対してロバストです。
データの読み込みと前処理
イメージ データを読み取って管理する imageDatastore
を作成します。ディレクトリの名前に従って、各イメージに chip
、dirt
、または normal
のラベルを付けます。
imageDir = fullfile(dataDir,"pillQC-main","images"); imds = imageDatastore(imageDir,IncludeSubfolders=true,LabelSource="foldernames");
学習セット、キャリブレーション セット、テスト セットへのデータの分割
関数splitAnomalyData
(Computer Vision Toolbox)を使用して、学習セット、キャリブレーション セット、およびテスト セットを作成します。この例では、Outlier Exposure (OE) を使用する FCDD アプローチを実装します。OE は、正常なイメージを主として、少数の異常なイメージを加えた学習データで構成されます。主に正常なシーンのみのサンプルを使用して学習させているにもかかわらず、モデルは正常なシーンと異常なシーンを区別する方法を学習します。
正常なイメージの 50% と各異常クラスのわずかな割合 (5%) を学習データ セットに割り当てます。正常なイメージの 10% と各異常クラスの 20% をキャリブレーション セットに割り当てます。残りのイメージをテスト セットに割り当てます。
normalTrainRatio = 0.5; anomalyTrainRatio = 0.05; normalCalRatio = 0.10; anomalyCalRatio = 0.20; normalTestRatio = 1 - (normalTrainRatio + normalCalRatio); anomalyTestRatio = 1 - (anomalyTrainRatio + anomalyCalRatio); anomalyClasses = ["chip","dirt"]; [imdsTrain,imdsCal,imdsTest] = splitAnomalyData(imds,anomalyClasses, ... NormalLabelsRatio=[normalTrainRatio normalCalRatio normalTestRatio], ... AnomalyLabelsRatio=[anomalyTrainRatio anomalyCalRatio anomalyTestRatio]);
Splitting anomaly dataset ------------------------- * Finalizing... Done. * Number of files and proportions per class in all the datasets: Input Train Validation Test NumFiles Ratio NumFiles Ratio NumFiles Ratio NumFiles Ratio ___________________ ____________________ ___________________ ___________________ chip 43 0.1303 2 0.02381 9 0.17647 32 0.1641 dirt 138 0.41818 7 0.083333 28 0.54902 103 0.52821 normal 149 0.45152 75 0.89286 14 0.27451 60 0.30769
さらに学習データを 2 つのデータストアに分割し、1 つは正常データのみを含み、もう 1 つは異常データのみを含むようにします。
[imdsNormalTrain,imdsAnomalyTrain] = splitAnomalyData(imdsTrain,anomalyClasses, ...
NormalLabelsRatio=[1 0 0],AnomalyLabelsRatio=[0 1 0],Verbose=false);
学習データの拡張
関数transform
を補助関数 augmentDataForPillAnomalyDetector
によって指定されたカスタム前処理演算と共に使用して、学習データを拡張します。補助関数は、この例にサポート ファイルとして添付されています。
関数 augmentDataForPillAnomalyDetector
は、各入力イメージに 90 度の回転と水平方向および垂直方向の反転をランダムに適用します。
imdsNormalTrain = transform(imdsNormalTrain,@augmentDataForPillAnomalyDetector); imdsAnomalyTrain = transform(imdsAnomalyTrain,@augmentDataForPillAnomalyDetector);
関数 transform
を補助関数 addLabelData
で指定された演算と共に使用して、バイナリ ラベルをキャリブレーション データ セットとテスト データ セットに追加します。補助関数はこの例の最後で定義されており、normal
クラスのイメージにバイナリ ラベル 0
を割り当て、chip
クラスまたは dirt
クラスのイメージにバイナリ ラベル 1
を割り当てます。
dsCal = transform(imdsCal,@addLabelData,IncludeInfo=true); dsTest = transform(imdsTest,@addLabelData,IncludeInfo=true);
拡張された 9 つの学習イメージのサンプルを可視化します。
exampleData = readall(subset(imdsNormalTrain,1:9)); montage(exampleData(:,1));
FCDD モデルの作成
この例では、Fully Convolutional Data Description (FCDD) モデル [1] を使用します。FCDD の基本的な考え方は、入力イメージ内の各領域に異常なコンテンツが含まれる確率を記述する異常スコア マップを生成するようにネットワークに学習させることです。
関数pretrainedEncoderNetwork
は、事前学習済みのバックボーンとして使用するために、ImageNet で事前に学習させた Inception-v3 ネットワークの最初の 3 つのダウンサンプリング ステージを返します。
backbone = pretrainedEncoderNetwork("inceptionv3",3);
Inception-v3 バックボーンで関数fcddAnomalyDetector
(Computer Vision Toolbox)を使用して、FCDD 異常検出器ネットワークを作成します。
net = fcddAnomalyDetector(backbone);
ネットワークの学習または事前学習済みネットワークのダウンロード
既定では、この例は補助関数 downloadTrainedNetwork
を使用して、FCDD 異常検出器の事前学習済みバージョンをダウンロードします。補助関数は、この例にサポート ファイルとして添付されています。事前学習済みのネットワークを使用することで、学習の完了を待たずに例全体を実行できます。
ネットワークに学習させるには、次のコードで変数 doTraining
を true
に設定します。フィールドに値を入力して、学習に使用するエポック数 numEpochs
を指定します。関数trainFCDDAnomalyDetector
(Computer Vision Toolbox)を使用して、モデルに学習させます。
可能であれば、1 つ以上の GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™、および CUDA® 対応の NVIDIA® GPU が必要です。詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。学習には NVIDIA Titan RTX™ で約 3 分を要します。
doTraining = false; numEpochs = 200; if doTraining options = trainingOptions("adam", ... Shuffle="every-epoch",... MaxEpochs=numEpochs,InitialLearnRate=1e-4, ... MiniBatchSize=32,... BatchNormalizationStatistics="moving"); detector = trainFCDDAnomalyDetector(imdsNormalTrain,imdsAnomalyTrain,net,options); modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss")); save(fullfile(dataDir,"trainedPillAnomalyDetector-"+modelDateTime+".mat"),"detector"); else trainedPillAnomalyDetectorNet_url = "https://ssd.mathworks.com/supportfiles/"+ ... "vision/data/trainedFCDDPillAnomalyDetectorSpkg.zip"; downloadTrainedNetwork(trainedPillAnomalyDetectorNet_url,dataDir); load(fullfile(dataDir,"folderForSupportFilesInceptionModel", ... "trainedPillFCDDNet.mat")); end
異常のしきい値の設定
異常検出器の異常スコアしきい値を選択します。異常検出器は、スコアがしきい値を上回るか下回るかに基づいてイメージを分類します。この例では、正常なイメージと異常なイメージの両方を含むキャリブレーション データ セットを使用してしきい値を選定します。
キャリブレーション セット内の各イメージについて、平均異常スコアとグラウンド トゥルース ラベルを取得します。
scores = predict(detector,dsCal);
labels = imdsCal.Labels ~= "normal";
正常クラスと異常クラスの平均異常スコアのヒストグラムをプロットします。モデルが予測した異常スコアによって、分布がきれいに分かれます。
numBins = 20; [~,edges] = histcounts(scores,numBins); figure hold on hNormal = histogram(scores(labels==0),edges); hAnomaly = histogram(scores(labels==1),edges); hold off legend([hNormal,hAnomaly],"Normal","Anomaly") xlabel("Mean Anomaly Score") ylabel("Counts")
関数anomalyThreshold
(Computer Vision Toolbox)を使用して、異常しきい値の最適値を計算します。最初の 2 つの入力引数を、キャリブレーション データ セットのグラウンド トゥルース ラベル (labels
) および予測異常スコア (scores
) として指定します。真陽性異常イメージの labels
値は true
であるため、3 番目の入力引数を true
として指定します。関数 anomalyThreshold
は、検出器用の最適なしきい値と、rocmetrics
(Deep Learning Toolbox)オブジェクトとして格納される受信者動作特性 (ROC) 曲線を返します。
[thresh,roc] = anomalyThreshold(labels,scores,true);
異常検出器の Threshold
プロパティを最適な値に設定します。
detector.Threshold = thresh;
rocmetrics
のオブジェクト関数plot
(Deep Learning Toolbox)を使用して ROC をプロットします。ROC 曲線は、しきい値が取り得る範囲に対する分類器のパフォーマンスを示します。ROC 曲線上の各点は、異なるしきい値を使用してキャリブレーション セット イメージを分類する場合の偽陽性率 (x 座標) と真陽性率 (y 座標) を表します。青い実線は ROC 曲線を表します。赤い破線は、50% の成功率に対応するスキルなしの分類器を表します。ROC 曲線下面積 (AUC) メトリクスは分類器のパフォーマンスを示し、完全な分類器に対応する最大 ROC AUC は 1.0 となります。
plot(roc)
title("ROC AUC: "+ roc.AUC)
分類モデルの評価
テスト セット内の各イメージを正常または異常のいずれかに分類します。
testSetOutputLabels = classify(detector,dsTest);
各テスト イメージのグラウンド トゥルース ラベルを取得します。
testSetTargetLabels = dsTest.UnderlyingDatastores{1}.Labels;
関数evaluateAnomalyDetection
(Computer Vision Toolbox)を使用してパフォーマンス メトリクスを計算し、異常検出器を評価します。関数は、テスト データ セット用の検出器の精度、適合率、感度、特異度を評価するいくつかのメトリクスを計算します。
metrics = evaluateAnomalyDetection(testSetOutputLabels,testSetTargetLabels,anomalyClasses);
Evaluating anomaly detection results ------------------------------------ * Finalizing... Done. * Data set metrics: GlobalAccuracy MeanAccuracy Precision Recall Specificity F1Score FalsePositiveRate FalseNegativeRate ______________ ____________ _________ _______ ___________ _______ _________________ _________________ 0.96923 0.97778 1 0.95556 1 0.97727 0 0.044444
metrics
の ConfusionMatrix
プロパティには、テスト セットの混同行列が含まれます。混同行列を抽出し、混同プロットを表示します。この例の分類モデルは非常に正確で、予測の偽陽性と偽陰性の割合はわずかです。
M = metrics.ConfusionMatrix{:,:}; confusionchart(M,["Normal","Anomaly"]) acc = sum(diag(M)) / sum(M,"all"); title("Accuracy: "+acc)
dirt
や chip
など、複数の異常クラス ラベルをこの例で指定した場合、関数 evaluateAnomalyDetection
はデータ セット全体および各異常クラスについて、メトリクスを計算します。クラスごとのメトリクスは、anomalyDetectionMetrics
(Computer Vision Toolbox)オブジェクト metrics
の ClassMetrics
プロパティで返されます。
metrics.ClassMetrics
ans=2×2 table
Accuracy AccuracyPerSubClass
________ ___________________
Normal 1 {1×1 table}
Anomaly 0.95556 {2×1 table}
metrics.ClassMetrics(2,"AccuracyPerSubClass").AccuracyPerSubClass{1}
ans=2×1 table
AccuracyPerSubClass
___________________
chip 0.84375
dirt 0.99029
分類判定の説明
異常検出器によって予測された異常値のヒートマップを使用して、イメージが正常または異常に分類される理由を説明することができます。このアプローチは、偽陰性と偽陽性のパターンを識別するのに役立ちます。これらのパターンを使用して、学習データのクラス均衡化を高めたり、ネットワーク パフォーマンスを改善したりするための戦略を特定できます。
異常値のヒート マップの表示範囲の計算
正常イメージと異常イメージを含むキャリブレーション セット全体で観察される異常スコアの範囲を表す表示範囲を計算します。イメージ間で同じ表示範囲を使用することで、各イメージをそれらの最小値と最大値でスケーリングする場合よりも簡単にイメージを比較できます。この例では、表示範囲をすべてのヒートマップに適用します。
minMapVal = inf; maxMapVal = -inf; reset(dsCal) while hasdata(dsCal) img = read(dsCal); map = anomalyMap(detector,img{1}); minMapVal = min(min(map,[],"all"),minMapVal); maxMapVal = max(max(map,[],"all"),maxMapVal); end displayRange = [minMapVal,maxMapVal];
異常イメージのヒートマップの表示
異常であると正しく分類されたイメージを選択します。次の結果は真陽性分類です。イメージを表示します。
testSetAnomalyLabels = testSetTargetLabels ~= "normal"; idxTruePositive = find(testSetAnomalyLabels' & testSetOutputLabels,1,"last"); dsExample = subset(dsTest,idxTruePositive); img = read(dsExample); img = img{1}; map = anomalyMap(detector,img); imshow(anomalyMapOverlay(img,map,MapRange=displayRange,Blend="equal"))
通常イメージのヒートマップの表示
正常なイメージであると正しく分類されたイメージを選択して表示します。次の結果は真陰性分類です。
idxTrueNegative = find(~(testSetAnomalyLabels' | testSetOutputLabels));
dsExample = subset(dsTest,idxTrueNegative);
img = read(dsExample);
img = img{1};
map = anomalyMap(detector,img);
imshow(anomalyMapOverlay(img,map,MapRange=displayRange,Blend="equal"))
偽陰性イメージのヒートマップの表示
偽陰性は、ネットワークが正常として分類するもののうち、錠剤に欠陥異常があるイメージです。ネットワークからの説明を使用して、誤分類についての洞察を取得します。
テスト セットからすべての偽陰性イメージを見つけます。関数 transform
を使用して、偽陰性イメージのヒートマップ オーバーレイを取得します。変換の演算は無名関数によって指定されます。無名関数では、関数anomalyMapOverlay
(Computer Vision Toolbox)を適用してテスト セット内の各偽陰性のヒートマップ オーバーレイを取得します。
falseNegativeIdx = find(testSetAnomalyLabels' & ~testSetOutputLabels); if ~isempty(falseNegativeIdx) fnExamples = subset(dsTest,falseNegativeIdx); fnExamplesWithHeatmapOverlays = transform(fnExamples,@(x) {... anomalyMapOverlay(x{1},anomalyMap(detector,x{1}), ... MapRange=displayRange,Blend="equal")}); fnExamples = readall(fnExamples); fnExamples = fnExamples(:,1); fnExamplesWithHeatmapOverlays = readall(fnExamplesWithHeatmapOverlays); montage(fnExamples) montage(fnExamplesWithHeatmapOverlays) else disp("No false negatives detected.") end
偽陽性イメージのヒートマップの表示
偽陽性は、ネットワークが異常として分類するもののうち、錠剤に欠陥異常がないイメージです。テスト セット内のすべての偽陽性を見つけます。ネットワークからの説明を使用して、誤分類についての洞察を取得します。たとえば、異常スコアがイメージの背景に局在化している場合は、前処理の際に背景の抑制処理を検討できます。
falsePositiveIdx = find(~testSetAnomalyLabels' & testSetOutputLabels); if ~isempty(falsePositiveIdx) fpExamples = subset(dsTest,falsePositiveIdx); fpExamplesWithHeatmapOverlays = transform(fpExamples,@(x) { ... anomalyMapOverlay(x{1},anomalyMap(detector,x{1}), ... MapRange=displayRange,Blend="equal")}); fpExamples = readall(fpExamples); fpExamples = fpExamples(:,1); fpExamplesWithHeatmapOverlays = readall(fpExamplesWithHeatmapOverlays); montage(fpExamples) montage(fpExamplesWithHeatmapOverlays) else disp("No false positives detected.") end
No false positives detected.
サポート関数
補助関数 addLabelData
は、data
内のラベル情報の one-hot 符号化表現を作成します。
function [data,info] = addLabelData(data,info) if info.Label == categorical("normal") onehotencoding = 0; else onehotencoding = 1; end data = {data,onehotencoding}; end
参考文献
[1] Liznerski, Philipp, Lukas Ruff, Robert A. Vandermeulen, Billy Joe Franks, Marius Kloft, and Klaus-Robert Müller. "Explainable Deep One-Class Classification." Preprint, submitted March 18, 2021. https://arxiv.org/abs/2007.01760.
[2] Ruff, Lukas, Robert A. Vandermeulen, Billy Joe Franks, Klaus-Robert Müller, and Marius Kloft. "Rethinking Assumptions in Deep Anomaly Detection." Preprint, submitted May 30, 2020. https://arxiv.org/abs/2006.00339.
[3] Simonyan, Karen, and Andrew Zisserman."Very Deep Convolutional Networks for Large-Scale Image Recognition." Preprint, submitted April 10, 2015. https://arxiv.org/abs/1409.1556.
[4] ImageNet. https://www.image-net.org.
参考
transform
| pretrainedEncoderNetwork
| fcddAnomalyDetector
(Computer Vision Toolbox) | trainFCDDAnomalyDetector
(Computer Vision Toolbox) | predict
(Computer Vision Toolbox) | anomalyThreshold
(Computer Vision Toolbox) | anomalyMapOverlay
(Computer Vision Toolbox) | evaluateAnomalyDetection
(Computer Vision Toolbox) | anomalyDetectionMetrics
(Computer Vision Toolbox) | rocmetrics
(Deep Learning Toolbox) | confusionchart
(Deep Learning Toolbox)
関連する例
- Classify Defects on Wafer Maps Using Deep Learning
- Detect Image Anomalies Using Pretrained ResNet-18 Feature Embeddings
詳細
- Getting Started with Anomaly Detection Using Deep Learning (Computer Vision Toolbox)
- 深層学習用のデータストア (Deep Learning Toolbox)