Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

シャム ネットワークの学習とイメージの比較

この例では、シャム ネットワークに学習させて類似した手書き文字のイメージを特定する方法を説明します。

シャム ネットワークは深層学習ネットワークの一種で、同じアーキテクチャをもち、同じパラメーターと重みを共有する、2 つ以上の同一のサブネットワークを使用します。シャム ネットワークは通常、比較可能な 2 つの事例の関係性を見つけることに関わるタスクに使用されます。シャム ネットワークの一般的な使用例には、顔認識、シグネチャ検証 [1]、またはパラフレーズ識別 [2] などがあります。学習中に重みを共有することで学習させるパラメーターが少なくなり、比較的少量の学習データで良い結果が得られるため、シャム ネットワークはこれらのタスクにおいて良好に動作します。

シャム ネットワークは特に、クラスの数が多く、各クラスの観測値が少ない場合に役立ちます。このような場合、イメージをこれらのクラスに分類するよう深層畳み込みニューラル ネットワークに学習させるだけの十分なデータがありません。代わりに、シャム ネットワークによって 2 つのイメージが同じクラスに属するかどうかを判定することができます。

この例では、Omniglot データセット [3] を使用してシャム ネットワークに学習させ、手書き文字のイメージを比較します [4]。Omniglot データセットには 50 個のアルファベットの文字セットが含まれ、そのうち 30 個が学習用に、20 個がテスト用に分けられています。それぞれのアルファベットには、Ojibwe (カナダ先住民文字) の 14 個の文字から Tifinagh の 55 個の文字まで、数多くの文字が含まれています。そして、それぞれの文字には 20 個の手書きの観測値があります。この例では、2 つの手書きの観測値が同じ文字の異なるインスタンスであるかを特定するよう、ネットワークに学習させます。

シャム ネットワークを使用すると、次元削減によって類似のイメージを特定することもできます。例については、次元削減のためのシャム ネットワークの学習を参照してください。

学習データの読み込みと前処理

Omniglot 学習データセットをダウンロードして解凍します。

url = "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"images_background.zip");

dataFolderTrain = fullfile(downloadFolder,'images_background');
if ~exist(dataFolderTrain,"dir")
    disp("Downloading Omniglot training data (4.5 MB)...")
    websave(filename,url);
    unzip(filename,downloadFolder);
end
disp("Training data downloaded.")
Training data downloaded.

関数 imageDatastore を使用して学習データをイメージ データストアとして読み込みます。ファイル名からラベルを抽出して Labels プロパティを設定し、ラベルを手動で指定します。

imdsTrain = imageDatastore(dataFolderTrain, ...
    'IncludeSubfolders',true, ...
    'LabelSource','none');

files = imdsTrain.Files;
parts = split(files,filesep);
labels = join(parts(:,(end-2):(end-1)),'_');
imdsTrain.Labels = categorical(labels);

Omniglot 学習データセットは 30 個のアルファベットのモノクロの手書き文字で構成され、それぞれの文字について 20 個の観測値をもちます。イメージのサイズは 105 x 105 x 1 で、各ピクセルの値は 0 から 1 の範囲内です。

ランダムに選択したイメージを表示します。

idxs = randperm(numel(imdsTrain.Files),8);

for i = 1:numel(idxs)
    subplot(4,2,i)
    imshow(readimage(imdsTrain,idxs(i)))
    title(imdsTrain.Labels(idxs(i)), "Interpreter","none");
end

類似イメージと非類似イメージのペアの作成

ネットワークに学習させるには、データを類似イメージまたは非類似イメージのペアにグループ化しなければなりません。ここで、類似イメージは同じ文字の異なる手書きのインスタンスで、ラベルが同じであるのに対し、非類似イメージは異なる文字で、ラベルが異なります。関数 getSiameseBatch (この例のサポート関数の節で定義) は、類似イメージまたは非類似イメージのランダム化されたペアである pairImage1 pairImage2 を作成します。また、この関数は、イメージのペアが互いに類似か非類似かを識別するラベル pairLabel を返します。イメージの類似ペアの場合は pairLabel = 1、非類似ペアの場合は pairLabel = 0 になります。

