Main Content

最大および最小の活性化イメージを使用したイメージ分類の可視化

この例では、データセットを使用して何が深層ニューラル ネットワークのチャネルを活性化するか発見する方法を説明します。これにより、ニューラル ネットワークの仕組みを理解し、学習データ セットの潜在的な問題を診断することができます。

この例では、食品のデータセットを転移学習させた GoogLeNet を使用して、いくつかのシンプルな可視化手法を取り上げます。分類器を最大あるいは最小に活性化させるイメージを確認することにより、ニューラル ネットワークが誤った分類をする理由を見つけることができます。

データの読み込みと前処理

イメージをイメージ データストアとして読み込みます。この小さいデータセットには合計 978 の観測値と、食品の 9 つのクラスが含まれています。

このデータを学習セット、検証セット、およびテスト セットに分割して、GoogLeNet を使用した転移学習の準備をします。データ セットのイメージをいくつか選択して表示します。

rng default
dataDir = fullfile(tempdir,"Food Dataset");
url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip";

if ~exist(dataDir,"dir")
    mkdir(dataDir);
end

downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset...
This can take several minutes to download...
Download finished...
Unzipping file...
Unzipping finished...
Done.
imds = imageDatastore(dataDir, ...
    "IncludeSubfolders",true,"LabelSource","foldernames");
[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.6,0.2);

rnd = randperm(numel(imds.Files),9);
for i = 1:numel(rnd)
subplot(3,3,i)
imshow(imread(imds.Files{rnd(i)}))
label = imds.Labels(rnd(i));
title(label,"Interpreter","none")
end

食品イメージを分類するネットワークの学習

事前学習済みの GoogLeNet ネットワークを使用して、9 つのタイプの食品を分類するよう再学習させます。Deep Learning Toolbox™ Model for GoogLeNet Network サポート パッケージをインストールしていない場合、ダウンロード用リンクが表示されます。

別の事前学習済みのネットワークを試すには、この例を MATLAB® で開き、googlenet より高速なネットワークである squeezenet のような別のネットワークを選択します。使用可能なすべてのネットワークについては、事前学習済みの深層ニューラル ネットワークを参照してください。

net = googlenet;

ネットワークの Layers プロパティの最初の要素はイメージ入力層です。この層にはサイズが 224 x 224 x 3 の入力イメージが必要です。ここで、3 はカラー チャネルの数です。

inputSize = net.Layers(1).InputSize;

ネットワーク アーキテクチャ

ネットワークの畳み込み層は、入力イメージを分類するために、最後の学習可能な層と最終分類層が使用するイメージの特徴を抽出します。GoogLeNet のこれらの 2 つの層 'loss3-classifier' および 'output' は、ネットワークによって抽出された特徴を組み合わせてクラス確率、損失値、および予測ラベルにまとめる方法に関する情報を含んでいます。新しいイメージを分類するために事前学習済みのネットワークに学習させるには、これら 2 つの層を新しいデータ セットに適合させた新しい層に置き換えます。

学習済みのネットワークから層グラフを抽出します。

lgraph = layerGraph(net);

ほとんどのネットワークでは、学習可能な重みを持つ最後の層は全結合層です。この全結合層を、新しいデータ セットのクラスの数 (この例では 9) と同じ数の出力をもつ新しい全結合層に置き換えます。

numClasses = numel(categories(imdsTrain.Labels));

newfclayer = fullyConnectedLayer(numClasses,...
    'Name','new_fc',...
    'WeightLearnRateFactor',10,...
    'BiasLearnRateFactor',10);
lgraph = replaceLayer(lgraph,net.Layers(end-2).Name,newfclayer);

分類層はネットワークの出力クラスを指定します。分類層をクラス ラベルがない新しい分類層に置き換えます。trainNetwork は、学習時に層の出力クラスを自動的に設定します。

newclasslayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,net.Layers(end).Name,newclasslayer);

ネットワークの学習

ネットワークにはサイズが 224 x 224 x 3 の入力イメージが必要ですが、イメージ データストアにあるイメージのサイズは異なります。拡張イメージ データストアを使用して学習イメージのサイズを自動的に変更します。学習イメージに対して実行する追加の拡張演算として、学習イメージを縦軸に沿ってランダムに反転させる演算、最大 30 ピクセルだけランダムに平行移動させる演算、および水平方向と垂直方向に最大 10% スケールアップする演算を指定します。データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。

pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

他のデータ拡張を実行せずに検証イメージのサイズを自動的に変更するには、追加の前処理演算を指定せずに拡張イメージ データストアを使用します。

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

