Main Content

次元削減のためのシャム ネットワークの学習

この例では、次元削減を使用して手書きの数字を比較するシャム ネットワークの学習方法を説明します。

シャム ネットワークは深層学習ネットワークの一種で、同じアーキテクチャをもち、同じパラメーターと重みを共有する、2 つ以上の同一のサブネットワークを使用します。シャム ネットワークは通常、比較可能な 2 つの事例の関係性を見つけることに関わるタスクに使用されます。シャム ネットワークの一般的な使用例には、顔認識、シグネチャ検証 [1]、またはパラフレーズ識別 [2] などがあります。学習中に重みを共有することで学習させるパラメーターが少なくなり、比較的少量の学習データで良い結果が得られるため、シャム ネットワークはこれらのタスクにおいて良好に動作します。

シャム ネットワークは特に、クラスの数が多く、各クラスの観測値が少ない場合に役立ちます。このような場合、イメージをこれらのクラスに分類するよう深層畳み込みニューラル ネットワークに学習させるだけの十分なデータがありません。代わりに、シャム ネットワークによって 2 つのイメージが同じクラスに属するかどうかを判定することができます。ネットワークは、学習データの次元を削減し、距離ベースのコスト関数を使用してクラス間を区別することでこれを行います。

この例ではシャム ネットワークを使用して、手書きの数字のイメージ コレクションについて次元削減を行います。シャム アーキテクチャは、同じクラスのイメージを低次元空間の近傍点にマッピングすることで次元削減を行います。その後、低次元特徴の表現を使用して、テスト イメージに最も類似したイメージをデータセットから抽出します。この例の学習データは、サイズ 28 x 28 x 1 のイメージで、初期の特徴次元は 784 です。シャム ネットワークは、入力イメージの次元を 2 つの特徴に削減して、同一ラベルのイメージに対して類似の低次元特徴を出力するように学習を行います。

シャム ネットワークを使用すると、類似のイメージを直接比較して識別することもできます。例については、シャム ネットワークの学習とイメージの比較を参照してください。

学習データの読み込みと前処理

手書きの数字のイメージで構成される学習データを読み込みます。関数 digitTrain4DArrayData は、数字のイメージとそのラベルを読み込みます。

[XTrain,TTrain] = digitTrain4DArrayData;

XTrain は、サイズが 28 x 28 のシングル チャネル イメージ 5000 個を含む、28 x 28 x 1 x 5000 の配列です。各ピクセルは 0 から 1 の間の値です。TTrain は、各観測値のラベルが含まれる categorical ベクトルで、ラベルは手書きの数字に対応する 0 から 9 の数字です。

ランダムに選択したイメージを表示します。

perm = randperm(numel(TTrain),9);
imshow(imtile(XTrain(:,:,:,perm),ThumbnailSize=[100 100]));

類似イメージと非類似イメージのペアの作成

ネットワークに学習させるには、データを類似イメージまたは非類似イメージのペアにグループ化しなければなりません。ここで、類似イメージは同じラベルをもつイメージ、非類似イメージは異なるラベルをもつイメージとして定義されます。関数 getSiameseBatch (この例のサポート関数の節で定義) は、類似イメージまたは非類似イメージのランダム化されたペアである pairImage1 pairImage2 を作成します。また、この関数は、イメージのペアが互いに類似か非類似かを識別するラベル pairLabel を返します。イメージの類似ペアの場合は pairLabel = 1、非類似ペアの場合は pairLabel = 0 になります。

一例として、5 つのイメージのペアをもつ、典型的な小さいセットを作成します。

batchSize = 10;
[pairImage1,pairImage2,pairLabel] = getSiameseBatch(XTrain,TTrain,batchSize);

生成されたイメージのペアを表示します。

figure
tiledlayout("flow")
for i = 1:batchSize
    nexttile
    imshow([pairImage1(:,:,:,i) pairImage2(:,:,:,i)]);
    if pairLabel(i) == 1
        s = "similar";
    else
        s = "dissimilar";
    end
    title(s)
end

この例では、学習ループの反復ごとに、イメージのペア 180 個から成る新しいバッチが作成されます。これにより、類似ペアと非類似ペアの比率がほぼ等しい大量のランダムなイメージのペアでネットワークに学習させることができます。

ネットワーク アーキテクチャの定義

シャム ネットワークのアーキテクチャを次の図に示します。

この例では、2 つの同一のサブネットワークが ReLU 層をもつ一連の全結合層として定義されます。28 x 28 x 1 のイメージを受け取り、低次元特徴表現として使用される 2 つの特徴ベクトルを出力するネットワークを作成します。ネットワークは、入力イメージの次元を 2 に削減します。これにより、初期の 784 の次元よりもプロットと可視化を簡単に行えるようになります。