一例として、5 つのイメージのペアをもつ、典型的な小さいセットを作成します。

batchSize = 10;
[pairImage1,pairImage2,pairLabel] = getSiameseBatch(imdsTrain,batchSize);

生成されたイメージのペアを表示します。

for i = 1:batchSize    
    if pairLabel(i) == 1
        s = "similar";
    else
        s = "dissimilar";
    end
    subplot(2,5,i)   
    imshow([pairImage1(:,:,:,i) pairImage2(:,:,:,i)]);
    title(s)
end

この例では、学習ループの反復ごとに、イメージのペア 180 個から成る新しいバッチが作成されます。これにより、類似ペアと非類似ペアの比率がほぼ等しい大量のランダムなイメージのペアでネットワークに学習させることができます。

ネットワーク アーキテクチャの定義

シャム ネットワークのアーキテクチャを次の図に示します。

2 つのイメージを比較するために、各イメージが、重みを共有する 2 つの同一サブネットワークの一方を通過します。サブネットワークは、105 x 105 x 1 の各イメージを 4096 次元の特徴ベクトルに変換します。同じクラスのイメージの 4096 次元表現は類似しています。各サブネットワークの出力特徴ベクトルは、減算により組み合わされ、その結果が単一出力をもつ fullyconnect 演算に渡されます。sigmoid 演算は、この値を 0 から 1 の間の確率に変換します。この値は、イメージが類似か非類似かのネットワークの予測を示します。学習中のネットワークの更新には、ネットワーク予測と真のラベルの間におけるバイナリ交差エントロピー損失が使用されます。

この例では、2 つの同一サブネットワークが dlnetwork オブジェクトとして定義されます。最終の fullyconnect および sigmoid 演算は、サブネットワークの出力に対する関数演算として実行されます。

105 x 105 x 1 のイメージを受け取り、サイズが 4096 の特徴ベクトルを出力する一連の層としてサブネットワークを作成します。

convolution2dLayer オブジェクトでは、狭い正規分布を使用して重みとバイアスを初期化します。

maxPooling2dLayer オブジェクトでは、ストライドを 2 に設定します。

最終の fullyConnectedLayer オブジェクトでは、出力サイズを 4096 に指定し、狭い正規分布を使用して重みとバイアスを初期化します。