学習オプションを指定します。InitialLearnRate を小さい値に設定して、まだ凍結されていない転移層での学習速度を下げます。上記の手順では、最後の学習可能な層の学習率係数を大きくして、新しい最後の層での学習時間を短縮しています。この学習率設定の組み合わせによって、新しい層では高速に学習が行われ、中間層では学習速度が低下し、凍結された初期の層では学習が行われません。

学習するエポック数を指定します。転移学習の実行時には、同じエポック数の学習を行う必要はありません。エポックとは、学習データセット全体の完全な学習サイクルのことです。ミニバッチのサイズと検証データを指定します。エポックごとに 1 回、検定精度を計算します。

miniBatchSize = 10;
valFrequency = floor(numel(augimdsTrain.Files)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ...
    'MaxEpochs',4, ...
    'InitialLearnRate',3e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',valFrequency, ...
    'Verbose',false, ...
    'Plots','training-progress');

学習データを使用してネットワークに学習させます。既定では、使用可能な GPU がある場合、trainNetwork は GPU を使用します。これには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。そうでない場合、trainNetwork は CPU を使用します。trainingOptions の名前と値のペアの引数 'ExecutionEnvironment' を使用して、実行環境を指定することもできます。このデータセットは小さいため、学習は短時間で終了します。実際にこの例を実行してネットワークに学習させた場合、学習過程に含まれるランダム性のため、異なる結果が得られ誤分類をする可能性があります。

net = trainNetwork(augimdsTrain,lgraph,options);

テスト イメージの分類

微調整したネットワークを使用してテスト イメージを分類し、分類精度を計算します。

augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);
[predictedClasses,predictedScores] = classify(net,augimdsTest);

accuracy = mean(predictedClasses == imdsTest.Labels)
accuracy = 0.8418

テスト セットの混同行列

テスト セットの予測の混同行列をプロットします。これは、ネットワークで最も問題の原因になっているクラスを強調表示します。

figure;
confusionchart(imdsTest.Labels,predictedClasses,'Normalization',"row-normalized");

混同行列は、グリーク サラダ、刺身、ホット ドッグ、寿司といったいくつかのクラスに対するネットワーク性能が低いことを示しています。これらのクラスはデータ セット内で少数しか存在しないため、そのことがネットワーク性能に影響を与えている可能性があります。ネットワーク性能が上がらない原因をより深く理解するため、これらのクラスのいずれかを調査します。

figure();
histogram(imdsValidation.Labels);
ax = gca();
ax.XAxis.TickLabelInterpreter = "none";

分類の調査

ネットワークによる寿司クラスの分類について調査します。

最も寿司らしい寿司

最初に、ネットワークの寿司クラスを最も強く活性化する寿司のイメージを探します。これは、"ネットワークはどのイメージを最も寿司らしいと考えるか" という質問の答えになります。

最大に活性化するイメージ、すなわち、全結合層の "寿司" ニューロンを強く活性化する入力イメージをプロットします。次の Figure は上位 4 つのイメージをクラス スコアの降順に示しています。

chosenClass = "sushi";
classIdx = find(net.Layers(end).Classes == chosenClass);

numImgsToShow = 4;

[sortedScores,imgIdx] = findMaxActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);

figure
plotImages(imdsTest,imgIdx,sortedScores,predictedClasses,numImgsToShow)

寿司クラスの手がかりの可視化

ネットワークは寿司について正しい部分に着目しているでしょうか。ネットワークの寿司クラスで最大に活性化されたイメージは、相互に類似しており、多くの円形のものが互いに密集しています。

ネットワークはこのような種類の寿司を上手く分類しています。しかし、これが正しいことを検証し、ネットワークの判定理由をより正確に理解するには、Grad-CAM のような可視化手法を使用します。Grad-CAM の使用の詳細については、Grad-CAM での深層学習による判定の理由の解明を参照してください。

拡張イメージ データストアから最初のサイズ変更されたイメージを読み取り、gradCAM を使用して Grad-CAM による可視化をプロットします。

imageNumber = 1;

observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

Grad-CAM マップを見ると、ネットワークがイメージ内の寿司に焦点を当てていることがわかります。ただし、ネットワークが皿やテーブルの一部を見ていることもわかります。

2 番目のイメージでは、寿司の集まりが左側に、単独の寿司が右側にあります。ネットワークが何に焦点を合てるかを確認するには、2 番目のイメージを読み取って Grad-CAM をプロットします。

imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

ネットワークは、このイメージを寿司の集まりとして認識するため、寿司として分類します。ただし、単独で 1 つの寿司を分類できるでしょうか。これをテストするには、寿司が 1 つだけある写真を見てみます。

img = imread(strcat(tempdir,"Food Dataset/sushi/sushi_18.jpg"));
img = imresize(img,net.Layers(1).InputSize(1:2),"Method","bilinear","AntiAliasing",true);

[label,score] = classify(net,img);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
sgtitle(string(label)+" (score: "+ max(score)+")")

