Main Content

深層学習を使用したニューラル スタイル転送

この例では、事前学習済みの VGG-19 ネットワークを使用して、あるイメージのスタイル外観を別のイメージのシーン コンテンツに適用する方法を説明します。

データの読み込み

スタイル イメージとコンテンツ イメージを読み込みます。この例では、ファン・ゴッホの有名な絵画 "星月夜" をスタイル イメージとして使用し、灯台の写真をコンテンツ イメージとして使用します。

styleImage = im2double(imread("starryNight.jpg"));
contentImage = imread("lighthouse.png");

スタイル イメージとコンテンツ イメージをモンタージュとして表示します。

imshow(imtile({styleImage,contentImage},BackgroundColor="w"));

特徴抽出ネットワークの読み込み

この例では、事前学習済みの VGG-19 深層ニューラル ネットワークに変更を加えたものを使用して、コンテンツ イメージとスタイル イメージの特徴をさまざまな層で抽出します。この多層にわたる特徴は、コンテンツとスタイルの損失を計算するのに使用されます。このネットワークは、組み合わされた損失を使用して、スタイル化された転送イメージを生成します。

事前学習済みの VGG-19 ネットワークを取得するには、vgg19 をインストールします。必要なサポート パッケージがインストールされていない場合、ダウンロード用リンクが表示されます。

net = vgg19;

VGG-19 ネットワークを特徴抽出に適したネットワークにするには、全結合層をネットワークからすべて削除します。

lastFeatureLayerIdx = 38;
layers = net.Layers;
layers = layers(1:lastFeatureLayerIdx);

VGG-19 ネットワークの最大プーリング層は、フェージング効果を引き起こします。フェージング効果を減らして勾配フローを増やすには、すべての最大プーリング層を平均プーリング層に置き換えます [1]

for l = 1:lastFeatureLayerIdx
    layer = layers(l);
    if isa(layer,"nnet.cnn.layer.MaxPooling2DLayer")
        layers(l) = averagePooling2dLayer(layer.PoolSize,Stride=layer.Stride,Name=layer.Name);
    end
end

変更を加えた層を使用して、層グラフを作成します。

lgraph = layerGraph(layers);

特徴抽出ネットワークをプロットで可視化します。

plot(lgraph)
title("Feature Extraction Network")

カスタム学習ループを使用してネットワークに学習させ、自動微分を有効にするには、層グラフを dlnetwork オブジェクトに変換します。

dlnet = dlnetwork(lgraph);

データの前処理

より高速に処理できるように、スタイル イメージとコンテンツ イメージのサイズを小さくします。

imageSize = [384,512];
styleImg = imresize(styleImage,imageSize);
contentImg = imresize(contentImage,imageSize);

事前学習済みの VGG-19 ネットワークは、チャネル単位の平均値を減算したイメージで分類を実行します。イメージ入力層 (ネットワークの最初の層) から、チャネル単位の平均値を取得します。

imgInputLayer = lgraph.Layers(1);
meanVggNet = imgInputLayer.Mean(1,1,:);

チャネル単位の平均値は、ピクセル値の範囲が [0, 255] である浮動小数点データ型のイメージに適しています。スタイル イメージとコンテンツ イメージを、範囲が [0, 255] であるデータ型 single に変換します。その後、スタイル イメージとコンテンツ イメージからチャネル単位の平均値を減算します。

styleImg = rescale(single(styleImg),0,255) - meanVggNet;
contentImg = rescale(single(contentImg),0,255) - meanVggNet;

転送イメージの初期化

転送イメージは、スタイル転送の結果として得られる出力イメージです。この転送イメージは、スタイル イメージ、コンテンツ イメージ、または任意のランダム イメージを使用して初期化できます。スタイル イメージまたはコンテンツ イメージを使用して初期化すると、スタイル転送処理にバイアスがかけられ、転送イメージは入力イメージに似たものになります。それに対し、ホワイト ノイズを使用して初期化すると、バイアスが除去されますが、スタイル化されたイメージに収束するまで時間がかかります。この例では、より優れたスタイル化とより高速な収束を実現するため、コンテンツ イメージとホワイト ノイズ イメージの重み付けされた組み合わせとして出力転送イメージを初期化します。

