Main Content

高速スタイル転送ネットワークの学習

この例では、あるイメージのスタイルを転送して第 2 のイメージにするネットワークの学習方法を示します。これは、[1] で定義されるアーキテクチャに基づいています。

この例は深層学習を使用したニューラル スタイル転送に似ていますが、スタイル イメージ S でネットワークに学習させることで、より高速に動作します。入力イメージ X をネットワークにフォワード パスしさえすれば、スタイル化されたイメージ Y を取得できるからです。

学習アルゴリズムの大まかなブロック線図は下のようになります。ここでは、入力イメージ X、変換後のイメージ Y、スタイル イメージ S という 3 つのイメージを使用して、損失を計算します。

損失関数は、事前学習済みのネットワーク VGG-16 を使用してイメージから特徴を抽出することに注意してください。その実装および数学的定義は、この例のスタイル転送損失の節で確認できます。

学習データの読み込み

https://cocodataset.org/#download から、[2014 Train images] をクリックし、COCO 2014 の学習イメージとキャプションをダウンロードして解凍します。imageFolder で指定したフォルダーにデータを保存します。イメージを imageFolder に解凍します。COCO 2014 は Coco Consortium によって収集されたものです。

COCO データセットを格納するディレクトリを作成します。

imageFolder = fullfile(tempdir,"coco");
if ~exist(imageFolder,'dir')
    mkdir(imageFolder);
end

COCO イメージを含むイメージ データストアを作成します。

imds = imageDatastore(imageFolder,'IncludeSubfolders',true);

学習の実行には時間がかかることがあります。結果として得られるネットワークの精度を犠牲にして学習時間を短縮する場合は、fraction を小さな値に設定することにより、イメージ データストアのサブセットを選択します。

fraction = 1;
numObservations = numel(imds.Files);
imds = subset(imds,1:floor(numObservations*fraction));

イメージのサイズを変更して、それらすべてを RGB に変換するには、拡張イメージ データストアを作成します。

augimds = augmentedImageDatastore([256 256],imds,'ColorPreprocessing',"gray2rgb");

スタイル イメージを読み取ります。

styleImage = imread('starryNight.jpg');
styleImage = imresize(styleImage,[256 256]);

選択したスタイル イメージを表示します。

figure
imshow(styleImage)
title("Style Image")

イメージ変換ネットワークの定義

イメージ変換ネットワークを定義します。これは image-to-image ネットワークです。このネットワークは、次の 3 つの部分で構成されます。

  1. ネットワークの 1 番目の部分は、サイズ [256 x 256 x 3] の RGB イメージを入力として受け取り、それをサイズ [64 x 64 x 128] の特徴マップにダウンサンプリングします。

  2. ネットワークの 2 番目の部分は、サポート関数 residualBlock. で定義された 5 つの同じ残差ブロックで構成されます。

  3. ネットワークの 3 番目の部分は、特徴マップを元のサイズのイメージにアップサンプリングし、変換後のイメージを返します。この最後の部分では upsampleLayer を使用します。これは、この例にサポート ファイルとして添付されているカスタム層です。

layers = [
    
    % First part.
    imageInputLayer([256 256 3],Normalization="none")
    
    convolution2dLayer([9 9],32,Padding="same")
    groupNormalizationLayer("channel-wise")
    reluLayer
    
    convolution2dLayer([3 3],64,Stride=2,Padding="same")
    groupNormalizationLayer("channel-wise")
    reluLayer
    
    convolution2dLayer([3 3],128,Stride=2,Padding="same")
    groupNormalizationLayer("channel-wise")
    reluLayer(Name="relu_3")
    
    % Second part. 
    residualBlock("1")
    residualBlock("2")
    residualBlock("3")
    residualBlock("4")
    residualBlock("5")
    
    % Third part.
    upsampleLayer
    convolution2dLayer([3 3],64,Padding="same")
    groupNormalizationLayer("channel-wise")
    reluLayer
    
    upsampleLayer
    convolution2dLayer([3 3],32,Padding="same")
    groupNormalizationLayer("channel-wise")
    reluLayer
    
    convolution2dLayer(9,3,Padding="same")];

