Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

深層学習を使用したセグメンテーション マップからのイメージの生成

この例では、pix2pixHD 条件付き敵対的生成ネットワーク (CGAN) を使用して、セマンティック セグメンテーション マップからシーンの合成イメージを生成する方法を説明します。

pix2pixHD [1] は 2 つのネットワークで構成されています。これらのネットワークは、両者の性能を最大化するために同時に学習されます。

  1. ジェネレーターは、符号化器-復号化器スタイルのニューラル ネットワークで、セマンティック セグメンテーション マップからシーン イメージを生成します。CGAN のネットワークは、ディスクリミネーターが本物であると誤分類するようなシーン イメージを生成するように、ジェネレーターに学習させます。

  2. ディスクリミネーターは完全畳み込みニューラル ネットワークで、生成されたシーン イメージとそれに対応する実イメージを比較し、偽物と本物を分類するよう試みます。CGAN のネットワークは、生成されたイメージと実イメージを正しく区別できるようにディスクリミネーターに学習させます。

学習中、ジェネレーターとディスクリミネーターのネットワークは互いに競い合います。学習は、どちらのネットワークもさらに向上させることができなくなったときに収束します。

データの読み込み

この例では、学習用に Cambridge University の CamVid データ セット [2] を使用します。このデータ セットは、運転中に得られた路上レベルでのビューが含まれる 701 個のイメージ コレクションです。データ セットは、車、歩行者、道路を含む 32 個のセマンティック クラスについてピクセルのラベルを提供します。

次の URL から CamVid データ セットをダウンロードします。ダウンロード時間はお使いのインターネット接続によって異なります。

imageURL = "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/files/701_StillsRaw_full.zip";
labelURL = "http://web4.cs.ucl.ac.uk/staff/g.brostow/MotionSegRecData/data/LabeledApproved_full.zip";

dataDir = fullfile(tempdir,"CamVid"); 
downloadCamVidData(dataDir,imageURL,labelURL);
imgDir = fullfile(dataDir,"images","701_StillsRaw_full");
labelDir = fullfile(dataDir,"labels");

学習用データの準備

CamVid データ セット内のイメージを保存するimageDatastoreを作成します。

imds = imageDatastore(imgDir);
imageSize = [576 768];

補助関数 defineCamVid32ClassesAndPixelLabelIDs を使用して、CamVid データ セットに含まれる 32 個のクラスのクラス名とピクセル ラベル ID を定義します。補助関数 camvid32ColorMap を使用して、CamVid データ セットの標準カラーマップを取得します。この補助関数は、この例にサポート ファイルとして添付されています。

numClasses = 32;
[classes,labelIDs] = defineCamVid32ClassesAndPixelLabelIDs;
cmap = camvid32ColorMap;

ピクセル ラベル イメージを保存するpixelLabelDatastoreを作成します。

pxds = pixelLabelDatastore(labelDir,classes,labelIDs);

ピクセル ラベル イメージおよび対応するグラウンド トゥルース シーン イメージをプレビューします。関数label2rgbを使用して categorical ラベルを RGB カラーに変換し、ピクセル ラベル イメージとグラウンド トゥルース イメージをモンタージュに表示します。

im = preview(imds);
px = preview(pxds);
px = label2rgb(px,cmap);
montage({px,im})

補助関数 partitionCamVidForPix2PixHD を使用してデータを学習セットとテスト セットに分割します。この関数は、この例にサポート ファイルとして添付されています。補助関数は、データを 648 個の学習ファイルと 32 個のテスト ファイルに分割します。

[imdsTrain,imdsTest,pxdsTrain,pxdsTest] = partitionCamVidForPix2PixHD(imds,pxds,classes,labelIDs);

関数combineを使用し、ピクセル ラベル イメージとグラウンド トゥルース シーン イメージを組み合わせて単一のデータストアにします。

dsTrain = combine(pxdsTrain,imdsTrain);

関数transformを、補助関数 preprocessCamVidForPix2PixHD によって指定されたカスタム前処理演算と共に使用して、学習データを拡張します。この補助関数は、この例にサポート ファイルとして添付されています。