noiseRatio = 0.7;
randImage = randi([-20,20],[imageSize 3]);
transferImage = noiseRatio.*randImage + (1-noiseRatio).*contentImg;

損失関数とスタイル転送パラメーターの定義

コンテンツ損失

コンテンツ損失の目的は、転送イメージの特徴をコンテンツ イメージの特徴に一致させることです。コンテンツ損失は、各コンテンツ特徴層におけるコンテンツ イメージの特徴と転送イメージの特徴との間の平均二乗差として計算されます [1]Yˆ は転送イメージについて予測された特徴マップで、Y はコンテンツ イメージについて予測された特徴マップです。Wcl は、lth 番目の層におけるコンテンツ層の重みです。H,W,C はそれぞれ、特徴マップの高さ、幅、チャネル数です。

Lcontent=lWcl×1HWCi,j(Yˆi,jl-Yi,jl)2

コンテンツ特徴抽出層の名前を指定します。これらの層から抽出された特徴は、コンテンツ損失の計算に使用されます。VGG-19 ネットワークでは、浅い層から抽出された特徴を使用するよりも、深い層から抽出された特徴を使用するほうがより効果的です。そのため、コンテンツ特徴抽出層を 4 番目の畳み込み層として指定します。

styleTransferOptions.contentFeatureLayerNames = "conv4_2";

コンテンツ特徴抽出層の重みを指定します。

styleTransferOptions.contentFeatureLayerWeights = 1;

スタイル損失

スタイル損失の目的は、転送イメージのテクスチャをスタイル イメージのテクスチャに一致させることです。イメージのスタイル表現は、グラム行列として表されます。そのため、スタイル損失は、スタイル イメージのグラム行列と転送イメージのグラム行列との間の平均二乗差として計算されます [1]ZZˆ はそれぞれ、スタイル イメージと転送イメージについて予測された特徴マップです。GZGZˆ はそれぞれ、スタイル イメージの特徴量と転送イメージの特徴量に関するグラム行列です。Wsl は、lth 番目の層におけるスタイル層の重みです。

GZˆ=i,jZˆi,j×Zˆj,i

GZ=i,jZi,j×Zj,i

Lstyle=lWsl×1(2HWC)2(GZˆl-GZl)2

スタイル特徴抽出層の名前を指定します。これらの層から抽出された特徴は、スタイル損失の計算に使用されます。

styleTransferOptions.styleFeatureLayerNames = ["conv1_1","conv2_1","conv3_1","conv4_1","conv5_1"];

スタイル特徴抽出層の重みを指定します。単純なスタイル イメージの場合は小さな重みを指定し、複雑なスタイル イメージの場合は重みを増やします。

styleTransferOptions.styleFeatureLayerWeights = [0.5,1.0,1.5,3.0,4.0];

合計損失

合計損失は、コンテンツ損失とスタイル損失が重み付けされて組み合わされたものです。αβ はそれぞれ、コンテンツ損失とスタイル損失の重み係数です。

Ltotal=α×Lcontent+β×Lstyle

コンテンツ損失とスタイル損失の重み係数 alphabeta を指定します。alphabeta の比は、約 1e-3 または 1e-4 でなければなりません [1]

styleTransferOptions.alpha = 1; 
styleTransferOptions.beta = 1e3;

学習オプションの指定

2500 回反復して学習させます。

numIterations = 2500;

Adam 最適化のオプションを指定します。より早く収束するように、学習率を 2 に設定します。出力イメージと損失を観測することで、学習率を試すことができます。最後の平均勾配と最後の平均 2 乗勾配減衰率を [] に初期化します。