lgraph = layerGraph(layers);

残差ブロック内で欠損結合を追加します。

lgraph = connectLayers(lgraph,"relu_3","add_1/in2");
lgraph = connectLayers(lgraph,"add_1","add_2/in2");
lgraph = connectLayers(lgraph,"add_2","add_3/in2");
lgraph = connectLayers(lgraph,"add_3","add_4/in2");
lgraph = connectLayers(lgraph,"add_4","add_5/in2");

イメージ変換ネットワークをプロットで可視化します。

figure
plot(lgraph)
title("Transform Network")

層グラフから dlnetwork オブジェクトを作成します。

netTransform = dlnetwork(lgraph);

スタイル損失ネットワーク

この例では、事前学習済みの VGG-16 深層ニューラル ネットワークを使用して、コンテンツ イメージとスタイル イメージの特徴をさまざまな層で抽出します。この多層にわたる特徴は、コンテンツとスタイルの損失を計算するのに使用されます。

事前学習済みの VGG-16 ネットワークを取得するには、関数 vgg16 を使用します。必要なサポート パッケージがインストールされていない場合、ダウンロード用リンクが表示されます。

netLoss = vgg16;

損失の計算に不可欠な特徴を抽出するために必要となるのは、最初の 24 個の層のみです。抽出し、層グラフに変換します。

lossLayers = netLoss.Layers(1:24);
lgraph = layerGraph(lossLayers);

dlnetwork に変換します。

netLoss = dlnetwork(lgraph);

モデル損失関数の定義

この例のモデル損失関数の節にリストされている関数 modelLoss を作成します。この関数は、損失ネットワーク、イメージ変換ネットワーク、入力イメージのミニバッチ、スタイル イメージのグラム行列を含む配列、コンテンツ損失に関連する重み、およびスタイル損失に関連する重みを入力として受け取ります。この関数は、合計損失、コンテンツに関連する損失、スタイルに関連する損失、イメージ変換の学習可能パラメーターについての合計損失の勾配、イメージ変換ネットワークの状態、および変換後のイメージを返します。

学習オプションの指定

[1] と同様に、ミニバッチ サイズを 4 として、学習を 2 エポック行います。

numEpochs = 2;
miniBatchSize = 4;

拡張イメージ データストアの読み取りサイズをミニバッチのサイズに設定します。

augimds.MiniBatchSize = miniBatchSize;

ADAM 最適化のオプションを指定します。学習率は 0.001 に、勾配の減衰係数は 0.01 に、2 乗勾配の減衰係数は 0.999 に指定します。

learnRate = 0.001;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

合計損失の計算において、スタイル損失に設定する重みとコンテンツ損失に設定する重みを指定します。

コンテンツ損失とスタイル損失のちょうど良いバランスを見つけるには、異なる重みの組み合わせを使った実験が必要になる場合もあることに注意してください。

weightContent = 1e-4;
weightStyle = 3e-8; 

学習の進捗状況のプロット周波数を選択します。これにより、プロットの更新間隔における反復回数を指定します。

plotFrequency = 10;

モデルの学習

学習中に損失を計算できるようにするには、スタール イメージのグラム行列を計算します。

スタイル イメージを dlarray に変換します。

S = dlarray(single(styleImage),"SSC");

グラム行列を計算するには、スタイル イメージを VGG-16 ネットワークに供給し、4 つの異なる層で活性化を抽出します。