関数 preprocessCamVidForPix2PixHD は以下の操作を実行します。

  1. グラウンド トゥルース データを [-1, 1] の範囲にスケーリングします。この範囲は、ジェネレーター ネットワークの最終的なtanhLayer (Deep Learning Toolbox)の範囲と一致します。

  2. 双三次と最近傍のダウンサンプリングをそれぞれ使用し、ネットワークの出力サイズ (576 x 768 ピクセル) に合うようにイメージとラベルのサイズを変更します。

  3. 関数onehotencode (Deep Learning Toolbox)を使用し、単一チャネルのセグメンテーション マップを、32 チャネルの one-hot 符号化されたセグメンテーション マップに変換します。

  4. イメージとピクセル ラベルのペアを水平方向にランダムに反転します。

dsTrain = transform(dsTrain,@(x) preprocessCamVidForPix2PixHD(x,imageSize));

one-hot 符号化されたセグメンテーション マップのチャネルをモンタージュにプレビューします。各チャネルは、一意のクラスのピクセルに対応する one-hot マップを表します。

map = preview(dsTrain);
montage(map{1},"Size",[4 8],"Bordersize",5,"BackgroundColor","b")

ジェネレーター ネットワークの構成

深さ方向に one-hot 符号化されたセグメンテーション マップからシーン イメージを生成する pix2pixHD ジェネレーター ネットワークを定義します。この入力は、元のセグメンテーション マップと同じ高さと幅をもち、クラスと同じ数のチャネルをもちます。

generatorInputSize = [imageSize numClasses];

関数pix2pixHDGlobalGeneratorを使用して pix2pixHD ジェネレーター ネットワークを作成します。

dlnetGenerator = pix2pixHDGlobalGenerator(generatorInputSize);

ネットワーク アーキテクチャを表示します。

analyzeNetwork(dlnetGenerator)

この例で示しているのは 576 x 768 ピクセルのサイズのイメージを生成するために pix2pixHD グローバル ジェネレーターを使用する方法であることに注意してください。1152 x 1536 ピクセルまたはそれを超えるような高解像度でイメージを生成するローカル エンハンサー ネットワークを作成するには、関数addPix2PixHDLocalEnhancerを使用できます。ローカル エンハンサー ネットワークは、非常に高い解像度において細かいレベルの詳細を生成するのに役立ちます。

ディスクリミネーター ネットワークの構成

入力イメージを本物 (1) または偽物 (0) のいずれかに分類する PatchGAN ディスクリミネーター ネットワークを定義します。この例では、2 つのディスクリミネーター ネットワークを異なる入力スケールで使用します。これは、マルチスケール ディスクリミネーターとも呼ばれます。最初のスケールのサイズはイメージ サイズと同じで、2 番目のスケールのサイズはイメージ サイズの半分です。

ディスクリミネーターへの入力は、one-hot 符号化されたセグメンテーション マップと分類されるシーン イメージを深さ方向に連結したものです。ラベル付けされたクラスの数とイメージのカラー チャネル数の合計として、ディスクリミネーターに入力するチャネルの数を指定します。

numImageChannels = 3;
numChannelsDiscriminator = numClasses + numImageChannels;

最初のディスクリミネーターの入力サイズを指定します。関数patchGANDiscriminatorを使用して、インスタンスの正規化を使用する PatchGAN ディスクリミネーターを作成します。

discriminatorInputSizeScale1 = [imageSize numChannelsDiscriminator];
dlnetDiscriminatorScale1 = patchGANDiscriminator(discriminatorInputSizeScale1,NormalizationLayer="instance");

2 番目のディスクリミネーターの入力サイズをイメージ サイズの半分として指定した後、2 番目の PatchGAN ディスクリミネーターを作成します。

discriminatorInputSizeScale2 = [floor(imageSize)./2 numChannelsDiscriminator];
dlnetDiscriminatorScale2 = patchGANDiscriminator(discriminatorInputSizeScale2,NormalizationLayer="instance");

ネットワークを可視化します。

analyzeNetwork(dlnetDiscriminatorScale1);
analyzeNetwork(dlnetDiscriminatorScale2);