learningRate = 2;
trailingAvg = [];
trailingAvgSq = [];

ネットワークの学習

スタイル イメージ、コンテンツ イメージ、および転送イメージを、基となる型が single で次元ラベルが "SSC" であるdlarrayオブジェクトに変換します。

dlStyle = dlarray(styleImg,"SSC");
dlContent = dlarray(contentImg,"SSC");
dlTransfer = dlarray(transferImage,"SSC");

GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™、および CUDA® 対応の NVIDIA® GPU が必要です。詳細については、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。GPU で学習する場合、データを gpuArray に変換します。

if canUseGPU
    dlContent = gpuArray(dlContent);
    dlStyle = gpuArray(dlStyle);
    dlTransfer = gpuArray(dlTransfer);
end

コンテンツ イメージからコンテンツの特徴を抽出します。

numContentFeatureLayers = numel(styleTransferOptions.contentFeatureLayerNames);
contentFeatures = cell(1,numContentFeatureLayers);
[contentFeatures{:}] = forward(dlnet,dlContent,Outputs=styleTransferOptions.contentFeatureLayerNames);

スタイル イメージからスタイルの特徴を抽出します。

numStyleFeatureLayers = numel(styleTransferOptions.styleFeatureLayerNames);
styleFeatures = cell(1,numStyleFeatureLayers);
[styleFeatures{:}] = forward(dlnet,dlStyle,Outputs=styleTransferOptions.styleFeatureLayerNames);

カスタム学習ループを使用してモデルに学習させます。それぞれの反復で次を行います。

  • コンテンツ イメージ、スタイル イメージ、および転送イメージの特徴量を使用して、コンテンツ損失とスタイル損失を計算します。損失と勾配を計算するには、補助関数 imageGradients を使用します (この例のサポート関数の節で定義されています)。

  • 関数 adamupdate を使用して、転送イメージを更新します。

  • 最も優れたスタイル転送イメージを最終的な出力イメージとして選択します。

figure

minimumLoss = inf;

for iteration = 1:numIterations
    % Evaluate the transfer image gradients and state using dlfeval and the
    % imageGradients function listed at the end of the example
    [grad,losses] = dlfeval(@imageGradients,dlnet,dlTransfer,contentFeatures,styleFeatures,styleTransferOptions);
    [dlTransfer,trailingAvg,trailingAvgSq] = adamupdate(dlTransfer,grad,trailingAvg,trailingAvgSq,iteration,learningRate);
  
    if losses.totalLoss < minimumLoss
        minimumLoss = losses.totalLoss;
        dlOutput = dlTransfer;        
    end   
    
    % Display the transfer image on the first iteration and after every 50
    % iterations. The postprocessing steps are described in the "Postprocess
    % Transfer Image for Display" section of this example
    if mod(iteration,50) == 0 || (iteration == 1)
        
        transferImage = gather(extractdata(dlTransfer));
        transferImage = transferImage + meanVggNet;
        transferImage = uint8(transferImage);
        transferImage = imresize(transferImage,size(contentImage,[1 2]));
        
        image(transferImage)
        title(["Transfer Image After Iteration ",num2str(iteration)])
        axis off image
        drawnow
    end   
    
end

転送イメージを表示するための後処理

更新された転送イメージを取得します。

transferImage = gather(extractdata(dlOutput));

ネットワークで学習した平均値を転送イメージに追加します。

transferImage = transferImage + meanVggNet;

ピクセルによっては、ピクセル値がコンテンツ イメージとスタイル イメージの元の範囲 [0, 255] を超える場合があります。データ型を uint8 に変換することで、この値を [0, 255] の範囲にクリップすることができます。

transferImage = uint8(transferImage);

転送イメージのサイズを、コンテンツ イメージの元のサイズに合わせて変更します。

transferImage = imresize(transferImage,size(contentImage,[1 2]));