ネットワークは単独の寿司を正確に分類できています。しかし、ネットワークは寿司全体を 1 つの塊として見ているのではなく、寿司の上部やキュウリの集合に焦点を当てていることが GradCAM からわかります。

Grad-CAM の可視化手法を、小さな材料の集まりを含まない単独の寿司に対して実行します。

img = imread("crop__sushi34-copy.jpg");
img = imresize(img,net.Layers(1).InputSize(1:2),"Method","bilinear","AntiAliasing",true);

[label,score] = classify(net,img);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (score: "+ max(score)+")")

この場合、可視化手法によって、ネットワークの性能が劣る理由が明らかになります。寿司のイメージをハンバーガーとして誤って分類しています。

この問題を解決するには、ネットワークの学習過程でさらに多くの単独の寿司のイメージを提供する必要があります。

最も寿司らしくない寿司

今度は、寿司クラスについてネットワークを最も活性化しない寿司のイメージを探します。これは、"ネットワークはどのイメージを最も寿司らしくないと考えるか" という質問の答えになります。

ネットワークの性能が劣るイメージを見つけ、その判定に関する洞察が得られるため、これは有用です。

chosenClass = "sushi";
numImgsToShow = 9;

[sortedScores,imgIdx] = findMinActivatingImages(imdsTest,chosenClass,predictedScores,numImgsToShow);

figure
plotImages(imdsTest,imgIdx,sortedScores,predictedClasses,numImgsToShow)

刺身に誤分類された寿司の調査

なぜネットワークは寿司を刺身として分類するのでしょうか。ネットワークは、9 つのイメージのうち 3 つを刺身として分類しています。それらのイメージのうちのいくつか (イメージ 4 や 9 など) には実際に刺身が含まれています。これは、実はネットワークが誤分類したわけではないことを意味します。これらのイメージには誤ったラベルが付いています。

ネットワークが何に焦点を当てているかを確認するには、これらのイメージの 1 つに対して Grad-CAM 手法を実行します。

imageNumber = 4;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

予想どおり、ネットワークは寿司ではなく刺身に焦点を当てています。

ピザに誤分類された寿司の調査

なぜネットワークは寿司をピザに分類したのでしょうか。ネットワークは、これらのイメージのうち 4 つを寿司ではなくピザとして分類しています。イメージ 1 を見ると、このイメージにはカラフルなトッピングが含まれており、これがネットワークを混乱させている可能性があります。

ネットワークがイメージのどの部分を見ているかを確認するには、これらのイメージの 1 つに対して Grad-CAM 手法を実行します。

imageNumber = 1;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

ネットワークはトッピングに強く焦点を当てています。ネットワークがピザとトッピングのある寿司を区別できるようにするには、トッピングのある寿司の学習イメージをさらに追加します。ネットワークは皿にも焦点を当てています。これは、ネットワークが特定の食品を特定のタイプの皿と関連付けて学習しており、寿司のイメージを見る際も皿にハイライトを当てていることが原因となっている可能性があります。ネットワーク性能を改善するには、さまざまなタイプの皿に乗っている食品のサンプルを追加して学習させます。

ハンバーガーに誤分類された寿司の調査

なぜネットワークは寿司をハンバーガーとして分類するのでしょうか。ネットワークが何に焦点を当てているかを確認するには、誤分類されたイメージに対して Grad-CAM 手法を実行します。

imageNumber = 2;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")")

ネットワークは、イメージ内の花に焦点を当てています。カラフルな紫の花と茶色の茎がネットワークを混乱させ、このイメージをハンバーガーと認識させています。

フライドポテトに誤分類された寿司の調査

なぜネットワークは寿司をフライドポテトとして分類するのでしょうか。ネットワークは、3 番目のイメージを寿司ではなくフライドポテトとして分類しています。この寿司には黄色いトッピングが盛られているため、ネットワークはこの色をフライドポテトと関連付けている可能性があります。

このイメージに対して Grad-CAM を実行します。

imageNumber = 3;
observation = augimdsTest.readByIndex(imgIdx(imageNumber));
img = observation.input{1};

label = predictedClasses(imgIdx(imageNumber));
score = sortedScores(imageNumber);

gradcamMap = gradCAM(net,img,label);

figure
alpha = 0.5;
plotGradCAM(img,gradcamMap,alpha);
title(string(label)+" (sushi score: "+ max(score)+")","Interpreter","none")

ネットワークは黄色い寿司に焦点を当て、それをフライドポテトとして分類しています。ハンバーガーと同様に、通常とは異なる色が、ネットワークに寿司を誤分類させる原因となっています。

この場合にネットワークを改善するには、フライドポテト以外の黄色い食品のイメージを追加して学習させます。

まとめ