モデル勾配と損失関数の定義

補助関数 modelGradients は、ジェネレーターとディスクリミネーターの勾配と敵対的損失を計算します。この関数は、ジェネレーターの特徴マッチング損失と VGG 損失も計算します。この関数は、この例のサポート関数の節で定義されています。

ジェネレーターの損失

ジェネレーターの目的は、ディスクリミネーターが本物 (1) と分類するようなイメージを生成することです。ジェネレーターの損失は、3 つの損失で構成されています。

  • 敵対的損失は、1 のベクトルと、生成されたイメージにおけるディスクリミネーターの予測との間の二乗差として計算されます。Yˆgenerated は、ジェネレーターによって生成されたイメージについてのディスクリミネーターの予測です。この損失は、この例のサポート関数の節で定義されている補助関数 pix2pixhdAdversarialLoss の一部を使用して実装されます。

lossAdversarialGenerator=(1-Yˆgenerated)2

  • 特徴マッチング損失は、実際の特徴マップと生成された特徴マップ (ディスクリミネーター ネットワークからの予測として取得) との距離 L1 にペナルティを課します。T は、ディスクリミネーターの特徴層の総数です。YrealYˆgenerated はそれぞれ、グラウンド トゥルース イメージと生成されたイメージです。この損失は、この例のサポート関数の節で定義された補助関数 pix2pixhdFeatureMatchingLoss を使用して実装されます。

lossFeatureMatching=i=1T||Yreal-Yˆgenerated||1

  • 知覚的損失は、実際の特徴マップと生成された特徴マップ (特徴抽出ネットワークからの予測として取得) との距離 L1 にペナルティを課します。T は特徴層の総数です。YVggRealYˆVggGenerated はそれぞれ、グラウンド トゥルース イメージと生成されたイメージのネットワーク予測です。この損失は、この例のサポート関数の節で定義された補助関数 pix2pixhdVggLoss を使用して実装されます。特徴抽出ネットワークは、特徴抽出ネットワークの読み込みで作成されます。

lossVgg=i=1T||YVggReal-YˆVggGenerated||1

ジェネレーター全体の損失は、3 つの損失の重み付き和です。λ1λ2λ3 はそれぞれ、敵対的損失、特徴マッチング損失、知覚的損失の重み係数です。

lossGenerator=λ1*lossAdversarialGenerator+λ2*lossFeatureMatching+λ3*lossPerceptual

ジェネレーターの敵対的損失と特徴マッチング損失は、2 つの異なるスケールで計算されることに注意してください。

ディスクリミネーターの損失

ディスクリミネーターの目的は、グラウンド トゥルース イメージと生成されたイメージを正しく区別することです。ディスクリミネーターの損失は、2 つの要素の和です。

  • 1 のベクトルと、実イメージについてディスクリミネーターが行った予測との間の二乗差

  • 0 のベクトルと、生成されたイメージについてディスクリミネーターが行った予測との間の二乗差

lossDiscriminator=(1-Yreal)2+(0-Yˆgenerated)2

ディスクリミネーターの損失は、この例のサポート関数の節で定義された補助関数 pix2pixhdAdversarialLoss の一部を使用して実装されます。ディスクリミネーターの敵対的損失は、2 つの異なるディスクリミネーターのスケールで計算されることに注意してください。

事前学習済み特徴抽出ネットワークの読み込み

この例では、実イメージと生成されたイメージの特徴をさまざまな層で抽出できるように、事前学習済みの VGG-19 深層ニューラル ネットワークを変更します。この多層にわたる特徴は、ジェネレーターの知覚的損失を計算するのに使用されます。

事前学習済みの VGG-19 ネットワークを取得するには、vgg19 (Deep Learning Toolbox)をインストールします。必要なサポート パッケージがインストールされていない場合、ダウンロード用リンクが表示されます。

netVGG = vgg19;

特徴抽出ネットワークの構成

ディープ ネットワーク デザイナー (Deep Learning Toolbox)アプリを使用して、ネットワーク アーキテクチャを可視化します。

deepNetworkDesigner(netVGG)

