Main Content

tsne を使用したネットワークの動作の表示

この例では、関数 tsne を使用して学習済みネットワークの活性化を表示する方法を説明します。この表示は、ネットワークの動作を理解するのに役立ちます。

Statistics and Machine Learning Toolbox™ の関数 tsne (Statistics and Machine Learning Toolbox) は、t 分布型確率的近傍埋め込み法 (t-SNE) [1] を実装します。この手法では、高次元のデータ (層のネットワーク活性化など) が 2 次元にマッピングされます。この手法では距離を保持しようとする非線形のマッピングが使用されます。t-SNE を使用してネットワーク活性化を可視化することにより、ネットワークがどのように応答するかを理解できます。

t-SNE を使用して、入力データがネットワーク層を通過するときに、深層学習ネットワークが入力データの表現をどのように変化させるかを可視化できます。t-SNE を使用すると、入力データの問題を検出したり、ネットワークによって正しく分類されていない観測値を把握することもできます。

たとえば、t-SNE でソフトマックス層の多次元の活性化を減らして、同様の構造をもつ 2 次元表現にすることができます。結果の t-SNE プロットの密なクラスターは、ネットワークによって通常正しく分類されるクラスに対応しています。可視化を使用すると、誤ったクラスターに表示されている点を検出できます。これは、ネットワークによって正しく分類されていない観測値を示します。観測値に誤ったラベルが付けられる場合があり、また、そのクラスの他の観測値に似ていることが原因で、ネットワークによって別のクラスのインスタンスであると予測される場合があります。t-SNE によるソフトマックス活性化の減少ではこれらの活性化のみが使用され、基となる観測値は使用されないことに注意してください。

データセットのダウンロード

この例では、食品イメージのサンプルというデータセットを使用します。このデータセットには 9 クラスの 978 枚の食品の写真が含まれており、サイズは約 77 MB です。補助関数 downloadExampleFoodImagesData を呼び出して、データセットを一時ディレクトリにダウンロードします。この補助関数のコードは、この例の終わりに示します。

dataDir = fullfile(tempdir, "ExampleFoodImageDataset");
url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip";

if ~exist(dataDir, "dir")
    mkdir(dataDir);
end

downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset...
This can take several minutes to download...
Download finished...
Unzipping file...
Unzipping finished...
Done.

食品イメージを分類するネットワークの学習

データセットに含まれる食品のイメージを分類するように、SqueezeNet 事前学習済みネットワークを変更します。ImageNet の 1,000 個のクラスに対する 1,000 個のフィルターを備えた最終畳み込み層を、9 個のフィルターのみを備えた新しい畳み込み層に置き換えます。各フィルターは食品の 1 つのタイプに対応しています。

lgraph = layerGraph(squeezenet());
lgraph = lgraph.replaceLayer("ClassificationLayer_predictions",...
    classificationLayer("Name", "ClassificationLayer_predictions"));

newConv =  convolution2dLayer([14 14], 9, "Name", "conv", "Padding", "same");
lgraph = lgraph.replaceLayer("conv10", newConv);

イメージ データへのパスが含まれる imageDatastore を作成します。データの 65% を学習に使用し、残りを検証に使用して、データストアを学習セットと検証セットに分割します。データセットが非常に小さいため、過適合が深刻な問題になります。過適合を最小限に抑えるため、ランダムな反転とスケーリングで学習セットを拡張します。

imds = imageDatastore(dataDir, ...
    "IncludeSubfolders", true, "LabelSource", "foldernames");

aug = imageDataAugmenter("RandXReflection", true, ...
    "RandYReflection", true, ...
    "RandXScale", [0.8 1.2], ...
    "RandYScale", [0.8 1.2]);

trainingFraction = 0.65;
[trainImds,valImds] = splitEachLabel(imds, trainingFraction);

augImdsTrain = augmentedImageDatastore([227 227], trainImds, ...
    'DataAugmentation', aug);
augImdsVal = augmentedImageDatastore([227 227], valImds);

学習オプションを作成し、ネットワークに学習させます。SqueezeNet は簡単に学習させることができる小さいネットワークです。GPU または CPU で学習させることができますが、この例では CPU で学習させます。