コンテンツ イメージ、転送イメージ、およびスタイル イメージをモンタージュに表示します。

imshow(imtile({contentImage,transferImage,styleImage}, ...
    GridSize=[1 3],BackgroundColor="w"));

サポート関数

イメージの損失と勾配の計算

補助関数 imageGradients は、コンテンツ イメージ、スタイル イメージ、および転送イメージの特徴量を使用して、損失と勾配を返します。

function [gradients,losses] = imageGradients(dlnet,dlTransfer,contentFeatures,styleFeatures,params)
 
    % Initialize transfer image feature containers
    numContentFeatureLayers = numel(params.contentFeatureLayerNames);
    numStyleFeatureLayers = numel(params.styleFeatureLayerNames);
 
    transferContentFeatures = cell(1,numContentFeatureLayers);
    transferStyleFeatures = cell(1,numStyleFeatureLayers);
 
    % Extract content features of transfer image
    [transferContentFeatures{:}] = forward(dlnet,dlTransfer,Outputs=params.contentFeatureLayerNames);
     
    % Extract style features of transfer image
    [transferStyleFeatures{:}] = forward(dlnet,dlTransfer,Outputs=params.styleFeatureLayerNames);
 
    % Calculate content loss
    cLoss = contentLoss(transferContentFeatures,contentFeatures,params.contentFeatureLayerWeights);
 
    % Calculate style loss
    sLoss = styleLoss(transferStyleFeatures,styleFeatures,params.styleFeatureLayerWeights);
 
    % Calculate final loss as weighted combination of content and style loss 
    loss = (params.alpha * cLoss) + (params.beta * sLoss);
 
    % Calculate gradient with respect to transfer image
    gradients = dlgradient(loss,dlTransfer);
    
    % Extract various losses
    losses.totalLoss = gather(extractdata(loss));
    losses.contentLoss = gather(extractdata(cLoss));
    losses.styleLoss = gather(extractdata(sLoss));
 
end

コンテンツ損失の計算

補助関数 contentLoss は、コンテンツ イメージの特徴量と転送イメージの特徴量との間の重み付け平均二乗差を計算します。

function loss = contentLoss(transferContentFeatures,contentFeatures,contentWeights)

    loss = 0;
    for i=1:numel(contentFeatures)
        temp = 0.5 .* mean((transferContentFeatures{1,i} - contentFeatures{1,i}).^2,"all");
        loss = loss + (contentWeights(i)*temp);
    end
end

スタイル損失の計算

補助関数 styleLoss は、スタイル イメージの特徴量に関するグラム行列と転送イメージの特徴量に関するグラム行列との間の重み付け平均二乗差を計算します。

function loss = styleLoss(transferStyleFeatures,styleFeatures,styleWeights)

    loss = 0;
    for i=1:numel(styleFeatures)
        
        tsf = transferStyleFeatures{1,i};
        sf = styleFeatures{1,i};    
        [h,w,c] = size(sf);
        
        gramStyle = calculateGramMatrix(sf);
        gramTransfer = calculateGramMatrix(tsf);
        sLoss = mean((gramTransfer - gramStyle).^2,"all") / ((h*w*c)^2);
        
        loss = loss + (styleWeights(i)*sLoss);
    end
end

グラム行列の計算

補助関数 styleLoss は、補助関数 calculateGramMatrix を使用して特徴マップのグラム行列を計算します。

function gramMatrix = calculateGramMatrix(featureMap)
    [H,W,C] = size(featureMap);
    reshapedFeatures = reshape(featureMap,H*W,C);
    gramMatrix = reshapedFeatures' * reshapedFeatures;
end

参考文献

[1] Leon A. Gatys, Alexander S. Ecker, and Matthias Bethge. "A Neural Algorithm of Artistic Style." Preprint, submitted September 2, 2015. https://arxiv.org/abs/1508.06576

参考

| | |

関連するトピック