VGG-19 ネットワークを特徴抽出に適したネットワークにするには、"pool5" までの層を保持し、全結合層をネットワークからすべて削除します。結果として得られるネットワークは、完全畳み込みネットワークです。

netVGG = layerGraph(netVGG.Layers(1:38));

正規化を行わない新しいイメージ入力層を作成します。元のイメージ入力層を新しい層に置き換えます。

inp = imageInputLayer([imageSize 3],Normalization="None",Name="Input");
netVGG = replaceLayer(netVGG,"input",inp);
netVGG = dlnetwork(netVGG);

ネットワークの学習

学習オプションの指定

Adam 最適化のオプションを指定します。学習を 60 エポック行います。ジェネレーターとディスクリミネーターのネットワークで同じオプションを指定します。

  • 同じ学習率 0.0002 を指定します。

  • 最後の平均勾配と最後の平均 2 乗勾配減衰率を [] に初期化します。

  • 勾配の減衰係数に 0.5、2 乗勾配の減衰係数に 0.999 を使用します。

  • 学習用のミニバッチ サイズとして 1 を使用します。

numEpochs = 60;
learningRate = 0.0002;
trailingAvgGenerator = [];
trailingAvgSqGenerator = [];
trailingAvgDiscriminatorScale1 = [];
trailingAvgSqDiscriminatorScale1 = [];
trailingAvgDiscriminatorScale2 = [];
trailingAvgSqDiscriminatorScale2 = [];
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;
miniBatchSize = 1;

カスタム学習ループで観測値のミニバッチを管理するminibatchqueue (Deep Learning Toolbox)オブジェクトを作成します。また、minibatchqueue オブジェクトは、深層学習アプリケーションで自動微分を可能にするdlarray (Deep Learning Toolbox)オブジェクトにデータをキャストします。

ミニバッチのデータ抽出形式を SSCB (spatial、spatial、channel、batch) として指定します。DispatchInBackground の名前と値のペアの引数をcanUseGPUによって返されるブール値として設定します。サポートされている GPU を計算に使用できる場合、minibatchqueue オブジェクトは、学習中に並列プールのバックグラウンドでミニバッチを前処理します。

mbqTrain = minibatchqueue(dsTrain,MiniBatchSiz=miniBatchSize, ...
   MiniBatchFormat="SSCB",DispatchInBackground=canUseGPU);

この例では既定で、補助関数 downloadTrainedPix2PixHDNet を使用して、CamVid データ セット用に事前学習済みバージョンの pix2pixHD ジェネレーター ネットワークをダウンロードします。この補助関数は、この例にサポート ファイルとして添付されています。この事前学習済みのネットワークを使用することで、学習の完了を待たずに例全体を実行できます。

ネットワークに学習させるには、次のコードで変数 doTrainingtrue に設定します。カスタム学習ループでモデルに学習させます。それぞれの反復で次を行います。

  • 関数next (Deep Learning Toolbox)を使用して、現在のミニバッチ データを読み取ります。

  • 関数dlfeval (Deep Learning Toolbox)と補助関数 modelGradients を使用してモデルの勾配を評価します。

  • 関数adamupdate (Deep Learning Toolbox)を使用してネットワーク パラメーターを更新します。

  • 反復のたびに学習の進行状況プロットを更新し、計算されたさまざまな損失を表示します。

GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™、および CUDA® 対応の NVIDIA® GPU が必要です。詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

学習には NVIDIA™ Titan RTX で約 22 時間を要します。ご使用の GPU ハードウェアによっては、さらに長い時間がかかる可能性もあります。GPU デバイスのメモリが少ない場合は、例の学習データの前処理の節で、変数 imageSize を [480 640] に指定して入力イメージのサイズを縮小します。