最初の 2 つの全結合層では、出力サイズに 1024 を指定し、He の重みの初期化子を使用します。

最終の全結合層では、出力サイズに 2 を指定し、He の重みの初期化子を使用します。

layers = [
    imageInputLayer([28 28],Normalization="none")
    fullyConnectedLayer(1024,WeightsInitializer="he")
    reluLayer
    fullyConnectedLayer(1024,WeightsInitializer="he")
    reluLayer
    fullyConnectedLayer(2,WeightsInitializer="he")];

カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層配列を dlnetwork オブジェクトに変換します。

net = dlnetwork(layers);

モデル損失関数の定義

関数 modelLoss (この例のサポート関数の節で定義) を作成します。関数 modelLoss は、シャム dlnetwork オブジェクト net と、ミニバッチ入力データ X1 および X2 とそのラベル pairLabels を受け取ります。関数は損失値と、ネットワークの学習可能なパラメーターについての損失の勾配を返します。

シャム ネットワークの目的は、各イメージについて、類似イメージの場合は類似し、非類似イメージの場合は明確に異なるような特徴ベクトルを出力することです。このようにして、ネットワークは 2 つの入力を区別することができます。

最後の全結合層の出力と、pairImage1 および pairImage2 からの特徴ベクトル features1 および features1 との間の対比損失をそれぞれ求めます。ペアの対比損失は [3] で与えられます。

loss=12yd2+12(1-y)max(margin-d,0)2,

ここで、y はペア ラベルの値 (類似イメージの場合は y=1、非類似イメージの場合は y=0) で、d は 2 つの特徴ベクトル f1f2 のユークリッド距離 d=f1-f22 です。

margin パラメーターは制約のために使用されます。ペア内の 2 つのイメージが非類似の場合は、両者の距離が少なくとも margin でなければならず、そうでないと損失が発生します。

対比損失には項が 2 つありますが、与えられたイメージ ペアについて非ゼロになり得るのは、どちらか 1 つだけです。類似イメージの場合は、第 1 項が非ゼロになることができ、イメージの特徴 f1f2 の間の距離を減らすことで最小化されます。非類似イメージの場合は、第 2 項が非ゼロになることができ、イメージの特徴間の距離を増やすことで、少なくとも margin の距離まで最小化されます。margin の値が小さいほど、損失が発生する前に非類似ペアがどれだけ近くなり得るかの制約が緩くなります。

学習オプションの指定

学習中に使用する margin の値を指定します。

margin = 0.3;

学習中に使用するオプションを指定します。3000 回反復して学習させます。

numIterations = 3000;
miniBatchSize = 180;

Adam 最適化のオプションを指定します。

  • 学習率を 0.0001 に設定します。

  • 最後の平均勾配と最後の平均 2 乗勾配減衰率を [] に初期化します。

  • 勾配の減衰係数を 0.9 に、2 乗勾配の減衰係数を 0.99 に設定。

learningRate = 1e-4;
trailingAvg = [];
trailingAvgSq = [];
gradDecay = 0.9;
gradDecaySq = 0.99;

学習損失の進行状況をプロットするためのプロット パラメーターを初期化します。

figure
C = colororder;
lineLossTrain = animatedline(Color=C(2,:));
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

モデルの学習

カスタム学習ループを使用してモデルに学習させます。学習データ全体をループ処理し、各反復でネットワーク パラメーターを更新します。

それぞれの反復で次を行います。

  • イメージ ペアのバッチの作成の節で定義されている関数 getSiameseBatch を使用して、イメージ ペアとラベルのバッチを抽出。

  • 基となる型が singledlarray オブジェクトにイメージ データを変換し、次元ラベルを "SSCB" (spatial、spatial、channel、batch) に指定。

  • GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

  • 関数 dlfeval および modelLoss を使用してモデルの損失と勾配を評価。

  • 関数 adamupdate を使用してネットワーク パラメーターを更新。

start = tic;

