次元削減のためのシャム ネットワークの学習
この例では、次元削減を使用して手書きの数字を比較するシャム ネットワークの学習方法を説明します。
シャム ネットワークは深層学習ネットワークの一種で、同じアーキテクチャをもち、同じパラメーターと重みを共有する、2 つ以上の同一のサブネットワークを使用します。シャム ネットワークは通常、比較可能な 2 つの事例の関係性を見つけることに関わるタスクに使用されます。シャム ネットワークの一般的な使用例には、顔認識、シグネチャ検証 [1]、またはパラフレーズ識別 [2] などがあります。学習中に重みを共有することで学習させるパラメーターが少なくなり、比較的少量の学習データで良い結果が得られるため、シャム ネットワークはこれらのタスクにおいて良好に動作します。
シャム ネットワークは特に、クラスの数が多く、各クラスの観測値が少ない場合に役立ちます。このような場合、イメージをこれらのクラスに分類するよう深層畳み込みニューラル ネットワークに学習させるだけの十分なデータがありません。代わりに、シャム ネットワークによって 2 つのイメージが同じクラスに属するかどうかを判定することができます。ネットワークは、学習データの次元を削減し、距離ベースのコスト関数を使用してクラス間を区別することでこれを行います。
この例ではシャム ネットワークを使用して、手書きの数字のイメージ コレクションについて次元削減を行います。シャム アーキテクチャは、同じクラスのイメージを低次元空間の近傍点にマッピングすることで次元削減を行います。その後、低次元特徴の表現を使用して、テスト イメージに最も類似したイメージをデータセットから抽出します。この例の学習データは、サイズ 28 x 28 x 1 のイメージで、初期の特徴次元は 784 です。シャム ネットワークは、入力イメージの次元を 2 つの特徴に削減して、同一ラベルのイメージに対して類似の低次元特徴を出力するように学習を行います。
シャム ネットワークを使用すると、類似のイメージを直接比較して識別することもできます。例については、シャム ネットワークの学習とイメージの比較を参照してください。
学習データの読み込みと前処理
手書きの数字のイメージで構成される学習データを読み込みます。関数 digitTrain4DArrayData
は、数字のイメージとそのラベルを読み込みます。
[XTrain,TTrain] = digitTrain4DArrayData;
XTrain
は、サイズが 28 x 28 のシングル チャネル イメージ 5000 個を含む、28 x 28 x 1 x 5000 の配列です。各ピクセルは 0
から 1
の間の値です。TTrain
は、各観測値のラベルが含まれる categorical ベクトルで、ラベルは手書きの数字に対応する 0 から 9 の数字です。
ランダムに選択したイメージを表示します。
perm = randperm(numel(TTrain),9); imshow(imtile(XTrain(:,:,:,perm),ThumbnailSize=[100 100]));
類似イメージと非類似イメージのペアの作成
ネットワークに学習させるには、データを類似イメージまたは非類似イメージのペアにグループ化しなければなりません。ここで、類似イメージは同じラベルをもつイメージ、非類似イメージは異なるラベルをもつイメージとして定義されます。関数 getSiameseBatch
(この例のサポート関数の節で定義) は、類似イメージまたは非類似イメージのランダム化されたペアである pairImage1
と pairImage2
を作成します。また、この関数は、イメージのペアが互いに類似か非類似かを識別するラベル pairLabel
を返します。イメージの類似ペアの場合は pairLabel = 1
、非類似ペアの場合は pairLabel = 0
になります。
一例として、5 つのイメージのペアをもつ、典型的な小さいセットを作成します。
batchSize = 10; [pairImage1,pairImage2,pairLabel] = getSiameseBatch(XTrain,TTrain,batchSize);
生成されたイメージのペアを表示します。
figure tiledlayout("flow") for i = 1:batchSize nexttile imshow([pairImage1(:,:,:,i) pairImage2(:,:,:,i)]); if pairLabel(i) == 1 s = "similar"; else s = "dissimilar"; end title(s) end
この例では、学習ループの反復ごとに、イメージのペア 180 個から成る新しいバッチが作成されます。これにより、類似ペアと非類似ペアの比率がほぼ等しい大量のランダムなイメージのペアでネットワークに学習させることができます。
ネットワーク アーキテクチャの定義
シャム ネットワークのアーキテクチャを次の図に示します。
この例では、2 つの同一のサブネットワークが ReLU 層をもつ一連の全結合層として定義されます。28 x 28 x 1 のイメージを受け取り、低次元特徴表現として使用される 2 つの特徴ベクトルを出力するネットワークを作成します。ネットワークは、入力イメージの次元を 2 に削減します。これにより、初期の 784 の次元よりもプロットと可視化を簡単に行えるようになります。
最初の 2 つの全結合層では、出力サイズに 1024 を指定し、He の重みの初期化子を使用します。
最終の全結合層では、出力サイズに 2 を指定し、He の重みの初期化子を使用します。
layers = [ imageInputLayer([28 28],Normalization="none") fullyConnectedLayer(1024,WeightsInitializer="he") reluLayer fullyConnectedLayer(1024,WeightsInitializer="he") reluLayer fullyConnectedLayer(2,WeightsInitializer="he")];
カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層配列を dlnetwork
オブジェクトに変換します。
net = dlnetwork(layers);
モデル損失関数の定義
関数 modelLoss
(この例のサポート関数の節で定義) を作成します。関数 modelLoss
は、シャム dlnetwork
オブジェクト net
、入力データのミニバッチ X1
と X2
およびそのラベル pairLabels
を受け取ります。関数は損失値と、ネットワークの学習可能なパラメーターについての損失の勾配を返します。
シャム ネットワークの目的は、各イメージについて、類似イメージの場合は類似し、非類似イメージの場合は明確に異なるような特徴ベクトルを出力することです。このようにして、ネットワークは 2 つの入力を区別することができます。
最後の全結合層の出力と、pairImage1
および pairImage2
からの特徴ベクトル features1
および features1
との間の対比損失をそれぞれ求めます。ペアの対比損失は [3] で与えられます。
ここで、 はペア ラベルの値 (類似イメージの場合は 、非類似イメージの場合は ) で、 は 2 つの特徴ベクトル と のユークリッド距離 です。
パラメーターは制約のために使用されます。ペア内の 2 つのイメージが非類似の場合は、両者の距離が少なくとも でなければならず、そうでないと損失が発生します。
対比損失には項が 2 つありますが、与えられたイメージ ペアについて非ゼロになり得るのは、どちらか 1 つだけです。類似イメージの場合は、第 1 項が非ゼロになることができ、イメージの特徴 と の間の距離を減らすことで最小化されます。非類似イメージの場合は、第 2 項が非ゼロになることができ、イメージの特徴間の距離を増やすことで、少なくとも の距離まで最小化されます。 の値が小さいほど、損失が発生する前に非類似ペアがどれだけ近くなり得るかの制約が緩くなります。
学習オプションの指定
学習中に使用する の値を指定します。
margin = 0.3;
学習中に使用するオプションを指定します。3000 回反復して学習させます。
numIterations = 3000; miniBatchSize = 180;
Adam 最適化のオプションを指定します。
学習率を
0.0001
に設定します。最後の平均勾配と最後の平均 2 乗勾配減衰率を
[]
に初期化します。勾配の減衰係数を
0.9
に、2 乗勾配の減衰係数を0.99
に設定。
learningRate = 1e-4; trailingAvg = []; trailingAvgSq = []; gradDecay = 0.9; gradDecaySq = 0.99;
学習損失の進行状況をプロットするためのプロット パラメーターを初期化します。
figure C = colororder; lineLossTrain = animatedline(Color=C(2,:)); ylim([0 inf]) xlabel("Iteration") ylabel("Loss") grid on
モデルの学習
カスタム学習ループを使用してモデルに学習させます。学習データ全体をループ処理し、各反復でネットワーク パラメーターを更新します。
それぞれの反復で次を行います。
イメージ ペアのバッチの作成の節で定義されている関数
getSiameseBatch
を使用して、イメージ ペアとラベルのバッチを抽出。基となる型が
single
のdlarray
オブジェクトにイメージ データを変換し、次元ラベルを"SSCB"
(spatial、spatial、channel、batch) に指定。GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
関数
dlfeval
およびmodelLoss
を使用してモデルの損失と勾配を評価。関数
adamupdate
を使用してネットワーク パラメーターを更新。
start = tic; % Loop over mini-batches. for iteration = 1:numIterations % Extract mini-batch of image pairs and pair labels [X1,X2,pairLabels] = getSiameseBatch(XTrain,TTrain,miniBatchSize); % Convert mini-batch of data to dlarray. Specify the dimension labels % "SSCB" (spatial, spatial, channel, batch) for image data X1 = dlarray(single(X1),"SSCB"); X2 = dlarray(single(X2),"SSCB"); % If training on a GPU, then convert data to gpuArray. if canUseGPU X1 = gpuArray(X1); X2 = gpuArray(X2); end % Evaluate the model loss and gradients using dlfeval and the modelLoss % function listed at the end of the example. [loss,gradients] = dlfeval(@modelLoss,net,X1,X2,pairLabels,margin); % Update the Siamese network parameters. [net,trailingAvg,trailingAvgSq] = adamupdate(net,gradients, ... trailingAvg,trailingAvgSq,iteration,learningRate,gradDecay,gradDecaySq); % Update the training loss progress plot. D = duration(0,0,toc(start),Format="hh:mm:ss"); loss = double(loss); addpoints(lineLossTrain,iteration,loss) title("Elapsed: " + string(D)) drawnow end
イメージ類似度の可視化
次元削減におけるネットワークの性能を評価するために、テスト データのセットについて低次元特徴を計算してプロットします。学習データと類似した手書きの数字のイメージで構成されるテスト データを読み込みます。テスト データを dlarray
に変換し、次元ラベルを "SSCB"
(spatial、spatial、channel、batch) に指定します。GPU を使用している場合は、テスト データを gpuArray
に変換します。
[XTest,TTest] = digitTest4DArrayData; XTest = dlarray(single(XTest),"SSCB"); if canUseGPU XTest = gpuArray(XTest); end
テスト データの低次元特徴を計算します。
FTest = predict(net,XTest);
各グループについて、テスト データの最初の 2 つの低次元特徴をプロットします。
uniqueGroups = unique(TTest); colors = hsv(length(uniqueGroups)); figure hold on for k = 1:length(uniqueGroups) ind = TTest==uniqueGroups(k); plot(FTest(1,ind),FTest(2,ind),".",Color=colors(k,:)); end hold off xlabel("Feature 1") ylabel("Feature 2") title("2-D Feature Representation of Digits Images."); legend(uniqueGroups,Location="eastoutside");
学習済みネットワークを使用した類似イメージの検出
学習済みネットワークを使用して、互いに似ているイメージの集合をグループから検出することが可能です。テスト データからテスト イメージを 1 つ抽出して表示します。
testIdx = randi(5000); testImg = XTest(:,:,:,testIdx); trialImgDisp = extractdata(testImg); figure imshow(trialImgDisp,InitialMagnification=500);
テスト データを含み、抽出されたテスト イメージは含まないイメージのグループを作成します。
groupX = XTest; groupX(:,:,:,testIdx) = [];
predict
を使用してテスト イメージの低次元特徴を求めます。
trialF = predict(net,testImg);
学習済みネットワークを使用して、グループ内の各イメージについて 2 次元の低次元特徴表現を求めます。
FGroupX = predict(net,groupX);
低次元特徴表現を使用して、ユークリッド距離計量を用いてテスト イメージに最も近いイメージをグループ内から 9 個見つけます。イメージを表示します。
distances = vecnorm(extractdata(trialF - FGroupX)); [~,idx] = sort(distances); sortedImages = groupX(:,:,:,idx); sortedImages = extractdata(sortedImages); figure imshow(imtile(sortedImages(:,:,:,1:9)),InitialMagnification=500);
イメージの次元を削減することで、ネットワークは、テスト イメージに類似したイメージを識別できます。低次元特徴表現によって、ネットワークによる類似イメージと非類似イメージの識別が可能になります。シャム ネットワークは顔認識やシグネチャ認識のコンテキストでよく使用されます。たとえば、顔のイメージを入力として受け取り、データベースからよく似ている顔のセットを返すように、シャム ネットワークに学習させることができます。
サポート関数
モデル損失関数
関数 modelLoss
は、シャム dlnetwork
のオブジェクト net
、ミニバッチ入力データ X1
と X2
のペア、およびラベル pairLabels
を受け取ります。この関数は、ペアになっているイメージの低次元特徴間の対比損失と、ネットワーク内の学習可能パラメーターについての損失の勾配を返します。この例では、関数 modelLoss
についてモデル損失関数の定義の節で紹介されています。
function [loss,gradients] = modelLoss(net,X1,X2,pairLabel,margin) % The modelLoss function calculates the contrastive loss between the % paired images and returns the loss and the gradients of the loss with % respect to the network learnable parameters % Pass first half of image pairs forward through the network F1 = forward(net,X1); % Pass second set of image pairs forward through the network F2 = forward(net,X2); % Calculate contrastive loss loss = contrastiveLoss(F1,F2,pairLabel,margin); % Calculate gradients of the loss with respect to the network learnable % parameters gradients = dlgradient(loss,net.Learnables); end function loss = contrastiveLoss(F1,F2,pairLabel,margin) % The contrastiveLoss function calculates the contrastive loss between % the reduced features of the paired images % Define small value to prevent taking square root of 0 delta = 1e-6; % Find Euclidean distance metric distances = sqrt(sum((F1 - F2).^2,1) + delta); % label(i) = 1 if features1(:,i) and features2(:,i) are features % for similar images, and 0 otherwise lossSimilar = pairLabel.*(distances.^2); lossDissimilar = (1 - pairLabel).*(max(margin - distances, 0).^2); loss = 0.5*sum(lossSimilar + lossDissimilar,"all"); end
イメージ ペアのバッチの作成
次の関数は、ラベルに基づいて類似または非類似のイメージのランダム化されたペアを作成します。この例では、関数 getSiameseBatch
について類似イメージと非類似イメージのペアの作成の節で紹介されています。
function [X1,X2,pairLabels] = getSiameseBatch(X,Y,miniBatchSize) % getSiameseBatch returns a randomly selected batch of paired images. % On average, this function produces a balanced set of similar and % dissimilar pairs. pairLabels = zeros(1, miniBatchSize); imgSize = size(X(:,:,:,1)); X1 = zeros([imgSize 1 miniBatchSize]); X2 = zeros([imgSize 1 miniBatchSize]); for i = 1:miniBatchSize choice = rand(1); if choice < 0.5 [pairIdx1, pairIdx2, pairLabels(i)] = getSimilarPair(Y); else [pairIdx1, pairIdx2, pairLabels(i)] = getDissimilarPair(Y); end X1(:,:,:,i) = X(:,:,:,pairIdx1); X2(:,:,:,i) = X(:,:,:,pairIdx2); end end function [pairIdx1,pairIdx2,pairLabel] = getSimilarPair(classLabel) % getSimilarPair returns a random pair of indices for images % that are in the same class and the similar pair label = 1. % Find all unique classes. classes = unique(classLabel); % Choose a class randomly which will be used to get a similar pair. classChoice = randi(numel(classes)); % Find the indices of all the observations from the chosen class. idxs = find(classLabel==classes(classChoice)); % Randomly choose two different images from the chosen class. pairIdxChoice = randperm(numel(idxs),2); pairIdx1 = idxs(pairIdxChoice(1)); pairIdx2 = idxs(pairIdxChoice(2)); pairLabel = 1; end function [pairIdx1,pairIdx2,pairLabel] = getDissimilarPair(classLabel) % getDissimilarPair returns a random pair of indices for images % that are in different classes and the dissimilar pair label = 0. % Find all unique classes. classes = unique(classLabel); % Choose two different classes randomly which will be used to get a dissimilar pair. classesChoice = randperm(numel(classes), 2); % Find the indices of all the observations from the first and second classes. idxs1 = find(classLabel==classes(classesChoice(1))); idxs2 = find(classLabel==classes(classesChoice(2))); % Randomly choose one image from each class. pairIdx1Choice = randi(numel(idxs1)); pairIdx2Choice = randi(numel(idxs2)); pairIdx1 = idxs1(pairIdx1Choice); pairIdx2 = idxs2(pairIdx2Choice); pairLabel = 0; end
参考文献
Bromley, J., I. Guyon, Y. LeCun, E. Säckinger, and R. Shah. "Signature Verification using a "Siamese" Time Delay Neural Network." In Proceedings of the 6th International Conference on Neural Information Processing Systems (NIPS 1993), 1994, pp737-744. Available at Signature Verification using a "Siamese" Time Delay Neural Network on the NIPS Proceedings website.
Wenpeg, Y., and H Schütze. "Convolutional Neural Network for Paraphrase Identification." In Proceedings of 2015 Conference of the North American Cahapter of the ACL, 2015, pp901-911.Available at Convolutional Neural Network for Paraphrase Identification on the ACL Anthology website.
Hadsell, R., S. Chopra, and Y. LeCun. "Dimensionality Reduction by Learning an Invariant Mapping." In Proceedings of the 2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR 2006), 2006, pp1735-1742.
参考
dlarray
| dlgradient
| dlfeval
| dlnetwork
| adamupdate