doTraining = false;
if doTraining
    fig = figure;    
    
    lossPlotter = configureTrainingProgressPlotter(fig);
    iteration = 0;

    % Loop over epochs
    for epoch = 1:numEpochs
        
        % Reset and shuffle the data
        reset(mbqTrain);
        shuffle(mbqTrain);
 
        % Loop over each image
        while hasdata(mbqTrain)
            iteration = iteration + 1;
            
            % Read data from current mini-batch
            [dlInputSegMap,dlRealImage] = next(mbqTrain);
            
            % Evaluate the model gradients and the generator state using
            % dlfeval and the GANLoss function listed at the end of the
            % example
            [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = dlfeval( ...
                @modelGradients,dlInputSegMap,dlRealImage,dlnetGenerator,dlnetDiscriminatorScale1,dlnetDiscriminatorScale2,netVGG);
            
            % Update the generator parameters
            [dlnetGenerator,trailingAvgGenerator,trailingAvgSqGenerator] = adamupdate( ...
                dlnetGenerator,gradParamsG, ...
                trailingAvgGenerator,trailingAvgSqGenerator,iteration, ...
                learningRate,gradientDecayFactor,squaredGradientDecayFactor);
            
            % Update the discriminator scale1 parameters
            [dlnetDiscriminatorScale1,trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1] = adamupdate( ...
                dlnetDiscriminatorScale1,gradParamsDScale1, ...
                trailingAvgDiscriminatorScale1,trailingAvgSqDiscriminatorScale1,iteration, ...
                learningRate,gradientDecayFactor,squaredGradientDecayFactor);
            
            % Update the discriminator scale2 parameters
            [dlnetDiscriminatorScale2,trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2] = adamupdate( ...
                dlnetDiscriminatorScale2,gradParamsDScale2, ...
                trailingAvgDiscriminatorScale2,trailingAvgSqDiscriminatorScale2,iteration, ...
                learningRate,gradientDecayFactor,squaredGradientDecayFactor);
            
            % Plot and display various losses
            lossPlotter = updateTrainingProgressPlotter(lossPlotter,iteration, ...
                epoch,numEpochs,lossD,lossGGAN,lossGFM,lossGVGG);
        end
    end
    save("trainedPix2PixHDNet.mat","dlnetGenerator");
    
else    
    trainedPix2PixHDNet_url = "https://ssd.mathworks.com/supportfiles/vision/data/trainedPix2PixHDv2.zip";
    netDir = fullfile(tempdir,"CamVid");
    downloadTrainedPix2PixHDNet(trainedPix2PixHDNet_url,netDir);
    load(fullfile(netDir,"trainedPix2PixHDv2.mat"));
end

テスト データから生成されたイメージの評価

CamVid の学習イメージの数が比較的少ないため、この学習済みの Pix2PixHD ネットワークのパフォーマンスは制限されます。さらに、一部のイメージはイメージ シーケンスに属しているため、学習セット内の他のイメージと相関しています。Pix2PixHD ネットワークの有効性を向上させるには、相関のない学習イメージの数が多い別のデータ セットを使用してネットワークに学習させます。

制限があるため、この Pix2PixHD ネットワークは、一部のテスト イメージに対して他よりも現実的なイメージを生成します。結果の違いを示すために、1 番目と 3 番目のテスト イメージに対して生成されたイメージを比較します。最初のテスト イメージのカメラ アングルは、通常の学習イメージよりも道路に対して垂直に向いている珍しい有利な視点になっています。対照的に、3 番目のテスト イメージのカメラ アングルは、道路に沿って向いている典型的な有利な視点で、車線マーカーで分けられた 2 つの車線をとらえています。ネットワークでは、最初のテスト イメージよりも 3 番目のテスト イメージに対する現実的なイメージを生成するパフォーマンスが大幅に向上します。

テスト データから最初のグラウンド トゥルース シーン イメージを取得します。双三次内挿を使用してイメージのサイズを変更します。

idxToTest = 1;
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");

テスト データから対応するピクセル ラベル イメージを取得します。最近傍内挿を使用してピクセル ラベル イメージのサイズを変更します。

segMap = readimage(pxdsTest,idxToTest);
segMap = imresize(segMap,imageSize,"nearest");

関数onehotencode (Deep Learning Toolbox)を使用し、ピクセル ラベル イメージをマルチチャネルの one-hot セグメンテーション マップに変換します。

segMapOneHot = onehotencode(segMap,3,"single");

ジェネレーターにデータを入力する dlarray オブジェクトを作成します。サポートされている GPU を計算に使用できる場合は、データを gpuArray オブジェクトに変換して、GPU で推論を実行します。

dlSegMap = dlarray(segMapOneHot,"SSCB"); 
if canUseGPU
    dlSegMap = gpuArray(dlSegMap);
end

関数predict (Deep Learning Toolbox)を使用して、ジェネレーターと one-hot セグメンテーション マップからシーン イメージを生成します。

dlGeneratedImage = predict(dlnetGenerator,dlSegMap);
generatedImage = extractdata(gather(dlGeneratedImage));

ジェネレーター ネットワークの最終層は、[-1, 1] の範囲で活性化を生成します。表示のために、[0, 1] の範囲で活性化を再スケーリングします。

generatedImage = rescale(generatedImage);

表示のために、関数label2rgbを使用して categorical ラベルを RGB カラーに変換します。

coloredSegMap = label2rgb(segMap,cmap);

RGB ピクセル ラベル イメージ、生成されたシーン イメージ、グラウンド トゥルース シーン イメージをモンタージュに表示します。

figure
montage({coloredSegMap generatedImage gtImage},Size=[1 3])
title("Test Pixel Label Image " + idxToTest + " with Generated and Ground Truth Scene Images")

テスト データから 3 番目のグラウンド トゥルース シーン イメージを取得します。双三次内挿を使用してイメージのサイズを変更します。

idxToTest = 3;  
gtImage = readimage(imdsTest,idxToTest);
gtImage = imresize(gtImage,imageSize,"bicubic");

テスト データから 3 番目のピクセル ラベル イメージを取得し、対応するシーン イメージを生成するために、補助関数 evaluatePix2PixHD を使用できます。この補助関数は、この例にサポート ファイルとして添付されています。

関数 evaluatePix2PixHD は、最初のテスト イメージの評価と同じ演算を実行します。

  • テスト データからピクセル ラベル イメージを取得します。最近傍内挿を使用してピクセル ラベル イメージのサイズを変更します。

  • 関数onehotencode (Deep Learning Toolbox)を使用して、ピクセル ラベル イメージをマルチチャネルの one-hot セグメンテーション マップに変換します。

  • ジェネレーターにデータを入力するための dlarray オブジェクトを作成します。GPU での推論のために、データを gpuArray オブジェクトに変換します。

  • 関数predict (Deep Learning Toolbox)を使用して、ジェネレーターと one-hot セグメンテーション マップからシーン イメージを生成します。

  • 活性化を [0, 1] の範囲に再スケーリングします。

[generatedImage,segMap] = evaluatePix2PixHD(pxdsTest,idxToTest,imageSize,dlnetGenerator);

表示のために、関数label2rgbを使用して categorical ラベルを RGB カラーに変換します。

coloredSegMap = label2rgb(segMap,cmap);

RGB ピクセル ラベル イメージ、生成されたシーン イメージ、グラウンド トゥルース シーン イメージをモンタージュに表示します。

figure
montage({coloredSegMap generatedImage gtImage},Size=[1 3])
title("Test Pixel Label Image " + idxToTest + " with Generated and Ground Truth Scene Images")

カスタム ピクセル ラベル イメージから生成されたイメージの評価

このネットワークが、CamVid データ セット以外のピクセル ラベル イメージに対してどの程度適切に汎化するかを評価するために、カスタム ピクセル ラベル イメージからシーン イメージを生成します。この例では、イメージ ラベラーアプリを使用して作成されたピクセル ラベル イメージを使用します。このピクセル ラベル イメージは、この例にサポート ファイルとして添付されています。利用可能なグラウンド トゥルース イメージはありません。

この例の現在のディレクトリにあるピクセル ラベル イメージを読み取って処理するピクセル ラベル データストアを作成します。

cpxds = pixelLabelDatastore(pwd,classes,labelIDs);

データストアのピクセル ラベル イメージごとに、補助関数 evaluatePix2PixHD を使用してシーン イメージを生成します。

for idx = 1:length(cpxds.Files)

    % Get the pixel label image and generated scene image
    [generatedImage,segMap] = evaluatePix2PixHD(cpxds,idx,imageSize,dlnetGenerator);
    
    % For display, convert the labels from categorical labels to RGB colors
    coloredSegMap = label2rgb(segMap);
    
    % Display the pixel label image and generated scene image in a montage
    figure
    montage({coloredSegMap generatedImage})
    title("Custom Pixel Label Image " + num2str(idx) + " and Generated Scene Image")

end

サポート関数

モデル勾配関数

補助関数 modelGradients は、ジェネレーターとディスクリミネーターの勾配と敵対的損失を計算します。この関数は、ジェネレーターの特徴マッチング損失と VGG 損失も計算します。

function [gradParamsG,gradParamsDScale1,gradParamsDScale2,lossGGAN,lossGFM,lossGVGG,lossD] = modelGradients(inputSegMap,realImage,generator,discriminatorScale1,discriminatorScale2,netVGG)
              
    % Compute the image generated by the generator given the input semantic
    % map.
    generatedImage = forward(generator,inputSegMap);
    
    % Define the loss weights
    lambdaDiscriminator = 1;
    lambdaGenerator = 1;
    lambdaFeatureMatching = 5;
    lambdaVGG = 5;
    
    % Concatenate the image to be classified and the semantic map
    inpDiscriminatorReal = cat(3,inputSegMap,realImage);
    inpDiscriminatorGenerated = cat(3,inputSegMap,generatedImage);
    
    % Compute the adversarial loss for the discriminator and the generator
    % for first scale.
    [DLossScale1,GLossScale1,realPredScale1D,fakePredScale1G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale1);
    
    % Scale the generated image, the real image, and the input semantic map to
    % half size
    resizedRealImage = dlresize(realImage,Scale=0.5,Method="linear");
    resizedGeneratedImage = dlresize(generatedImage,Scale=0.5,Method="linear");
    resizedinputSegMap = dlresize(inputSegMap,Scale=0.5,Method="nearest");
    
    % Concatenate the image to be classified and the semantic map
    inpDiscriminatorReal = cat(3,resizedinputSegMap,resizedRealImage);
    inpDiscriminatorGenerated = cat(3,resizedinputSegMap,resizedGeneratedImage);
    
    % Compute the adversarial loss for the discriminator and the generator
    % for second scale.
    [DLossScale2,GLossScale2,realPredScale2D,fakePredScale2G] = pix2pixHDAdverserialLoss(inpDiscriminatorReal,inpDiscriminatorGenerated,discriminatorScale2);
    
    % Compute the feature matching loss for first scale.
    FMLossScale1 = pix2pixHDFeatureMatchingLoss(realPredScale1D,fakePredScale1G);
    FMLossScale1 = FMLossScale1 * lambdaFeatureMatching;
    
    % Compute the feature matching loss for second scale.
    FMLossScale2 = pix2pixHDFeatureMatchingLoss(realPredScale2D,fakePredScale2G);
    FMLossScale2 = FMLossScale2 * lambdaFeatureMatching;
    
    % Compute the VGG loss
    VGGLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG);
    VGGLoss = VGGLoss * lambdaVGG;
    
    % Compute the combined generator loss
    lossGCombined = GLossScale1 + GLossScale2 + FMLossScale1 + FMLossScale2 + VGGLoss;
    lossGCombined = lossGCombined * lambdaGenerator;
    
    % Compute gradients for the generator
    gradParamsG = dlgradient(lossGCombined,generator.Learnables,RetainData=true);
    
    % Compute the combined discriminator loss
    lossDCombined = (DLossScale1 + DLossScale2)/2 * lambdaDiscriminator;
    
    % Compute gradients for the discriminator scale1
    gradParamsDScale1 = dlgradient(lossDCombined,discriminatorScale1.Learnables,RetainData=true);
    
    % Compute gradients for the discriminator scale2
    gradParamsDScale2 = dlgradient(lossDCombined,discriminatorScale2.Learnables);
    
    % Log the values for displaying later
    lossD = gather(extractdata(lossDCombined));
    lossGGAN = gather(extractdata(GLossScale1 + GLossScale2));
    lossGFM  = gather(extractdata(FMLossScale1 + FMLossScale2));
    lossGVGG = gather(extractdata(VGGLoss));