% Loop over mini-batches.
for iteration = 1:numIterations

    % Extract mini-batch of image pairs and pair labels
    [X1,X2,pairLabels] = getSiameseBatch(XTrain,TTrain,miniBatchSize);

    % Convert mini-batch of data to dlarray. Specify the dimension labels
    % "SSCB" (spatial, spatial, channel, batch) for image data
    X1 = dlarray(single(X1),"SSCB");
    X2 = dlarray(single(X2),"SSCB");

    % If training on a GPU, then convert data to gpuArray.
    if canUseGPU
        X1 = gpuArray(X1);
        X2 = gpuArray(X2);
    end

    % Evaluate the model loss and gradients using dlfeval and the modelLoss
    % function listed at the end of the example.
    [loss,gradients] = dlfeval(@modelLoss,net,X1,X2,pairLabels,margin);

    % Update the Siamese network parameters.
    [net.Learnables,trailingAvg,trailingAvgSq] = ...
        adamupdate(net.Learnables,gradients, ...
        trailingAvg,trailingAvgSq,iteration,learningRate,gradDecay,gradDecaySq);

    % Update the training loss progress plot.
    D = duration(0,0,toc(start),Format="hh:mm:ss");
    loss = double(loss);
    addpoints(lineLossTrain,iteration,loss)
    title("Elapsed: " + string(D))

    drawnow
end

イメージ類似度の可視化

次元削減におけるネットワークの性能を評価するために、テスト データのセットについて低次元特徴を計算してプロットします。学習データと類似した手書きの数字のイメージで構成されるテスト データを読み込みます。テスト データを dlarray に変換し、次元ラベルを "SSCB" (spatial、spatial、channel、batch) に指定します。GPU を使用している場合は、テスト データを gpuArray に変換します。

[XTest,TTest] = digitTest4DArrayData;
XTest = dlarray(single(XTest),"SSCB");

if canUseGPU
    XTest = gpuArray(XTest);
end

テスト データの低次元特徴を計算します。

FTest = predict(net,XTest);

各グループについて、テスト データの最初の 2 つの低次元特徴をプロットします。

uniqueGroups = unique(TTest);
colors = hsv(length(uniqueGroups));

figure
hold on
for k = 1:length(uniqueGroups)
    ind = TTest==uniqueGroups(k);

    plot(FTest(1,ind),FTest(2,ind),".",Color=colors(k,:));
end
hold off

xlabel("Feature 1")
ylabel("Feature 2")
title("2-D Feature Representation of Digits Images.");

legend(uniqueGroups,Location="eastoutside");

学習済みネットワークを使用した類似イメージの検出

学習済みネットワークを使用して、互いに似ているイメージの集合をグループから検出することが可能です。テスト データからテスト イメージを 1 つ抽出して表示します。

testIdx = randi(5000);
testImg = XTest(:,:,:,testIdx);

trialImgDisp = extractdata(testImg);

figure
imshow(trialImgDisp,InitialMagnification=500);

テスト データを含み、抽出されたテスト イメージは含まないイメージのグループを作成します。

groupX = XTest;
groupX(:,:,:,testIdx) = [];

predict を使用してテスト イメージの低次元特徴を求めます。

trialF = predict(net,testImg);

学習済みネットワークを使用して、グループ内の各イメージについて 2 次元の低次元特徴表現を求めます。

FGroupX = predict(net,groupX);

低次元特徴表現を使用して、ユークリッド距離計量を用いてテスト イメージに最も近いイメージをグループ内から 9 個見つけます。イメージを表示します。

distances = vecnorm(extractdata(trialF - FGroupX));
[~,idx] = sort(distances);
sortedImages = groupX(:,:,:,idx);
sortedImages = extractdata(sortedImages);

figure
imshow(imtile(sortedImages(:,:,:,1:9)),InitialMagnification=500);

イメージの次元を削減することで、ネットワークは、テスト イメージに類似したイメージを識別できます。低次元特徴表現によって、ネットワークによる類似イメージと非類似イメージの識別が可能になります。シャム ネットワークは顔認識やシグネチャ認識のコンテキストでよく使用されます。たとえば、顔のイメージを入力として受け取り、データベースからよく似ている顔のセットを返すように、シャム ネットワークに学習させることができます。

サポート関数

モデル損失関数

関数 modelLoss は、シャム dlnetwork のオブジェクト net、ミニバッチ入力データ X1X2 のペア、およびラベル pairLabels を受け取ります。この関数は、ペアになっているイメージの低次元特徴間の対比損失と、ネットワーク内の学習可能パラメーターについての損失の勾配を返します。この例では、関数 modelLoss についてモデル損失関数の定義の節で紹介されています。

function [loss,gradients] = modelLoss(net,X1,X2,pairLabel,margin)
% The modelLoss function calculates the contrastive loss between the
% paired images and returns the loss and the gradients of the loss with
% respect to the network learnable parameters

% Pass first half of image pairs forward through the network
F1 = forward(net,X1);
% Pass second set of image pairs forward through the network
F2 = forward(net,X2);