opts = trainingOptions("adam", ...
    "InitialLearnRate", 1e-4, ...
    "MaxEpochs", 30, ...
    "ValidationData", augImdsVal, ...
    "Verbose", false,...
    "Plots", "training-progress", ...
    "ExecutionEnvironment","cpu",...
    "MiniBatchSize",128);
rng default
net = trainNetwork(augImdsTrain, lgraph, opts);

検証データの分類

ネットワークを使用して、検証セットのイメージを分類します。ネットワークが妥当な精度で新しいデータを分類することを確認するため、真のラベルと予測ラベルの混同行列をプロットします。

figure();
YPred = classify(net,augImdsVal);
confusionchart(valImds.Labels,YPred,'ColumnSummary',"column-normalized")

このネットワークはさまざまなイメージを適切に分類しています。このネットワークは寿司のイメージについて問題があると考えられます。多くを寿司として分類していますが、一部をピザやハンバーガーとして分類しているためです。ホットドッグ クラスに分類されているイメージはありません。

複数の層の活性化の計算

引き続きネットワーク性能を解析するため、初期の最大プーリング層、最終畳み込み層、および最終ソフトマックス層でデータセットのすべての観測値について活性化を計算します。活性化を N 行 M 列の行列として出力します。ここで、N は観測値の数、M は活性化の次元数です。M は空間次元とチャネルの次元の積になります。各行は観測値で、各列は次元です。食品データセットには 9 個のクラスがあるため、ソフトマックス層では M = 9 になります。行列の各行には 9 個の要素が含まれます。これは観測値が食品の 9 個の各クラスに属する確率に対応します。

earlyLayerName = "pool1";
finalConvLayerName = "conv";
softmaxLayerName = "prob";
pool1Activations = activations(net,...
    augImdsVal,earlyLayerName,"OutputAs","rows");
finalConvActivations = activations(net,...
    augImdsVal,finalConvLayerName,"OutputAs","rows");
softmaxActivations = activations(net,...
    augImdsVal,softmaxLayerName,"OutputAs","rows");

分類のあいまいさ

ソフトマックス活性化を使用して、間違っている可能性が最も高いイメージ分類を計算します。分類の "あいまいさ" を、最も高い確率に対する 2 番目に高い確率の比率として定義します。分類のあいまいさは、0 (ほぼ確実に特定のクラスに分類される) と 1 (最も有力なクラスに分類される確率と 2 番目のクラスに分類される確率がほぼ同じ) の間になります。あいまいさが 1 に近い場合、そのネットワークでは特定のイメージが属するクラスを明確に判定できないことを意味します。このようなあいまいさは、類似した観測値をもつ 2 つのクラスがあり、ネットワークがこれらの違いを学習できないことが原因で発生します。または、特定の観測値に複数のクラスの要素が含まれ、どの分類が正しいかネットワークが判定できないことが原因で、あいまいさが発生する場合もあります。あいまいさが低いことは、必ずしも分類が正しいことを意味しません。ネットワークによってあるクラスの確率が高いと判定された場合でも、分類が間違っていることがあります。

[R,RI] = maxk(softmaxActivations,2,2);
ambiguity = R(:,2)./R(:,1);

最もあいまいなイメージを見つけます。

[ambiguity,ambiguityIdx] = sort(ambiguity,"descend");

あいまいなイメージの最も確率の高いクラスと真のクラスを表示します。

classList = unique(valImds.Labels);
top10Idx = ambiguityIdx(1:10);
top10Ambiguity = ambiguity(1:10);
mostLikely = classList(RI(ambiguityIdx,1));
secondLikely = classList(RI(ambiguityIdx,2));
table(top10Idx,top10Ambiguity,mostLikely(1:10),secondLikely(1:10),valImds.Labels(ambiguityIdx(1:10)),...
    'VariableNames',["Image #","Ambiguity","Likeliest","Second","True Class"])
ans=10×5 table
    Image #    Ambiguity    Likeliest       Second        True Class 
    _______    _________    _________    ____________    ____________

       94        0.9879     hamburger    pizza           hamburger   
      175       0.96311     hamburger    french_fries    hot_dog     
      179       0.94939     pizza        hamburger       hot_dog     
      337       0.93426     sushi        sashimi         sushi       
      256       0.92972     sushi        pizza           pizza       
      297       0.91776     sushi        sashimi         sashimi     
      283       0.80407     pizza        sushi           pizza       
       27       0.80278     hamburger    pizza           french_fries
      302       0.79283     sashimi      sushi           sushi       
      201       0.76034     pizza        greek_salad     pizza       