end

敵対的損失関数

補助関数 pix2pixHDAdverserialLoss は、ジェネレーターとディスクリミネーターの敵対的損失の勾配を計算します。この関数は、実イメージと合成イメージの特徴マップも返します。

function [DLoss,GLoss,realPredFtrsD,genPredFtrsD] = pix2pixHDAdverserialLoss(inpReal,inpGenerated,discriminator)

    % Discriminator layer names containing feature maps
    featureNames = {"act_top","act_mid_1","act_mid_2","act_tail","conv2d_final"};
    
    % Get the feature maps for the real image from the discriminator    
    realPredFtrsD = cell(size(featureNames));
    [realPredFtrsD{:}] = forward(discriminator,inpReal,Outputs=featureNames);
    
    % Get the feature maps for the generated image from the discriminator    
    genPredFtrsD = cell(size(featureNames));
    [genPredFtrsD{:}] = forward(discriminator,inpGenerated,Outputs=featureNames);
    
    % Get the feature map from the final layer to compute the loss
    realPredD = realPredFtrsD{end};
    genPredD = genPredFtrsD{end};
    
    % Compute the discriminator loss
    DLoss = (1 - realPredD).^2 + (genPredD).^2;
    DLoss = mean(DLoss,"all");
    
    % Compute the generator loss
    GLoss = (1 - genPredD).^2;
    GLoss = mean(GLoss,"all");