% Calculate contrastive loss
loss = contrastiveLoss(F1,F2,pairLabel,margin);

% Calculate gradients of the loss with respect to the network learnable
% parameters
gradients = dlgradient(loss,net.Learnables);

end

function loss = contrastiveLoss(F1,F2,pairLabel,margin)
% The contrastiveLoss function calculates the contrastive loss between
% the reduced features of the paired images

% Define small value to prevent taking square root of 0
delta = 1e-6;

% Find Euclidean distance metric
distances = sqrt(sum((F1 - F2).^2,1) + delta);

% label(i) = 1 if features1(:,i) and features2(:,i) are features
% for similar images, and 0 otherwise
lossSimilar = pairLabel.*(distances.^2);

lossDissimilar = (1 - pairLabel).*(max(margin - distances, 0).^2);

loss = 0.5*sum(lossSimilar + lossDissimilar,"all");

end

イメージ ペアのバッチの作成

次の関数は、ラベルに基づいて類似または非類似のイメージのランダム化されたペアを作成します。この例では、関数 getSiameseBatch について類似イメージと非類似イメージのペアの作成の節で紹介されています。

function [X1,X2,pairLabels] = getSiameseBatch(X,Y,miniBatchSize)
% getSiameseBatch returns a randomly selected batch of paired images.
% On average, this function produces a balanced set of similar and
% dissimilar pairs.
pairLabels = zeros(1, miniBatchSize);
imgSize = size(X(:,:,:,1));
X1 = zeros([imgSize 1 miniBatchSize]);
X2 = zeros([imgSize 1 miniBatchSize]);

for i = 1:miniBatchSize
    choice = rand(1);
    if choice < 0.5
        [pairIdx1, pairIdx2, pairLabels(i)] = getSimilarPair(Y);
    else
        [pairIdx1, pairIdx2, pairLabels(i)] = getDissimilarPair(Y);
    end
    X1(:,:,:,i) = X(:,:,:,pairIdx1);
    X2(:,:,:,i) = X(:,:,:,pairIdx2);
end

end

function [pairIdx1,pairIdx2,pairLabel] = getSimilarPair(classLabel)
% getSimilarPair returns a random pair of indices for images
% that are in the same class and the similar pair label = 1.

% Find all unique classes.
classes = unique(classLabel);

% Choose a class randomly which will be used to get a similar pair.
classChoice = randi(numel(classes));

% Find the indices of all the observations from the chosen class.
idxs = find(classLabel==classes(classChoice));

% Randomly choose two different images from the chosen class.
pairIdxChoice = randperm(numel(idxs),2);
pairIdx1 = idxs(pairIdxChoice(1));
pairIdx2 = idxs(pairIdxChoice(2));
pairLabel = 1;
end

function  [pairIdx1,pairIdx2,pairLabel] = getDissimilarPair(classLabel)
% getDissimilarPair returns a random pair of indices for images
% that are in different classes and the dissimilar pair label = 0.

% Find all unique classes.
classes = unique(classLabel);

% Choose two different classes randomly which will be used to get a dissimilar pair.
classesChoice = randperm(numel(classes), 2);

% Find the indices of all the observations from the first and second classes.
idxs1 = find(classLabel==classes(classesChoice(1)));
idxs2 = find(classLabel==classes(classesChoice(2)));

% Randomly choose one image from each class.
pairIdx1Choice = randi(numel(idxs1));
pairIdx2Choice = randi(numel(idxs2));
pairIdx1 = idxs1(pairIdx1Choice);
pairIdx2 = idxs2(pairIdx2Choice);
pairLabel = 0;
end

参考文献

  1. Bromley, J., I. Guyon, Y. LeCun, E. Säckinger, and R. Shah. "Signature Verification using a "Siamese" Time Delay Neural Network." In Proceedings of the 6th International Conference on Neural Information Processing Systems (NIPS 1993), 1994, pp737-744. Available at Signature Verification using a "Siamese" Time Delay Neural Network on the NIPS Proceedings website.

  2. Wenpeg, Y., and H Schütze. "Convolutional Neural Network for Paraphrase Identification." In Proceedings of 2015 Conference of the North American Cahapter of the ACL, 2015, pp901-911.Available at Convolutional Neural Network for Paraphrase Identification on the ACL Anthology website.

  3. Hadsell, R., S. Chopra, and Y. LeCun. "Dimensionality Reduction by Learning an Invariant Mapping." In Proceedings of the 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR 2006), 2006, pp1735-1742.

参考

| | | |

関連するトピック