ネットワークは、イメージ 27 がハンバーガーまたはピザである可能性が高いと予測しています。しかし実際には、このイメージはフライド ポテトです。イメージを表示して、このような誤分類が発生した理由を確認します。

v = 27;
figure();
imshow(valImds.Files{v});
title(sprintf("Observation: %i\n" + ...
    "Actual: %s. Predicted: %s", v, ...
    string(valImds.Labels(v)), string(YPred(v))), ...
    'Interpreter', 'none');

イメージには複数の独立した領域が含まれており、そのいくつかによってネットワークに混乱が生じている可能性があります。

t-SNE を使用したデータの 2 次元表現の計算

初期の最大プーリング層、最終畳み込み層、および最終ソフトマックス層について、ネットワーク データの低次元表現を計算します。関数 tsne を使用して、活性化データの次元を M から 2 に減らします。活性化の次元が大きくなるほど、t-SNE の計算にかかる時間は長くなります。そのため、活性化の次元が 200,704 の初期の最大プーリング層は、最終ソフトマックス層より計算に時間がかかります。t-SNE の結果の再現性について乱数シードを設定します。

rng default
pool1tsne = tsne(pool1Activations);
finalConvtsne = tsne(finalConvActivations);
softmaxtsne = tsne(softmaxActivations);

初期の層と後の層でのネットワークの動作の比較

t-SNE 手法では距離の保持が試みられるため、高次元表現で近接している点は、低次元表現でも近接しています。混同行列からわかるように、ネットワークは異なるクラスへの分類では効果を発揮します。そのため、シーザー サラダとカプレーゼ サラダなど、意味的に類似している (または同じタイプの) イメージは、ソフトマックス活性化の空間で近接しています。t-SNE ではこの近さが 2 次元表現で取得されるため、9 次元のソフトマックス スコアより理解しやすく、プロットが容易になります。

初期の層は、エッジや色など、低レベルの特徴に対して作用する傾向があります。深い層は、ピザとホットドッグの違いなど、よりセマンティックな意味をもつ高レベルの特徴で学習しています。そのため、初期の層からの活性化はクラスによるクラスタリングを示しません。ピクセル単位で類似する (たとえば、どちらも緑のピクセルを多く含んでいる) 2 つのイメージは、セマンティックなコンテンツに関係なく、活性化の高次元空間で近接しています。後の層からの活性化は、同じクラスの点を一緒にクラスタリングする傾向があります。この動作はソフトマックス層で最も顕著であり、2 次元の t-SNE 表現で維持されます。

関数 gscatter を使用して、初期の最大プーリング層、最終畳み込み層、および最終ソフトマックス層について t-SNE データをプロットします。初期の最大プーリング層の活性化は、同じクラスのイメージ間のクラスタリングを示していないことがわかります。最終畳み込み層の活性化はクラスによるクラスタリングがある程度行われていますが、ソフトマックス活性化ほどではありません。それぞれの色は各クラスの観測値に対応しています。

doLegend = 'off';
markerSize = 7;
figure;

subplot(1,3,1);
gscatter(pool1tsne(:,1),pool1tsne(:,2),valImds.Labels, ...
    [],'.',markerSize,doLegend);
title("Max pooling activations");

subplot(1,3,2);
gscatter(finalConvtsne(:,1),finalConvtsne(:,2),valImds.Labels, ...
    [],'.',markerSize,doLegend);
title("Final conv activations");

subplot(1,3,3);
gscatter(softmaxtsne(:,1),softmaxtsne(:,2),valImds.Labels, ...
    [],'.',markerSize,doLegend);
title("Softmax activations");

t-SNE プロットでの観測値の確認

各クラスのラベル付けの凡例を含む、ソフトマックス活性化の大きいプロットを作成します。t-SNE プロットから、事後確率分布の構造についてより多くのことがわかります。