クラス スコアの大小を決めるデータ点や、ネットワークが高い信頼度で誤分類をするデータ点を調査することは、学習済みのネットワークがどのように機能しているかに関して有用な洞察が得られるシンプルな手法です。食品のデータセットの場合、この例では次のことが明確になりました。

  • テスト データには、実際は "寿司" であっても "刺身" というように、真のラベルが誤っているイメージが複数含まれる。このデータには、寿司と刺身の両方が含まれるイメージのように、不完全なラベルも含まれている。

  • ネットワークは "寿司" を "複数の集合した円形のもの" と考えている。しかし、単独の寿司も同様に区別できなければならない。

  • トッピングや通常と異なる色のある寿司や刺身は、ネットワークを混乱させる。この問題を解決するには、さらに多様な寿司と刺身がデータになければならない。

  • 性能を改善するには、少数しか存在しないクラスのイメージをさらに数多くネットワークに見せる必要がある。

補助関数

function downloadExampleFoodImagesData(url,dataDir)
% Download the Example Food Image data set, containing 978 images of
% different types of food split into 9 classes.

% Copyright 2019 The MathWorks, Inc.

fileName = "ExampleFoodImageDataset.zip";
fileFullPath = fullfile(dataDir,fileName);

% Download the .zip file into a temporary directory.
if ~exist(fileFullPath,"file")
    fprintf("Downloading MathWorks Example Food Image dataset...\n");
    fprintf("This can take several minutes to download...\n");
    websave(fileFullPath,url);
    fprintf("Download finished...\n");
else
    fprintf("Skipping download, file already exists...\n");
end

% Unzip the file.
%
% Check if the file has already been unzipped by checking for the presence
% of one of the class directories.
exampleFolderFullPath = fullfile(dataDir,"pizza");
if ~exist(exampleFolderFullPath,"dir")
    fprintf("Unzipping file...\n");
    unzip(fileFullPath,dataDir);
    fprintf("Unzipping finished...\n");
else
    fprintf("Skipping unzipping, file already unzipped...\n");
end
fprintf("Done.\n");

end

function [sortedScores,imgIdx] = findMaxActivatingImages(imds,className,predictedScores,numImgsToShow)
% Find the predicted scores of the chosen class on all the images of the chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
[scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores);

% Sort the scores in descending order
[sortedScores,idx] = sort(scoresForChosenClass,'descend');

% Return the indices of only the first few
imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow));

end

function [sortedScores,imgIdx] = findMinActivatingImages(imds,className,predictedScores,numImgsToShow)
% Find the predicted scores of the chosen class on all the images of the chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
[scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores);

% Sort the scores in ascending order
[sortedScores,idx] = sort(scoresForChosenClass,'ascend');

% Return the indices of only the first few
imgIdx = imgsOfClassIdxs(idx(1:numImgsToShow));

end

function [scoresForChosenClass,imgsOfClassIdxs] = findScoresForChosenClass(imds,className,predictedScores)
% Find the index of className (e.g. "sushi" is the 9th class)
uniqueClasses = unique(imds.Labels);
chosenClassIdx = string(uniqueClasses) == className;

% Find the indices in imageDatastore that are images of label "className"
% (e.g. find all images of class sushi)
imgsOfClassIdxs = find(imds.Labels == className);

% Find the predicted scores of the chosen class on all the images of the
% chosen class
% (e.g. predicted scores for sushi on all the images of sushi)
scoresForChosenClass = predictedScores(imgsOfClassIdxs,chosenClassIdx);
end

function plotImages(imds,imgIdx,sortedScores,predictedClasses,numImgsToShow)

for i=1:numImgsToShow
    score = sortedScores(i);
    sortedImgIdx = imgIdx(i);
    predClass = predictedClasses(sortedImgIdx); 
    correctClass = imds.Labels(sortedImgIdx);
        
    imgPath = imds.Files{sortedImgIdx};
    
    if predClass == correctClass
        color = "\color{green}";
    else
        color = "\color{red}";
    end
    
    predClassTitle = strrep(string(predClass),'_',' ');
    correctClassTitle = strrep(string(correctClass),'_',' ');
    
    subplot(3,ceil(numImgsToShow./3),i)
    imshow(imread(imgPath));
    title("Predicted: " + color + predClassTitle + "\newline\color{black}Score: " + num2str(score) + "\newlineGround truth: " + correctClassTitle);
end

end

function plotGradCAM(img,gradcamMap,alpha)

subplot(1,2,1)
imshow(img);

h = subplot(1,2,2);
imshow(img)
hold on;
imagesc(gradcamMap,'AlphaData',alpha);

originalSize2 = get(h,'Position');

colormap jet
colorbar

set(h,'Position',originalSize2);
hold off;
end

参考

| | | | | | | |

関連する例

詳細