tsne
を使用したネットワークの動作の表示
この例では、関数 tsne
を使用して学習済みネットワークの活性化を表示する方法を説明します。この表示は、ネットワークの動作を理解するのに役立ちます。
Statistics and Machine Learning Toolbox™ の関数 tsne
(Statistics and Machine Learning Toolbox) は、t 分布型確率的近傍埋め込み法 (t-SNE) [1] を実装します。この手法では、高次元のデータ (層のネットワーク活性化など) が 2 次元にマッピングされます。この手法では距離を保持しようとする非線形のマッピングが使用されます。t-SNE を使用してネットワーク活性化を可視化することにより、ネットワークがどのように応答するかを理解できます。
t-SNE を使用して、入力データがネットワーク層を通過するときに、深層学習ネットワークが入力データの表現をどのように変化させるかを可視化できます。t-SNE を使用すると、入力データの問題を検出したり、ネットワークによって正しく分類されていない観測値を把握することもできます。
たとえば、t-SNE でソフトマックス層の多次元の活性化を減らして、同様の構造をもつ 2 次元表現にすることができます。結果の t-SNE プロットの密なクラスターは、ネットワークによって通常正しく分類されるクラスに対応しています。可視化を使用すると、誤ったクラスターに表示されている点を検出できます。これは、ネットワークによって正しく分類されていない観測値を示します。観測値に誤ったラベルが付けられる場合があり、また、そのクラスの他の観測値に似ていることが原因で、ネットワークによって別のクラスのインスタンスであると予測される場合があります。t-SNE によるソフトマックス活性化の減少ではこれらの活性化のみが使用され、基となる観測値は使用されないことに注意してください。
データ セットのダウンロード
この例では、食品イメージのサンプルというデータ セットを使用します。このデータ セットには 9 クラスの 978 個の食品の写真が含まれており、サイズは約 77 MB です。補助関数 downloadExampleFoodImagesData
を呼び出して、データ セットを一時ディレクトリにダウンロードします。この補助関数のコードは、この例の終わりに示します。
dataDir = fullfile(tempdir,"ExampleFoodImageDataset"); url = "https://www.mathworks.com/supportfiles/nnet/data/ExampleFoodImageDataset.zip"; if ~exist(dataDir,"dir") mkdir(dataDir); end downloadExampleFoodImagesData(url,dataDir);
Downloading MathWorks Example Food Image dataset... This can take several minutes to download... Download finished... Unzipping file... Unzipping finished... Done.
食品イメージを分類するネットワークの学習
イメージ データへのパスが含まれる imageDatastore
を作成します。データの 65% を学習に使用し、残りを検証に使用して、データストアを学習セットと検証セットに分割します。データ セットが非常に小さいため、過適合が深刻な問題になります。過適合を最小限に抑えるため、ランダムな反転とスケーリングで学習セットを拡張します。
imds = imageDatastore(dataDir, ... IncludeSubfolders=true,LabelSource="foldernames"); aug = imageDataAugmenter(RandXReflection=true, ... RandYReflection=true, ... RandXScale=[0.8 1.2], ... RandYScale=[0.8 1.2]); trainingFraction = 0.65; [trainImds,valImds] = splitEachLabel(imds,trainingFraction); augImdsTrain = augmentedImageDatastore([227 227],trainImds, ... DataAugmentation=aug); augImdsVal = augmentedImageDatastore([227 227],valImds); classNames = categories(trainImds.Labels); numClasses = numel(classNames);
事前学習済みの SqueezeNet ネットワークと対応するクラス名を読み込みます。これには、Deep Learning Toolbox™ Model for SqueezeNet Network サポート パッケージが必要です。このサポート パッケージがインストールされていない場合、ソフトウェアによってダウンロード用リンクが表示されます。使用可能なすべてのネットワークについては、事前学習済みの深層ニューラル ネットワークを参照してください。新しいデータの再学習の準備ができているニューラル ネットワークを返すには、クラスの数も指定します。
net = imagePretrainedNetwork("squeezenet",NumClasses=numClasses);
学習オプションを作成します。
opts = trainingOptions("adam", ... InitialLearnRate=1e-4, ... MaxEpochs=50, ... ValidationData=augImdsVal, ... Verbose=false,... Plots="training-progress", ... MiniBatchSize=128,... Metrics="accuracy"); rng default
関数 trainnet
を使用してニューラル ネットワークに学習させます。分類には、クロスエントロピー損失関数を使用します。既定では、関数 trainnet
は利用可能な GPU がある場合にそれを使用します。GPU での学習には、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数 trainnet
は CPU を使用します。
net = trainnet(augImdsTrain,net,"crossentropy",opts);
検証データの分類
テスト イメージを分類します。複数の観測値を使用して予測を行うには、関数 minibatchpredict
を使用します。予測スコアをラベルに変換するには、関数 scores2label
を使用します。関数 minibatchpredict
は利用可能な GPU がある場合に自動的にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。
figure();
scores = minibatchpredict(net,augImdsVal);
YPred = scores2label(scores,classNames);
confusionchart(valImds.Labels,YPred,ColumnSummary="column-normalized")
このネットワークはさまざまなイメージを適切に分類しています。このネットワークは寿司のイメージについて問題があると考えられます。多くを寿司として分類していますが、一部をピザやハンバーガーとして分類しているためです。ホットドッグ クラスに分類されているイメージはありません。
複数の層の活性化の計算
引き続きネットワーク性能を解析するため、初期の最大プーリング層、最終畳み込み層、および最終ソフトマックス層でデータ セットのすべての観測値について活性化を計算します。活性化を N 行 M 列の行列として出力します。ここで、N は観測値の数、M は活性化の次元数です。M は空間次元とチャネルの次元の積になります。各行は観測値で、各列は次元です。食品データ セットには 9 個のクラスがあるため、ソフトマックス層では M = 9 になります。行列の各行には 9 個の要素が含まれます。これは観測値が食品の 9 個の各クラスに属する確率に対応します。
earlyLayerName = "pool1"; finalConvLayerName = "conv10"; softmaxLayerName = "prob_flatten"; pool1Activations = minibatchpredict(net,... augImdsVal,Outputs=earlyLayerName); finalConvActivations = minibatchpredict(net,... augImdsVal,Outputs=finalConvLayerName); softmaxActivations = minibatchpredict(net,... augImdsVal,Outputs=softmaxLayerName);
2 次元で活性化行列を形状変更します。t-SNE 関数には 2 次元の入力が必要です。
numValObservations = augImdsVal.numobservations; pool1Activations = reshape(pool1Activations,numValObservations,[]); finalConvActivations = reshape(finalConvActivations,numValObservations,[]); softmaxActivations = reshape(softmaxActivations,numValObservations,[]);
分類のあいまいさ
ソフトマックス活性化を使用して、間違っている可能性が最も高いイメージ分類を計算します。分類の "あいまいさ" を、最も高い確率に対する 2 番目に高い確率の比率として定義します。分類のあいまいさは、0 (ほぼ確実に特定のクラスに分類される) と 1 (最も有力なクラスに分類される確率と 2 番目のクラスに分類される確率がほぼ同じ) の間になります。あいまいさが 1 に近い場合、そのネットワークでは特定のイメージが属するクラスを明確に判定できないことを意味します。このようなあいまいさは、類似した観測値をもつ 2 つのクラスがあり、ネットワークがこれらの違いを学習できないことが原因で発生します。または、特定の観測値に複数のクラスの要素が含まれ、どの分類が正しいかネットワークが判定できないことが原因で、あいまいさが発生する場合もあります。あいまいさが低いことは、必ずしも分類が正しいことを意味しません。ネットワークによってあるクラスの確率が高いと判定された場合でも、分類が間違っていることがあります。
[R,RI] = maxk(softmaxActivations,2,2); ambiguity = R(:,2)./R(:,1);
最もあいまいなイメージを見つけます。
[ambiguity,ambiguityIdx] = sort(ambiguity,"descend");
あいまいなイメージの最も確率の高いクラスと真のクラスを表示します。
classList = unique(valImds.Labels); top10Idx = ambiguityIdx(1:10); top10Ambiguity = ambiguity(1:10); mostLikely = classList(RI(ambiguityIdx,1)); secondLikely = classList(RI(ambiguityIdx,2)); table(top10Idx,top10Ambiguity,mostLikely(1:10),secondLikely(1:10),valImds.Labels(ambiguityIdx(1:10)),... VariableNames=["Image #","Ambiguity","Likeliest","Second","True Class"])
ans=10×5 table
Image # Ambiguity Likeliest Second True Class
_______ _________ ___________ _____________ ____________
180 0.96684 sashimi pizza pizza
26 0.95833 sushi pizza french_fries
291 0.92104 sashimi hamburger sashimi
286 0.91898 pizza hot_dog sashimi
322 0.88626 hot_dog sushi sushi
89 0.87801 hamburger french_fries hamburger
242 0.8727 greek_salad sushi pizza
84 0.85619 greek_salad sushi greek_salad
245 0.85266 greek_salad pizza pizza
145 0.84436 sushi caprese_salad hamburger
ネットワークは、イメージ 27 がハンバーガーまたはピザである可能性が高いと予測しています。しかし実際には、このイメージはフライド ポテトです。イメージを表示して、このような誤分類が発生した理由を確認します。
v = 27; figure(); imshow(valImds.Files{v}); title(sprintf("Observation: %i\n" + ... "Actual: %s. Predicted: %s", v, ... string(valImds.Labels(v)),string(YPred(v))), ... Interpreter="none");
イメージには複数の独立した領域が含まれており、そのいくつかによってネットワークに混乱が生じている可能性があります。
t-SNE を使用したデータの 2 次元表現の計算
初期の最大プーリング層、最終畳み込み層、および最終ソフトマックス層について、ネットワーク データの低次元表現を計算します。関数 tsne
を使用して、活性化データの次元を M から 2 に減らします。活性化の次元が大きくなるほど、t-SNE の計算にかかる時間は長くなります。そのため、活性化の次元が 200,704 の初期の最大プーリング層は、最終ソフトマックス層より計算に時間がかかります。t-SNE の結果の再現性について乱数シードを設定します。
rng default
pool1tsne = tsne(pool1Activations);
finalConvtsne = tsne(finalConvActivations);
softmaxtsne = tsne(softmaxActivations);
初期の層と後の層でのネットワークの動作の比較
t-SNE 手法では距離の保持が試みられるため、高次元表現で近接している点は、低次元表現でも近接しています。混同行列からわかるように、ネットワークは異なるクラスへの分類では効果を発揮します。そのため、シーザー サラダとカプレーゼ サラダなど、意味的に類似している (または同じタイプの) イメージは、ソフトマックス活性化の空間で近接しています。t-SNE ではこの近さが 2 次元表現で取得されるため、9 次元のソフトマックス スコアより理解しやすく、プロットが容易になります。
初期の層は、エッジや色など、低レベルの特徴に対して作用する傾向があります。深い層は、ピザとホットドッグの違いなど、よりセマンティックな意味をもつ高レベルの特徴で学習しています。そのため、初期の層からの活性化はクラスによるクラスタリングを示しません。ピクセル単位で類似する (たとえば、どちらも緑のピクセルを多く含んでいる) 2 つのイメージは、セマンティックなコンテンツに関係なく、活性化の高次元空間で近接しています。後の層からの活性化は、同じクラスの点を一緒にクラスタリングする傾向があります。この動作はソフトマックス層で最も顕著であり、2 次元の t-SNE 表現で維持されます。
関数 gscatter
を使用して、初期の最大プーリング層、最終畳み込み層、および最終ソフトマックス層について t-SNE データをプロットします。初期の最大プーリング層の活性化は、同じクラスのイメージ間のクラスタリングを示していないことがわかります。最終畳み込み層の活性化はクラスによるクラスタリングがある程度行われていますが、ソフトマックス活性化ほどではありません。それぞれの色は各クラスの観測値に対応しています。
doLegend = "off"; markerSize = 7; figure; subplot(1,3,1); gscatter(pool1tsne(:,1),pool1tsne(:,2),valImds.Labels, ... [],'.',markerSize,doLegend); title("Max pooling \newline activations"); subplot(1,3,2); gscatter(finalConvtsne(:,1),finalConvtsne(:,2),valImds.Labels, ... [],'.',markerSize,doLegend); title("Final conv \newline activations"); subplot(1,3,3); gscatter(softmaxtsne(:,1),softmaxtsne(:,2),valImds.Labels, ... [],'.',markerSize,doLegend); title("Softmax \newline activations");
t-SNE プロットでの観測値の確認
各クラスのラベル付けの凡例を含む、ソフトマックス活性化の大きいプロットを作成します。t-SNE プロットから、事後確率分布の構造についてより多くのことがわかります。
たとえば、このプロットにはフライド ポテトの観測値の独立した個別のクラスターが表示されていますが、刺身と寿司のクラスターは適切に解決されていません。このプロットは、混同行列と同様に、フライド ポテトのクラスの予測でネットワークの精度がより高くなることを示しています。
numClasses = length(classList); colors = lines(numClasses); h = figure; gscatter(softmaxtsne(:,1),softmaxtsne(:,2),valImds.Labels,colors); l = legend; l.Interpreter = "none"; l.Location = "bestoutside";
t-SNE を使用して、ネットワークによって誤分類されるイメージとその理由を判定できます。間違った観測値は多くの場合、周囲のクラスターに対する誤った色の孤立点として表示されています。たとえば、誤分類されたハンバーガーのイメージはフライド ポテトの領域に非常に近接しています (オレンジのクラスターの中心に最も近い緑の点)。この点は観測値 99 です。t-SNE プロットでこの観測値を円で囲み、imshow
でイメージを表示します。
obs =99; figure(h) hold on; hs = scatter(softmaxtsne(obs, 1), softmaxtsne(obs, 2), ... 'black',LineWidth=1.5); l.String{end} = "Hamburger"; hold off; figure(); imshow(valImds.Files{obs}); title(sprintf("Observation: %i\n" + ... "Actual: %s. Predicted: %s",obs, ... string(valImds.Labels(obs)),string(YPred(obs))), ... Interpreter="none");
イメージに複数のタイプの食品が含まれている場合、ネットワークに混乱が生じる可能性があります。この場合、前景の食品がハンバーガーであるにもかかわらず、ネットワークはイメージをフライド ポテトとして分類しています。イメージのエッジに表示されているフライド ポテトによって混乱が生じています。
同様に、(この例で前述した) あいまいなイメージ 27 には複数の領域があります。このフライド ポテトのイメージのあいまいな側面が強調表示された t-SNE プロットを確認します。
obs =27; figure(h) hold on; h = scatter(softmaxtsne(obs, 1), softmaxtsne(obs, 2), ... 'k','d',LineWidth=1.5); l.String{end} = "French Fries"; hold off;
このイメージは、適切に定義されたクラスターでプロットに表示されていません。これは、分類が間違っている可能性が高いことを示します。このイメージはフライド ポテトのクラスターからは遠く、ハンバーガーのクラスターの近くにあります。
誤分類される "理由" をその他の情報で示さなければなりません。この情報には通常、イメージのコンテンツに基づく仮説を使用します。その後、その他のデータを使用するか、またはイメージのどの空間領域がネットワークによる分類で重要であるかを示すツールを使用して、仮説をテストします。例については、occlusionSensitivity
とGrad-CAM での深層学習による判定の理由の解明を参照してください。
参考文献
[1] van der Maaten, Laurens, and Geoffrey Hinton. "Visualizing Data using t-SNE." Journal of Machine Learning Research 9, 2008, pp. 2579–2605.
補助関数
function downloadExampleFoodImagesData(url, dataDir) % Download the Example Food Image data set, containing 978 images of % different types of food split into 9 classes. % Copyright 2019 The MathWorks, Inc. fileName = "ExampleFoodImageDataset.zip"; fileFullPath = fullfile(dataDir, fileName); % Download the .zip file into a temporary directory. if ~exist(fileFullPath, "file") fprintf("Downloading MathWorks Example Food Image dataset...\n"); fprintf("This can take several minutes to download...\n"); websave(fileFullPath, url); fprintf("Download finished...\n"); else fprintf("Skipping download, file already exists...\n"); end % Unzip the file. % % Check if the file has already been unzipped by checking for the presence % of one of the class directories. exampleFolderFullPath = fullfile(dataDir, "pizza"); if ~exist(exampleFolderFullPath, "dir") fprintf("Unzipping file...\n"); unzip(fileFullPath, dataDir); fprintf("Unzipping finished...\n"); else fprintf("Skipping unzipping, file already unzipped...\n"); end fprintf("Done.\n"); end
参考
imagePretrainedNetwork
| dlnetwork
| trainingOptions
| trainnet
| predict
| forward
| occlusionSensitivity
| minibatchpredict
| scores2label
| tsne
(Statistics and Machine Learning Toolbox)