たとえば、このプロットにはフライド ポテトの観測値の独立した個別のクラスターが表示されていますが、刺身と寿司のクラスターは適切に解決されていません。このプロットは、混同行列と同様に、フライド ポテトのクラスの予測でネットワークの精度がより高くなることを示しています。

numClasses = length(classList);
colors = lines(numClasses);
h = figure;
gscatter(softmaxtsne(:,1),softmaxtsne(:,2),valImds.Labels,colors);

l = legend;
l.Interpreter = "none";
l.Location = "bestoutside";

t-SNE を使用して、ネットワークによって誤分類されるイメージとその理由を判定できます。間違った観測値は多くの場合、周囲のクラスターに対する誤った色の孤立点として表示されています。たとえば、誤分類されたハンバーガーのイメージはフライド ポテトの領域に非常に近接しています (オレンジのクラスターの中心に最も近い緑の点)。この点は観測値 99 です。t-SNE プロットでこの観測値を円で囲み、imshow でイメージを表示します。

obs = 99;
figure(h)
hold on;
hs = scatter(softmaxtsne(obs, 1), softmaxtsne(obs, 2), ...
    'black','LineWidth',1.5);
l.String{end} = 'Hamburger';
hold off;
figure();
imshow(valImds.Files{obs});
title(sprintf("Observation: %i\n" + ...
    "Actual: %s. Predicted: %s", obs, ...
    string(valImds.Labels(obs)), string(YPred(obs))), ...
    'Interpreter', 'none');

イメージに複数のタイプの食品が含まれている場合、ネットワークに混乱が生じる可能性があります。この場合、前景の食品がハンバーガーであるにもかかわらず、ネットワークはイメージをフライド ポテトとして分類しています。イメージのエッジに表示されているフライド ポテトによって混乱が生じています。

同様に、(この例で前述した) あいまいなイメージ 27 には複数の領域があります。このフライド ポテトのイメージのあいまいな側面が強調表示された t-SNE プロットを確認します。

obs = 27;
figure(h)
hold on;
h = scatter(softmaxtsne(obs, 1), softmaxtsne(obs, 2), ...
    'k','d','LineWidth',1.5);
l.String{end} = 'French Fries';
hold off;

このイメージは、適切に定義されたクラスターでプロットに表示されていません。これは、分類が間違っている可能性が高いことを示します。このイメージはフライド ポテトのクラスターからは遠く、ハンバーガーのクラスターの近くにあります。

誤分類される "理由" をその他の情報で示さなければなりません。この情報には通常、イメージのコンテンツに基づく仮説を使用します。その後、その他のデータを使用するか、またはイメージのどの空間領域がネットワークによる分類で重要であるかを示すツールを使用して、仮説をテストします。例については、occlusionSensitivityGrad-CAM での深層学習による判定の理由の解明を参照してください。

参考文献

[1] van der Maaten, Laurens, and Geoffrey Hinton. "Visualizing Data using t-SNE." Journal of Machine Learning Research 9, 2008, pp. 2579–2605.

補助関数

function downloadExampleFoodImagesData(url, dataDir)
% Download the Example Food Image data set, containing 978 images of
% different types of food split into 9 classes.

% Copyright 2019 The MathWorks, Inc.

fileName = "ExampleFoodImageDataset.zip";
fileFullPath = fullfile(dataDir, fileName);

% Download the .zip file into a temporary directory.
if ~exist(fileFullPath, "file")
    fprintf("Downloading MathWorks Example Food Image dataset...\n");
    fprintf("This can take several minutes to download...\n");
    websave(fileFullPath, url);
    fprintf("Download finished...\n");
else
    fprintf("Skipping download, file already exists...\n");
end

% Unzip the file.
%
% Check if the file has already been unzipped by checking for the presence
% of one of the class directories.
exampleFolderFullPath = fullfile(dataDir, "pizza");
if ~exist(exampleFolderFullPath, "dir")
    fprintf("Unzipping file...\n");
    unzip(fileFullPath, dataDir);
    fprintf("Unzipping finished...\n");
else
    fprintf("Skipping unzipping, file already unzipped...\n");
end
fprintf("Done.\n");

end

参考

| | | | | | | (Statistics and Machine Learning Toolbox)

関連するトピック