Main Content

ディープ ネットワーク デザイナーを使用した Image-to-Image 回帰ネットワークの構築

この例では、ディープ ネットワーク デザイナーを使用して、超解像用の image-to-image 回帰ネットワークを構築する方法を示します。

空間解像度は、デジタル イメージを構築するのに使用されるピクセル数です。空間解像度が高いイメージは、それを構成するピクセルの数も多いため、より詳細な情報が含まれることになります。超解像は、低解像度イメージを入力として受け取り、高解像度イメージにアップスケーリングするプロセスです。イメージ データを使用する際、データのサイズを抑えるために、情報の損失と引き換えに空間解像度を下げる場合があります。この損失情報を復元するために、深層学習ネットワークの学習を行い、イメージで損失した詳細情報を予測することができます。この例では、7 x 7 ピクセルに圧縮されたイメージから 28 x 28 ピクセルのイメージを復元します。

データの読み込み

この例では、手書き数字の合成グレースケール イメージ 10,000 個で構成された数字データ セットを使用します。イメージはそれぞれ 28 x 28 x 1 ピクセルです。

データを読み込み、イメージ データストアを作成します。

unzip("DigitsData.zip")

imds = imageDatastore("DigitsData", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

学習前に関数shuffleを使用してデータをシャッフルします。

imds = shuffle(imds);

関数splitEachLabelを使用して、このイメージ データストアを学習用、検証用、テスト用のイメージを含む 3 つのイメージ データストアに分割します。

[imdsTrain,imdsVal,imdsTest] = splitEachLabel(imds,0.7,0.1,0.1,"randomized");

各イメージ内のデータを [0,1] の範囲に正規化します。正規化は、勾配降下を使用したネットワークの学習の安定化と高速化に有効です。データが適切にスケーリングされていない場合、学習中に損失が NaN になり、ネットワーク パラメーターが発散する可能性があります。

imdsTrain = transform(imdsTrain,@(x)rescale(x));
imdsVal = transform(imdsVal,@(x)rescale(x));
imdsTest = transform(imdsTest,@(x)rescale(x));

学習データの生成

アップサンプリングされた低解像度イメージと、それに対応する高解像度イメージで構成されるイメージのペアを生成することにより、学習データ セットを作成します。

image-to-image 回帰を実行するようにネットワークの学習を行うには、イメージは、入力と応答で構成される、サイズが同じイメージのイメージのペアとなる必要があります。各イメージを 7 x 7 ピクセルにダウンサンプリングしてから 28 x 28 ピクセルにアップサンプリングすることにより、学習データを生成します。変換されたイメージと元のイメージのペアを使用することで、ネットワークは 2 つの異なる解像度間でマッピングする方法を学習できます。

補助関数 upsampLowRes を使用して入力データを生成します。この補助関数は imresize を使用して低解像度イメージを生成します。

imdsInputTrain = transform(imdsTrain,@upsampLowRes);
imdsInputVal= transform(imdsVal,@upsampLowRes);
imdsInputTest = transform(imdsTest,@upsampLowRes);

関数combineを使用し、低解像度と高解像度のイメージを組み合わせて単一のデータストアにします。関数 combine の出力はCombinedDatastoreオブジェクトです。

dsTrain = combine(imdsInputTrain,imdsTrain);
dsVal = combine(imdsInputVal,imdsVal);
dsTest = combine(imdsInputTest,imdsTest);

ネットワーク アーキテクチャの作成

Computer Vision Toolbox™ から関数 unet を使用して、ネットワーク アーキテクチャを作成します。この関数は、image-to-image 回帰に簡単に適応可能な、セマンティック セグメンテーションに適したネットワークを提供します。

入力サイズが 28 x 28 x 1 ピクセルのネットワークを作成します。

layers = unet([28,28,1],2,EncoderDepth=2);

ディープ ネットワーク デザイナーを使用して、image-to-image 回帰用にネットワークを編集します。

deepNetworkDesigner(layers);

ソフトマックス層を削除します。

最終畳み込み層を選択します。層のプロパティのロックを解除して、新しいタスクに適応できるようにします。[プロパティ] ペインの下部で、[層のロックを解除] をクリックします。表示される警告ダイアログで、[ロックの強制解除] をクリックします。この層をタスクに適応させるには、[NumFilters]1 に設定します。

ネットワークの学習の準備が整っていることを確認するには、[解析] をクリックします。ネットワーク アナライザーによってエラーや警告が報告されていないため、ネットワークの学習の準備は整っています。ネットワークをワークスペースにエクスポートするには、[エクスポート] をクリックします。アプリはネットワークを変数 net_1 としてエクスポートします。

学習オプションの指定

学習オプションを指定します。

  • Adam 最適化を使用して学習させます。

  • 学習を 15 エポック行います。

  • 検証データを使用してネットワークを検証します。

  • 学習の進行状況をプロットに表示します。

  • 詳細出力を無効にします。

options = trainingOptions("adam", ...
    MaxEpochs=15, ...
    ValidationData=dsVal, ...
    Plots="training-progress", ...
    Verbose=false);

ニューラル ネットワークの学習

関数trainnetを使用して、Image-to-Image 回帰ネットワークに学習させます。回帰タスクの場合は、平均二乗誤差損失を使用します。既定では、関数 trainnet は利用可能な GPU がある場合にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。実行環境を指定するには、ExecutionEnvironment 学習オプションを使用します。

net = trainnet(dsTrain,net_1,"mse",options);

ネットワークのテスト

テスト データを使用して、ネットワーク性能を評価します。

関数minibatchpredictを使用して予測を行います。既定では、関数 minibatchpredict は利用可能な GPU がある場合にそれを使用します。minibatchpredict を使用すれば、学習セットに含まれていなかった低解像度の入力イメージから高解像度イメージをネットワークが生成できるかどうかをテストできます。

ypred = minibatchpredict(net,dsTest);

for i = 1:8
    I(1:2,i) = read(dsTest);
    I(3,i) = {ypred(:,:,:,i)};
end

入力イメージ、予測されたイメージ、応答イメージを比較します。

subplot(1,3,1)
imshow(imtile(I(1,:),GridSize=[8,1]))
title("Input")
subplot(1,3,2)
imshow(imtile(I(3,:),GridSize=[8,1]))
title("Prediction")
subplot(1,3,3)
imshow(imtile(I(2,:),GridSize=[8,1]))
title("Response")

ネットワークは、低解像度の入力から高解像度イメージを正常に生成しています。

この例のネットワークは非常に単純で、数字データ セットに合せて高度に調整されています。日常的なイメージを対象とした、より複雑な image-to-image 回帰ネットワークの作成方法を示す例については、深層学習を使用したイメージの高解像度化を参照してください。

サポート関数

function dataOut = upsampLowRes(dataIn)
temp = dataIn;
temp = imresize(temp,[7,7],method="bilinear");
dataOut = {imresize(temp,[28,28],method="bilinear")};
end

参考

|

関連するトピック