[SActivations1,SActivations2,SActivations3,SActivations4] = forward(netLoss,S, ...
    Outputs=["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

サポート関数 createGramMatrix を使用して、活性化の各セットについてグラム行列を計算します。

SGram{1} = createGramMatrix(SActivations1);
SGram{2} = createGramMatrix(SActivations2);
SGram{3} = createGramMatrix(SActivations3);
SGram{4} = createGramMatrix(SActivations4);

学習プロットは次の 2 つの Figure で構成されます。

  1. 学習中の損失のプロットを示す Figure

  2. イメージ変換ネットワークの入力イメージおよび出力イメージを含む Figure

学習プロットを初期化します。サポート関数 initializeFigures. における初期化の詳細をチェックできます。この関数が返すのは、損失をプロットしたときの軸 ax1、検証イメージをプロットしたときの軸 ax2、コンテンツ損失を含むアニメーション化された線 lineLossContent、スタイル損失を含むアニメーション化された線 lineLossStyle 、合計損失を含むアニメーション化された線 lineLossTotal です。

[ax1,ax2,lineLossContent,lineLossStyle,lineLossTotal] = initializeStyleTransferPlots;

ADAM オプティマイザーの平均勾配および平均 2 乗勾配のハイパーパラメーターを初期化します。

averageGrad = [];
averageSqGrad = [];

学習の合計反復回数を計算します。

numIterations = floor(augimds.NumObservations*numEpochs/miniBatchSize);

学習前に反復回数およびタイマーを初期化します。

iteration = 0;
start = tic;

モデルに学習させます。GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。実行には時間がかかることがあります。

% Loop over epochs.
for i = 1:numEpochs
    
    % Reset and shuffle datastore.
    reset(augimds);
    augimds = shuffle(augimds);
    
    % Loop over mini-batches.
    while hasdata(augimds)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimds);
        
        % Ignore last partial mini-batch of epoch.
        if size(data,1) < miniBatchSize
            continue
        end
        
        % Extract the images from data store into a cell array.
        images = data{:,1};
        
        % Concatenate the images along the 4th dimension.
        X = cat(4,images{:});
        X = single(X);
        
        % Convert mini-batch of data to dlarray and specify the dimension labels
        % "SSCB" (spatial, spatial, channel, batch).
        X = dlarray(X,"SSCB");
        
        % If training on a GPU, then convert data to gpuArray.
        if canUseGPU
            X = gpuArray(X);
        end
        
        % Evaluate the model loss, gradients, and the network state using
        % dlfeval and the modelLoss function listed at the end of the
        % example.
        [loss,lossContent,lossStyle,gradients,state,Y] = dlfeval(@modelLoss, ...
            netLoss,netTransform,X,SGram,weightContent,weightStyle);
        
        netTransform.State = state;
        
        % Update the network parameters.
        [netTransform,averageGrad,averageSqGrad] = ...
            adamupdate(netTransform,gradients,averageGrad,averageSqGrad,iteration,...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
              
        % Every plotFequency iterations, plot the training progress.
        if mod(iteration,plotFrequency) == 0
            addpoints(lineLossTotal,iteration,double(loss))
            addpoints(lineLossContent,iteration,double(lossContent))
            addpoints(lineLossStyle,iteration,double(lossStyle))
            
            % Use the first image of the mini-batch as a validation image.
            XV = X(:,:,:,1);
            % Use the transformed validation image computed previously.
            YV = Y(:,:,:,1);
            
            % To use the function imshow, convert to uint8.
            validationImage = uint8(gather(extractdata(XV)));
            transformedValidationImage = uint8(gather(extractdata(YV)));
            
            % Plot the input image and the output image and increase size
            imshow(imtile({validationImage,transformedValidationImage}),Parent=ax2);
        end
        
        % Display time elapsed since start of training and training completion percentage.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        completionPercentage = round(iteration/numIterations*100,2);
        title(ax1,"Epoch: " + i + ", Iteration: " + iteration +" of "+ numIterations + "(" + completionPercentage + "%)" +", Elapsed: " + string(D))
        drawnow
        
    end
end

イメージのスタイル化

学習が終了すれば、選択したイメージに対してイメージ変換を使用できます。

変換するイメージを読み込みます。

imFilename = "peppers.png";
im = imread(imFilename);

入力イメージのサイズをイメージ変換の入力の次元に変更します。

im = imresize(im,[256,256]);

それを dlarray. に変換します。

X = dlarray(single(im),"SSCB");

利用可能な場合に GPU を使用するには、gpuArray に変換します。

if canUseGPU
    X = gpuArray(X);
end

スタイルをイメージに適用するには、関数 predict. を使用してイメージ変換にフォワード パスします。

Y = predict(netTransform,X);

イメージを範囲 [0 255] に再スケーリングします。まず、関数 tanh を使用して、Y を範囲 [-1 1] に再スケーリングします。次に、出力をシフトおよびスケーリングして、[0 255] の範囲に再スケーリングします。

Y = 255*(tanh(Y)+1)/2;

プロット用に Y を準備します。関数 extractdata を使用して dlarray. からデータを抽出します。関数 gather を使用して、Y を GPU からローカル ワークスペースに転送します。

Y = uint8(gather(extractdata(Y)));

入力イメージ (左) をスタイル化されたイメージ (右) の隣に表示します。

figure
m = imtile({im,Y});
imshow(m)

モデル損失関数

関数 modelLoss は、損失ネットワーク netLoss、イメージ変換ネットワーク netTransform、入力イメージ X のミニバッチ、スタイル イメージのグラム行列を含む配列 SGram、コンテンツ損失に関連する重み contentWeight、スタイル損失に関連する重み styleWeight を入力として受け取ります。この関数は、合計損失、コンテンツに関連する損失 lossContent、スタイルに関連する損失 lossStyle、イメージ変換の学習可能なパラメーターについての合計損失の勾配 gradients、イメージ変換ネットワークの状態 state、および変換後のイメージ Y を返します。

function [loss,lossContent,lossStyle,gradients,state,Y] = ...
    modelLoss(netLoss,netTransform,X,SGram,contentWeight,styleWeight)

[Y,state] = forward(netTransform,X);

Y = 255*(tanh(Y)+1)/2;

[loss,lossContent,lossStyle] = styleTransferLoss(netLoss,Y,X,SGram,contentWeight,styleWeight);

gradients = dlgradient(loss,netTransform.Learnables);

end

スタイル転送損失

関数 styleTransferLoss は、損失ネットワーク netLoss、入力イメージ X, のミニバッチ、変換後のイメージ Y のミニバッチ、スタイル イメージのグラム行列を含む配列 SGram、コンテンツとスタイルに関連するそれぞれの重み contentWeight および styleWeight, を入力として受け取ります。また、合計損失 loss と、個々のコンポーネントであるコンテンツ損失 lossContent およびスタイル損失 lossStyle. を返します。

コンテンツ損失は、入力イメージ X と出力イメージ Y の間における空間構造体の差異がどの程度かを示す測定値です。

一方、スタイル損失は、スタイル イメージ S と出力イメージ Y の間におけるスタイル外観の差異がどの程度かを示します。

以下のグラフは、styleTransferLoss が合計損失を計算するために実装するアルゴリズムを説明しています。

まず、この関数は入力イメージ X、変換後のイメージ Y、スタイル イメージ S を事前学習済みのネットワーク VGG-16 に渡します。この事前学習済みのネットワークは、これらのイメージからいくつかの特徴を抽出します。その後、アルゴリズムは、入力イメージ X と出力イメージ Y の空間的な特徴を使用して、コンテンツ損失を計算します。さらに、出力イメージ Y とスタイル イメージ S のスタイルの特徴を使用して、スタイル損失を計算します。最後に、コンテンツ損失とスタイル損失を加算して、合計損失を求めます。

コンテンツ損失

ミニバッチ内の各イメージについて、コンテンツ損失関数は、元のイメージの特徴と、層 relu3_3 によって出力された変換後のイメージの特徴を比較します。具体的には、次のように、活性化間の平均二乗誤差を計算して、ミニバッチの平均損失を返します。

lossContent=1Nn=1Nmean([ϕ(Xn)-ϕ(Yn)]2),

ここで、X は入力イメージを格納し、Y は変換後のイメージを格納し、N はミニバッチ サイズであり、ϕ() は層 relu3_3. で抽出された活性化を表します。

スタイル損失

スタイル損失を計算するには、ミニバッチに含まれる単一のイメージのそれぞれに対して、次を行います。

  1. relu1_2relu2_2relu3_3relu4_3 で活性化を抽出します。

  2. 4 つの活性化 ϕj のそれぞれについて、グラム行列 G(ϕj) を計算します。

  3. 対応するグラム行列間の二乗差を計算します。

  4. 前の手順で得られた各層 j について、4 つの出力を合計します。

ミニバッチ全体のスタイル損失を求めるには、次のように、ミニバッチ内の各イメージ n についてスタイル損失の平均を計算します。

lossStyle=1Nn=1Nj=14[G(ϕj(Xn))-G(ϕj(S))]2,

ここで、j は層のインデックス、G() はグラム行列 です。

合計損失

function [loss,lossContent,lossStyle] = styleTransferLoss(netLoss,Y,X, ...
    SGram,weightContent,weightStyle)

% Extract activations.
YActivations = cell(1,4);
[YActivations{1},YActivations{2},YActivations{3},YActivations{4}] = ...
    forward(netLoss,Y,'Outputs',["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

XActivations = forward(netLoss,X,'Outputs','relu3_3');

% Calculate the mean square error between activations.
lossContent = mean((YActivations{3} - XActivations).^2,'all');

% Add up the losses for all the four activations.
lossStyle = 0;
for j = 1:4
    G = createGramMatrix(YActivations{j});
    lossStyle = lossStyle + sum((G - SGram{j}).^2,'all');
end

% Average the loss over the mini-batch.
miniBatchSize = size(X,4);
lossStyle = lossStyle/miniBatchSize;

% Apply weights.
lossContent = weightContent * lossContent;
lossStyle = weightStyle * lossStyle;

% Calculate the total loss.
loss = lossContent + lossStyle;

end

残差ブロック

関数 residualBlock は、6 つの層の配列を返します。これは畳み込み層、インスタンス正規化層、ReLu 層、加算層で構成されます。groupNormalizationLayer('channel-wise') は単純にインスタンス正規化層であることに注意してください。

function layers = residualBlock(name)

layers = [    
    convolution2dLayer([3 3], 128,Padding="same",Name="convRes_"+name+"_1")
    groupNormalizationLayer("channel-wise",Name="normRes_"+name+"_1")
    reluLayer(Name="reluRes_"+name+"_1")
    convolution2dLayer([3 3],128,Padding="same",Name="convRes_"+name+"_2")
    groupNormalizationLayer("channel-wise",Name="normRes_"+name+"_2")
    additionLayer(2,Name="add_"+name)];

end

グラム行列

関数 createGramMatrix は、単一の層の活性化を入力として受け取り、ミニバッチ内の各イメージについてスタイル表現を返します。. 入力は、サイズ [H, W, C, N] の特徴マップとなり、ここで、H は高さ、W は幅、C はチャネルの数、N はミニバッチ サイズです。この関数は、サイズ [C,C,N] の配列 G を出力します。各部分配列 G(:,:,k) は、ミニバッチ内の kth イメージに対応するグラム行列です。グラム行列の各エントリ G(i,j,k) は、チャネル cicj の間の相関を表します。なぜなら、チャネル ci 内の各エントリは、次のように、チャネル cj 内の対応する位置にあるエントリを乗算するためです。

G(i,j,k)=1C×H×Wh=1Hw=1Wϕk(h,w,ci)ϕk(h,w,cj),

ここで、ϕk は、ミニバッチ内の kth イメージの活性化です。

グラム行列には、一緒に活性化する特徴についての情報が含まれますが、イメージのどこに特徴が表れるかについての情報は含まれません。これは、高さと幅の総和は、空間構造体についての情報を失うためです。損失関数は、この行列をイメージのスタイル表現として使用します。

function G = createGramMatrix(activations)

[h,w,numChannels] = size(activations,1:3);

features = reshape(activations,h*w,numChannels,[]);
featuresT = permute(features,[2 1 3]);

G = dlmtimes(featuresT,features) / (h*w*numChannels);

end

参考文献

  1. Johnson, Justin, Alexandre Alahi, and Li Fei-Fei. "Perceptual losses for real-time style transfer and super-resolution." European conference on computer vision. Springer, Cham, 2016.

参考

| | | | | |

関連するトピック