end

特徴マッチング損失関数

補助関数 pix2pixHDFeatureMatchingLoss は、実イメージと合成イメージ (ジェネレーターによって生成) との間の特徴マッチング損失を計算します。

function featureMatchingLoss = pix2pixHDFeatureMatchingLoss(realPredFtrs,genPredFtrs)

    % Number of features
    numFtrsMaps = numel(realPredFtrs);
    
    % Initialize the feature matching loss
    featureMatchingLoss = 0;
    
    for i = 1:numFtrsMaps
        % Get the feature maps of the real image
        a = extractdata(realPredFtrs{i});
        % Get the feature maps of the synthetic image
        b = genPredFtrs{i};
        
        % Compute the feature matching loss
        featureMatchingLoss = featureMatchingLoss + mean(abs(a - b),"all");
    end
end

知覚的 VGG 損失関数

補助関数 pix2pixHDVGGLoss は、実イメージと合成イメージ (ジェネレーターによって生成) との間の知覚的 VGG 損失を計算します。

function vggLoss = pix2pixHDVGGLoss(realImage,generatedImage,netVGG)

    featureWeights = [1.0/32 1.0/16 1.0/8 1.0/4 1.0];
    
    % Initialize the VGG loss
    vggLoss = 0;
    
    % Specify the names of the layers with desired feature maps
    featureNames = ["relu1_1","relu2_1","relu3_1","relu4_1","relu5_1"];
    
    % Extract the feature maps for the real image
    activReal = cell(size(featureNames));
    [activReal{:}] = forward(netVGG,realImage,Outputs=featureNames);
    
    % Extract the feature maps for the synthetic image
    activGenerated = cell(size(featureNames));
    [activGenerated{:}] = forward(netVGG,generatedImage,Outputs=featureNames);
    
    % Compute the VGG loss
    for i = 1:numel(featureNames)
        vggLoss = vggLoss + featureWeights(i)*mean(abs(activReal{i} - activGenerated{i}),"all");
    end
end

参考文献

[1] Wang, Ting-Chun, Ming-Yu Liu, Jun-Yan Zhu, Andrew Tao, Jan Kautz, and Bryan Catanzaro. "High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs." In 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 8798–8807, 2018. https://doi.org/10.1109/CVPR.2018.00917.

[2] Brostow, Gabriel J., Julien Fauqueur, and Roberto Cipolla."Semantic Object Classes in Video: A High-Definition Ground Truth Database." Pattern Recognition Letters. Vol. 30, Issue 2, 2009, pp 88-97.

参考

(Deep Learning Toolbox) | | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | |

関連するトピック