layers = [
    imageInputLayer([105 105 1],'Name','input1','Normalization','none')
    convolution2dLayer(10,64,'Name','conv1','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    reluLayer('Name','relu1')
    maxPooling2dLayer(2,'Stride',2,'Name','maxpool1')
    convolution2dLayer(7,128,'Name','conv2','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    reluLayer('Name','relu2')
    maxPooling2dLayer(2,'Stride',2,'Name','maxpool2')
    convolution2dLayer(4,128,'Name','conv3','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    reluLayer('Name','relu3')
    maxPooling2dLayer(2,'Stride',2,'Name','maxpool3')
    convolution2dLayer(5,256,'Name','conv4','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')
    reluLayer('Name','relu4')
    fullyConnectedLayer(4096,'Name','fc1','WeightsInitializer','narrow-normal','BiasInitializer','narrow-normal')];

lgraph = layerGraph(layers);

カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層グラフを dlnetwork オブジェクトに変換します。

dlnet = dlnetwork(lgraph);

最終の fullyconnect 演算の重みを作成します。標準偏差 0.01 の狭い正規分布からランダムに選択してサンプリングし、重みを初期化します。

fcWeights = dlarray(0.01*randn(1,4096));
fcBias = dlarray(0.01*randn(1,1));

fcParams = struct(...
    "FcWeights",fcWeights,...
    "FcBias",fcBias);

ネットワークを使用するために、2 つのサブネットワークと減算、fullyconnect, 演算および sigmoid 演算をどのように組み合わせるかを定義する関数 forwardSiamese (この例のサポート関数の節で定義) を作成します。関数 forwardSiamese は、ネットワーク、fullyconnect 演算のパラメーターを含む構造体、および 2 つの学習イメージを受け取ります。関数 forwardSiamese は 2 つのイメージの類似度についての予測を出力します。

モデル勾配関数の定義

関数 modelGradients (この例のサポート関数の節で定義) を作成します。関数 modelGradients は、シャム ネットワーク dlnetfullyconnect 演算のためのパラメーター構造体、入力データのミニバッチ X1 および X2 とそのラベル pairLabels を受け取ります。関数は損失値と、ネットワークの学習可能なパラメーターについての損失の勾配を返します。

シャム ネットワークの目的は、2 つの入力 X1X2 を区別することです。ネットワーク出力は 0 から 1 の間の確率です。値が 0 に近いほどイメージは非類似と予測され、1 に近いほどイメージは類似と予測されたことを示します。損失は、予測スコアと真のラベル値の間のバイナリ交差エントロピーで表されます。

loss=-tlog(y)-(1-t)log(1-y),

ここで、真のラベル t は 0 または 1 を取り、y は予測されたラベルです。

学習オプションの指定

学習中に使用するオプションを指定します。10000 回反復して学習させます。

numIterations = 10000;
miniBatchSize = 180;

ADAM 最適化のオプションを指定します。

  • 学習率を 0.00006 に設定します。

  • 移動平均勾配と移動平均 2 乗勾配の減衰率について、dlnetfcParams のどちらも [] に初期化。

  • 勾配の減衰係数を 0.9 に、2 乗勾配の減衰係数を 0.99 に設定。

learningRate = 6e-5;
trailingAvgSubnet = [];
trailingAvgSqSubnet = [];
trailingAvgParams = [];
trailingAvgSqParams = [];
gradDecay = 0.9;
gradDecaySq = 0.99;

GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。GPU が使用可能で、関連するデータを GPU に配置できるかどうかを自動的に検出するには、executionEnvironment の値を "auto" に設定します。GPU がない場合や、学習に GPU を使用しない場合は、executionEnvironment の値を "cpu" に設定します。学習に必ず GPU を使用するには、executionEnvironment の値を "gpu" に設定します。

executionEnvironment = "auto";

学習の進行状況を監視するには、それぞれの反復の後で学習の損失をプロットできます。"training-progress" を含む変数 plots を作成します。学習の進行状況をプロットしない場合は、この値を "none" に設定します。

plots = "training-progress";

学習損失の進行状況をプロットするためのプロット パラメーターを初期化します。

plotRatio = 16/9;

if plots == "training-progress"
    trainingPlot = figure;
    trainingPlot.Position(3) = plotRatio*trainingPlot.Position(4);
    trainingPlot.Visible = 'on';
    
    trainingPlotAxes = gca;
    
    lineLossTrain = animatedline(trainingPlotAxes);
    xlabel(trainingPlotAxes,"Iteration")
    ylabel(trainingPlotAxes,"Loss")
    title(trainingPlotAxes,"Loss During Training")
end

モデルの学習

カスタム学習ループを使用してモデルに学習させます。学習データ全体をループ処理し、各反復でネットワーク パラメーターを更新します。

それぞれの反復で次を行います。

  • イメージ ペアのバッチの作成の節で定義されている関数 getSiameseBatch を使用して、イメージ ペアとラベルのバッチを抽出。

  • 基となる型が singledlarray オブジェクトにデータを変換し、イメージ データを次元ラベル 'SSCB' (spatial、spatial、channel、batch) に、ラベルを 'CB' (channel、batch) に指定。

  • GPU で学習する場合、データを gpuArray オブジェクトに変換。

  • 関数 dlfeval および modelGradients を使用してモデルの勾配を評価します。

  • 関数 adamupdate を使用してネットワーク パラメーターを更新。

% Loop over mini-batches.
for iteration = 1:numIterations
    
    % Extract mini-batch of image pairs and pair labels
    [X1,X2,pairLabels] = getSiameseBatch(imdsTrain,miniBatchSize);
    
    % Convert mini-batch of data to dlarray. Specify the dimension labels
    % 'SSCB' (spatial, spatial, channel, batch) for image data
    dlX1 = dlarray(single(X1),'SSCB');
    dlX2 = dlarray(single(X2),'SSCB');
    
    % If training on a GPU, then convert data to gpuArray.
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        dlX1 = gpuArray(dlX1);
        dlX2 = gpuArray(dlX2);
    end  
    
    % Evaluate the model gradients and the generator state using
    % dlfeval and the modelGradients function listed at the end of the
    % example.
    [gradientsSubnet, gradientsParams,loss] = dlfeval(@modelGradients,dlnet,fcParams,dlX1,dlX2,pairLabels);
    lossValue = double(gather(extractdata(loss)));
    
    % Update the Siamese subnetwork parameters.
    [dlnet,trailingAvgSubnet,trailingAvgSqSubnet] = ...
        adamupdate(dlnet,gradientsSubnet, ...
        trailingAvgSubnet,trailingAvgSqSubnet,iteration,learningRate,gradDecay,gradDecaySq);
    
    % Update the fullyconnect parameters.
    [fcParams,trailingAvgParams,trailingAvgSqParams] = ...
        adamupdate(fcParams,gradientsParams, ...
        trailingAvgParams,trailingAvgSqParams,iteration,learningRate,gradDecay,gradDecaySq);
      
    % Update the training loss progress plot.
    if plots == "training-progress"
        addpoints(lineLossTrain,iteration,lossValue);
    end
    drawnow;
end

ネットワークの精度の評価

Omniglot テスト データセットをダウンロードして解凍します。

url = 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip';
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'images_evaluation.zip');

dataFolderTest = fullfile(downloadFolder,'images_evaluation');
if ~exist(dataFolderTest,'dir')
    disp('Downloading Omniglot test data (3.2 MB)...')
    websave(filename,url);
    unzip(filename,downloadFolder);
end
disp("Test data downloaded.")
Test data downloaded.

関数 imageDatastore を使用してテスト データをイメージ データストアとして読み込みます。ファイル名からラベルを抽出して Labels プロパティを設定し、ラベルを手動で指定します。

imdsTest = imageDatastore(dataFolderTest, ...
    'IncludeSubfolders',true, ...
    'LabelSource','none');    

files = imdsTest.Files;
parts = split(files,filesep);
labels = join(parts(:,(end-2):(end-1)),'_');
imdsTest.Labels = categorical(labels);

テスト データセットには、ネットワークが学習済みのアルファベットとは異なる 20 個のアルファベットが含まれています。合計で、テスト データセットには 659 個の異なるクラスが存在します。

numClasses = numel(unique(imdsTest.Labels))
numClasses = 659

ネットワークの精度を計算するには、テスト ペアとして 5 つのランダムなミニバッチのセットを作成します。関数 predictSiamese (この例のサポート関数の節で定義) を使用して、ネットワークの予測の評価とミニバッチにおける平均精度の計算を行います。

accuracy = zeros(1,5);
accuracyBatchSize = 150;

for i = 1:5
    
    % Extract mini-batch of image pairs and pair labels
    [XAcc1,XAcc2,pairLabelsAcc] = getSiameseBatch(imdsTest,accuracyBatchSize);
    
    % Convert mini-batch of data to dlarray. Specify the dimension labels
    % 'SSCB' (spatial, spatial, channel, batch) for image data.
    dlXAcc1 = dlarray(single(XAcc1),'SSCB');
    dlXAcc2 = dlarray(single(XAcc2),'SSCB');
    
    % If using a GPU, then convert data to gpuArray.
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
       dlXAcc1 = gpuArray(dlXAcc1);
       dlXAcc2 = gpuArray(dlXAcc2);
    end    
    
    % Evaluate predictions using trained network
    dlY = predictSiamese(dlnet,fcParams,dlXAcc1,dlXAcc2);
   
    % Convert predictions to binary 0 or 1
    Y = gather(extractdata(dlY));
    Y = round(Y);
    
    % Compute average accuracy for the minibatch
    accuracy(i) = sum(Y == pairLabelsAcc)/accuracyBatchSize;
end

% Compute accuracy over all minibatches
averageAccuracy = mean(accuracy)*100
averageAccuracy = 88.6667

イメージのテスト セットと予測の表示

ネットワークが類似ペアと非類似ペアを正しく識別しているかを視覚的にチェックするには、テスト用にイメージのペアの小さいバッチを作成します。関数 predictSiamese を使用して各テスト ペアの予測を取得します。イメージのペアを、予測、確率スコア、および予測の正誤を示すラベルと共に表示します。

testBatchSize = 10;

[XTest1,XTest2,pairLabelsTest] = getSiameseBatch(imdsTest,testBatchSize);
    
% Convert test batch of data to dlarray. Specify the dimension labels
% 'SSCB' (spatial, spatial, channel, batch) for image data and 'CB' 
% (channel, batch) for labels
dlXTest1 = dlarray(single(XTest1),'SSCB');
dlXTest2 = dlarray(single(XTest2),'SSCB');

% If using a GPU, then convert data to gpuArray
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
   dlXTest1 = gpuArray(dlXTest1);
   dlXTest2 = gpuArray(dlXTest2);
end

% Calculate the predicted probability
dlYScore = predictSiamese(dlnet,fcParams,dlXTest1,dlXTest2);
YScore = gather(extractdata(dlYScore));

% Convert predictions to binary 0 or 1
YPred = round(YScore);    

% Extract data to plot
XTest1 = extractdata(dlXTest1);
XTest2 = extractdata(dlXTest2);

% Plot images with predicted label and predicted score
testingPlot = figure;
testingPlot.Position(3) = plotRatio*testingPlot.Position(4);
testingPlot.Visible = 'on';
    
for i = 1:numel(pairLabelsTest)
     
    if YPred(i) == 1
        predLabel = "similar";
    else
        predLabel = "dissimilar" ;
    end
    
    if pairLabelsTest(i) == YPred(i)
        testStr = "\bf\color{darkgreen}Correct\rm\newline";
        
    else
        testStr = "\bf\color{red}Incorrect\rm\newline";
    end
    
    subplot(2,5,i)        
    imshow([XTest1(:,:,:,i) XTest2(:,:,:,i)]);        
    
    title(testStr + "\color{black}Predicted: " + predLabel + "\newlineScore: " + YScore(i)); 
end

ネットワークは、それらのイメージが学習データセットにまったく含まれていない場合でも、テスト イメージを比較して類似度を判定することができます。

サポート関数

学習と予測のためのモデル関数

ネットワークの学習中は関数 forwardSiamese が使用されます。この関数は、サブネットワークや fullyconnect 演算および sigmoid 演算をどのように組み合わせて完全なシャム ネットワークを形成するかを定義します。forwardSiamese はネットワーク構造と 2 つの学習イメージを受け取り、2 つのイメージの類似度についての予測を出力します。この例では、関数 forwardSiamese についてネットワーク アーキテクチャの定義の節で紹介されています。

function Y = forwardSiamese(dlnet,fcParams,dlX1,dlX2)
% forwardSiamese accepts the network and pair of training images, and returns a
% prediction of the probability of the pair being similar (closer to 1) or 
% dissimilar (closer to 0). Use forwardSiamese during training.

    % Pass the first image through the twin subnetwork
    F1 = forward(dlnet,dlX1);
    F1 = sigmoid(F1);
    
    % Pass the second image through the twin subnetwork
    F2 = forward(dlnet,dlX2);
    F2 = sigmoid(F2);
    
    % Subtract the feature vectors
    Y = abs(F1 - F2);
    
    % Pass the result through a fullyconnect operation
    Y = fullyconnect(Y,fcParams.FcWeights,fcParams.FcBias);
    
    % Convert to probability between 0 and 1.
    Y = sigmoid(Y);
end

関数 predictSiamese は学習済みネットワークを使用して 2 つのイメージの類似度に関する予測を行います。この関数は、以前に定義した関数 forwardSiamese に類似しています。ただし、predictSiamese は関数 forward ではなく関数 predict をネットワークで使用します。これは、一部の深層学習層が学習中および予測中に異なる動作をするためです。この例では、関数 predictSiamese についてネットワークの精度の評価の節で紹介されています。

function Y = predictSiamese(dlnet,fcParams,dlX1,dlX2)
% predictSiamese accepts the network and pair of images, and returns a
% prediction of the probability of the pair being similar (closer to 1)
% or dissimilar (closer to 0). Use predictSiamese during prediction.

    % Pass the first image through the twin subnetwork
    F1 = predict(dlnet,dlX1);
    F1 = sigmoid(F1);
    
    % Pass the second image through the twin subnetwork
    F2 = predict(dlnet,dlX2);
    F2 = sigmoid(F2);
    
    % Subtract the feature vectors
    Y = abs(F1 - F2);
    
    % Pass result through a fullyconnect operation
    Y = fullyconnect(Y,fcParams.FcWeights,fcParams.FcBias);
    
    % Convert to probability between 0 and 1.
    Y = sigmoid(Y);
end

モデル勾配関数

関数 modelGradients は、シャム dlnetwork オブジェクトの net、ミニバッチ入力データ X1X2 のペア、およびそれらが類似か非類似かを示すラベルを受け取ります。この関数は、ネットワーク内の学習可能なパラメーターについての損失の勾配と、予測とグラウンド トゥルースの間のバイナリ交差エントロピー損失を返します。この例において、関数 modelGradients については、モデル勾配関数の定義の節で説明されています。

function [gradientsSubnet,gradientsParams,loss] = modelGradients(dlnet,fcParams,dlX1,dlX2,pairLabels)
% The modelGradients function calculates the binary cross-entropy loss between the
% paired images and returns the loss and the gradients of the loss with respect to
% the network learnable parameters

    % Pass the image pair through the network 
    Y = forwardSiamese(dlnet,fcParams,dlX1,dlX2);
    
    % Calculate binary cross-entropy loss
    loss = binarycrossentropy(Y,pairLabels);
       
    % Calculate gradients of the loss with respect to the network learnable
    % parameters
    [gradientsSubnet,gradientsParams] = dlgradient(loss,dlnet.Learnables,fcParams);
end

function loss = binarycrossentropy(Y,pairLabels)
    % binarycrossentropy accepts the network's prediction Y, the true
    % label, and pairLabels, and returns the binary cross-entropy loss value.
    
    % Get precision of prediction to prevent errors due to floating
    % point precision    
    precision = underlyingType(Y);
      
    % Convert values less than floating point precision to eps.
    Y(Y < eps(precision)) = eps(precision);
    %convert values between 1-eps and 1 to 1-eps.
    Y(Y > 1 - eps(precision)) = 1 - eps(precision);
    
    % Calculate binary cross-entropy loss for each pair
    loss = -pairLabels.*log(Y) - (1 - pairLabels).*log(1 - Y);
    
    % Sum over all pairs in minibatch and normalize.
    loss = sum(loss)/numel(pairLabels);
end

イメージ ペアのバッチの作成

次の関数は、ラベルに基づいて類似または非類似のイメージのランダム化されたペアを作成します。この例では、関数 getSiameseBatch について類似イメージと非類似イメージのペアの作成の節で紹介されています。

function [X1,X2,pairLabels] = getSiameseBatch(imds,miniBatchSize)
% getSiameseBatch returns a randomly selected batch or paired images. On
% average, this function produces a balanced set of similar and dissimilar
% pairs.

    pairLabels = zeros(1,miniBatchSize);
    imgSize = size(readimage(imds,1));
    X1 = zeros([imgSize 1 miniBatchSize]);
    X2 = zeros([imgSize 1 miniBatchSize]);

    for i = 1:miniBatchSize
        choice = rand(1);
        if choice < 0.5
            [pairIdx1,pairIdx2,pairLabels(i)] = getSimilarPair(imds.Labels);
        else
            [pairIdx1,pairIdx2,pairLabels(i)] = getDissimilarPair(imds.Labels);
        end
        X1(:,:,:,i) = imds.readimage(pairIdx1);
        X2(:,:,:,i) = imds.readimage(pairIdx2);
    end
end

function [pairIdx1,pairIdx2,pairLabel] = getSimilarPair(classLabel)
% getSimilarSiamesePair returns a random pair of indices for images
% that are in the same class and the similar pair label = 1.

    % Find all unique classes.
    classes = unique(classLabel);
    
    % Choose a class randomly which will be used to get a similar pair.
    classChoice = randi(numel(classes));
    
    % Find the indices of all the observations from the chosen class.
    idxs = find(classLabel==classes(classChoice));
    
    % Randomly choose two different images from the chosen class.
    pairIdxChoice = randperm(numel(idxs),2);
    pairIdx1 = idxs(pairIdxChoice(1));
    pairIdx2 = idxs(pairIdxChoice(2));
    pairLabel = 1;
end

function  [pairIdx1,pairIdx2,label] = getDissimilarPair(classLabel)
% getDissimilarSiamesePair returns a random pair of indices for images
% that are in different classes and the dissimilar pair label = 0.

    % Find all unique classes.
    classes = unique(classLabel);
    
    % Choose two different classes randomly which will be used to get a dissimilar pair.
    classesChoice = randperm(numel(classes),2);
    
    % Find the indices of all the observations from the first and second classes.
    idxs1 = find(classLabel==classes(classesChoice(1)));
    idxs2 = find(classLabel==classes(classesChoice(2)));
    
    % Randomly choose one image from each class.
    pairIdx1Choice = randi(numel(idxs1));
    pairIdx2Choice = randi(numel(idxs2));
    pairIdx1 = idxs1(pairIdx1Choice);
    pairIdx2 = idxs2(pairIdx2Choice);
    label = 0;
end

参考文献

[1] Bromley, J., I. Guyon, Y. LeCunn, E. Säckinger, and R. Shah."Signature Verification using a "Siamese" Time Delay Neural Network." In Proceedings of the 6th International Conference on Neural Information Processing Systems (NIPS 1993), 1994, pp737-744. Available at Signature Verification using a "Siamese" Time Delay Neural Network on the NIPS Proceedings website.

[2] Wenpeg, Y., and H Schütze. "Convolutional Neural Network for Paraphrase Identification." In Proceedings of 2015 Conference of the North American Chapter of the ACL, 2015, pp901-911. Available at Convolutional Neural Network for Paraphrase Identification on the ACL Anthology website

[3] Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. "Human-level concept learning through probabilistic program induction." Science, 350(6266), (2015) pp1332-1338.

[4] Koch, G., Zemel, R., and Salakhutdinov, R. (2015). "Siamese neural networks for one-shot image recognition". In Proceedings of the 32nd International Conference on Machine Learning, 37 (2015). Available at Siamese Neural Networks for One-shot Image Recognition on the ICML'15 website.

参考

| | | |

関連するトピック