Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

深層学習を使用したカメラのデータ処理パイプラインの構築

この例では、U-Net を使用して、カメラの RAW データを見た目の良いカラー イメージに変換する方法を説明します。

デジタル一眼レフ カメラや、最新の多くの携帯電話のカメラには、カメラのセンサーから直接収集したデータを RAW ファイルとして保存する機能が備わっています。RAW データの各ピクセルは、カメラの対応する光センサーで取得した光の量に直接対応しています。このデータは、電磁スペクトルの特定の波長範囲に対する光センサーごとの感度など、カメラ ハードウェアの決められた特性に左右されます。また、露光時間、光源などのシーンの要素等、カメラの撮影設定にも左右されます。

デモザイク処理は、単一チャネルの RAW データを 3 チャネルの RGB イメージに変換するために必要な唯一の操作です。ただし、さらにイメージ処理操作を行わなければ、得られる RGB イメージの表示品質は主観的に不十分なものになってしまいます。

従来のイメージ処理パイプラインでは、ノイズ除去、線形化、ホワイト バランス、色補正、明度調整、コントラスト調整 [1] などの追加操作を組み合わせて実行します。パイプラインを設計する際の課題としては、シーンや撮影設定の違いに関係なく最終的な RGB イメージの主観的な見た目を最適化するアルゴリズムの調整があります。

RAWtoRGBintro.png

深層学習手法を使用すると、従来の処理パイプラインを構築することなく RAW から RGB への直接変換ができるようになります。たとえば、ある手法では RAW イメージを RGB に変換する際に露出不足を補正します [2]。この例では、ロー エンドの携帯電話カメラの RAW イメージを、ハイ エンドのデジタル一眼レフ カメラの品質に近い RGB イメージに変換する方法を説明します。

Zurich RAW to RGB データ セットのダウンロード

この例では、Zurich RAW to RGB データ セット [3] を使用します。このデータ セットのサイズは 22 GB です。データ セットには、サイズが 448 x 448 の、空間的にレジストレーションされた RAW 学習イメージ パッチおよび RGB 学習イメージ パッチが 48,043 ペア登録されています。このデータ セットには 2 つの異なるテスト セットが格納されています。一方のテスト セットは、サイズが 448 x 448 の、空間的にレジストレーションされた RAW イメージ パッチおよび RGB イメージ パッチ 1,204 ペアで構成されています。もう一方のテスト セットはレジストレーションされていない高解像度 RAW イメージおよび RGB イメージで構成されています。

dataDir をデータの目的の場所として指定します。

dataDir = fullfile(tempdir,"ZurichRAWToRGB");

データ セットをダウンロードするには、Zurich RAW to RGB dataset フォームを使用してアクセスを依頼します。変数 dataDir で指定されたディレクトリにデータを解凍します。解凍に成功すると、dataDir には full_resolutiontest、および train という名前の 3 つのディレクトリが格納されています。

学習、検証、およびテスト用データストアの作成

RGB イメージ パッチ学習データ用のデータストアの作成

Canon のハイエンド デジタル一眼レフ カメラを使用して取得したターゲット RGB 学習イメージ パッチを読み取るimageDatastoreを作成します。

trainImageDir = fullfile(dataDir,"train");
dsTrainRGB = imageDatastore(fullfile(trainImageDir,"canon"),ReadSize=16);

RGB 学習イメージ パッチをプレビューします。

groundTruthPatch = preview(dsTrainRGB);
imshow(groundTruthPatch)

RAW イメージ パッチ学習データ用のデータストアの作成

Huawei 製携帯電話のカメラを使用して取得した入力 RAW 学習イメージ パッチを読み取る imageDatastore を作成します。RAW イメージは 10 ビットの精度で取得し、8 ビットおよび 16 ビットの PNG ファイルとして表現します。8 ビットのファイルは範囲 [0, 255] のデータを含むパッチをコンパクトに表現します。どの RAW データに対してもスケーリングは行っていません。

dsTrainRAW = imageDatastore(fullfile(trainImageDir,"huawei_raw"),ReadSize=16);

入力 RAW 学習イメージ パッチをプレビューします。データストアはこのパッチを 8 ビット uint8 として読み取ります。これは、センサーの計測値が範囲 [0, 255] に入っているためです。学習データの 10 ビット ダイナミック レンジをシミュレートするために、イメージの強度値を 4 で除算します。イメージを拡大すると、RGGB Bayer パターンを確認できます。

inputPatch = preview(dsTrainRAW);
inputPatchRAW = inputPatch/4;
imshow(inputPatchRAW)

従来の処理パイプラインの最小限をシミュレートするために、関数demosaic (Image Processing Toolbox)を使用して RAW データの RGGB Bayer パターンのデモザイク処理を行います。処理後のイメージを表示して、表示を明るくします。ターゲット RGB イメージと比較して、最小限の処理を行った RGB イメージは暗く、色のバランスが取れておらずアーティファクトも目立ちます。学習済みの RAW-to-RGB ネットワークは、出力 RGB イメージがターゲット イメージと似たものになるように前処理演算を実行します。

inputPatchRGB = demosaic(inputPatch,"rggb");
imshow(rescale(inputPatchRGB))

テスト イメージの検証セットおよびテスト セットへの分割

テスト データには、RAW イメージ パッチと RGB イメージ パッチおよびフルサイズのイメージが含まれています。この例では、テスト イメージ パッチを検証用のセットとテスト用のセットに分割します。フルサイズのテスト イメージは、定性的なテスト用にのみ使用します。フルサイズ イメージでの学習済みイメージ処理パイプラインの評価を参照してください。

RAW テスト イメージ パッチおよび RGB テスト イメージ パッチを読み取るイメージ データストアを作成します。

testImageDir = fullfile(dataDir,"test");
dsTestRAW = imageDatastore(fullfile(testImageDir,"huawei_raw"),ReadSize=16);
dsTestRGB = imageDatastore(fullfile(testImageDir,"canon"),ReadSize=16);

テスト データを検証用と学習用の 2 つのセットにランダムに分割します。検証データ セットには 200 個のイメージが格納されます。テスト セットには残りのイメージが格納されます。

numTestImages = dsTestRAW.numpartitions;
numValImages = 200;

testIdx = randperm(numTestImages);
validationIdx = testIdx(1:numValImages);
testIdx = testIdx(numValImages+1:numTestImages);

dsValRAW = subset(dsTestRAW,validationIdx);
dsValRGB = subset(dsTestRGB,validationIdx);

dsTestRAW = subset(dsTestRAW,testIdx);
dsTestRGB = subset(dsTestRGB,testIdx);

データの前処理と拡張

センサーは、赤の光センサーを 1 つ、緑を 2 つ、および青を 1 つ含む Bayer パターンの繰り返しでカラー データを取得します。関数transformを使用して、ネットワークに期待できる 4 チャネルのイメージにデータを前処理します。関数 transform は、補助関数 preprocessRAWDataForRAWToRGB で指定された演算を使用してデータを処理します。この補助関数は、この例にサポート ファイルとして添付されています。

補助関数 preprocessRAWDataForRAWToRGB は、H x W x 1 の RAW イメージを、赤が 1 つ、緑が 2 つ、青が 1 つのチャネルで構成される H/2 x W/2 x 4 のマルチチャネル イメージに変換します。

RAWto4Channel.png

また、[0, 1] の範囲にスケールした、データ型 single にデータをキャストします。

dsTrainRAW = transform(dsTrainRAW,@preprocessRAWDataForRAWToRGB);
dsValRAW = transform(dsValRAW,@preprocessRAWDataForRAWToRGB);
dsTestRAW = transform(dsTestRAW,@preprocessRAWDataForRAWToRGB);

ターゲット RGB イメージは符号なし 8 ビット データとしてディスクに格納されます。メトリクスの計算とネットワーク設計をより簡単にするには、関数 transform および補助関数 preprocessRGBDataForRAWToRGB を使用してターゲット RGB 学習イメージを前処理します。この補助関数は、この例にサポート ファイルとして添付されています。

補助関数 preprocessRGBDataForRAWToRGB は、[0, 1] の範囲にスケールしたデータ型 single にイメージをキャストします。

dsTrainRGB = transform(dsTrainRGB,@preprocessRGBDataForRAWToRGB);
dsValRGB = transform(dsValRGB,@preprocessRGBDataForRAWToRGB);

関数combineを使用して、学習、検証、およびテスト イメージ セットの入力 RAW データとターゲット RGB データを結合します。

dsTrain = combine(dsTrainRAW,dsTrainRGB);
dsVal = combine(dsValRAW,dsValRGB);
dsTest = combine(dsTestRAW,dsTestRGB);

関数transformと補助関数 augmentDataForRAWToRGB を使用して学習データをランダムに拡張します。この補助関数は、この例にサポート ファイルとして添付されています。

補助関数 augmentDataForRAWToRGB は入力 RAW イメージとターゲット RGB 学習イメージのペアに、90 度の回転と水平方向の反転をランダムに適用します。

dsTrainAug = transform(dsTrain,@augmentDataForRAWToRGB);

拡張された学習データをプレビューします。

exampleAug = preview(dsTrainAug)
exampleAug=8×2 cell array
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}
    {224×224×4 single}    {448×448×3 single}

ネットワーク入力とターゲット イメージをモンタージュに表示します。ネットワーク入力のチャネル数は 4 であるため、最初のチャネルを [0, 1] の範囲に再スケーリングして表示します。入力 RAW イメージとターゲット RGB イメージには同じ拡張を適用します。

exampleInput = exampleAug{1,1};
exampleOutput = exampleAug{1,2};
montage({rescale(exampleInput(:,:,1)),exampleOutput})

学習時の学習データおよび検証データのバッチ処理

この例ではカスタム学習ループを使用します。カスタム学習ループ内にある観測値のミニバッチ処理の管理にはminibatchqueueオブジェクトが有効です。また、minibatchqueue オブジェクトは、深層学習アプリケーションで自動微分を可能にするdlarrayオブジェクトにデータをキャストします。

miniBatchSize = 2;
valBatchSize = 10;
trainingQueue = minibatchqueue(dsTrainAug,MiniBatchSize=miniBatchSize, ...
    PartialMiniBatch="discard",MiniBatchFormat="SSCB");
validationQueue = minibatchqueue(dsVal,MiniBatchSize=valBatchSize,MiniBatchFormat="SSCB");

minibatchqueue の関数nextからは、データの次のミニバッチが得られます。関数 next の 1 回の呼び出しからの出力をプレビューします。出力のデータ型は dlarray です。データは GPU で既に gpuArray にキャストされており、学習できる状態にあります。

[inputRAW,targetRGB] = next(trainingQueue);
whos inputRAW
  Name            Size                   Bytes  Class      Attributes

  inputRAW      224x224x4x2            1605640  dlarray              
whos targetRGB
  Name             Size                   Bytes  Class      Attributes

  targetRGB      448x448x3x2            4816904  dlarray              

U-Net ネットワーク層の設定

この例では U-Net ネットワークのバリエーションを使用します。U-Net では、最初の一連の畳み込み層に最大プーリング層が点在し、入力イメージの解像度を逐次下げていきます。これらの層に、一連の畳み込み層が続き、その中にアップサンプリング演算処理が点在し、入力イメージの解像度を逐次上げていきます。U-Net の名前は、このネットワークが文字「U」のように対称の形状で描けることに由来しています。

この例では、2 つの変更を含むシンプルな U-Net アーキテクチャを使用します。まず、ネットワークは最終的な転置畳み込み演算を、独自のピクセル シャッフル アップサンプリング (Depth-To-Space とも呼ばれる) 演算に置き換えます。次に、カスタム双曲線正接活性化層をネットワークの最終層として使用します。

ピクセル シャッフル アップサンプリング

畳み込み、およびそれに続くピクセル シャッフル アップサンプリングでは、超解像アプリケーション用のサブピクセル畳み込みを定義できます。サブピクセル畳み込みにより、転置畳み込みから発生する可能性のあるチェッカーボード アーティファクトが予防されます [6]。モデルは H/2 x W/2 x 4 の RAW 入力を W x H x 3 の RGB 出力にマッピングする必要があるため、モデルの最終的なアップサンプリング段階は、空間サンプルの数が入力から出力で増加する超解像と同様であると考えられます。

次の図は、2 x 2 x 4 の入力に対してピクセル シャッフル アンサンプリングがどのように動作するかを示しています。最初の 2 つの次元は空間次元、3 番目の次元はチャネル次元です。一般に、係数 S のピクセル シャッフル アンサンプリングでは H x W x C の入力を取り、S*H x S*W x CS2 の出力が得られます。

RawtoRGBpixelshuffle.png

ピクセル シャッフル関数は、所与の空間位置のチャネル次元からの情報を、アップサンプリング時に各チャネルがその近傍に対して安定的な空間位置を取る出力になるように S x S の空間ブロックにマッピングして、出力の空間次元を拡張します。

スケーリングした双曲線正接の活性化

双曲線正接活性化層は、層の入力に対して関数 tanh を適用します。この例では、関数 tanh をスケールおよびシフトしたバージョンを使用します。この関数は RGB ネットワーク出力が [0, 1] の範囲にほぼ収まるようにします。

f(x)=0.58*tanh(x)+0.5

RawtoRGBtanh.png

入力正規化のための学習セット統計の計算

tall を使用して学習データ セット全体の平均リダクションをチャネル単位で計算します。ネットワークの入力層は、平均統計値を使用して学習時およびテスト時の入力の平均センタリングを実行します。

dsIn = copy(dsTrainRAW);
dsIn.UnderlyingDatastore.ReadSize = 1;
t = tall(dsIn);
perChannelMean = gather(mean(t,[1 2]));

U-Net の作成

チャネルごとの平均を指定して初期サブネットワークの層を作成します。

inputSize = [256 256 4];
initialLayer = imageInputLayer(inputSize,Normalization="zerocenter", ...
    Mean=perChannelMean,Name="ImageInputLayer");

最初の符号化サブネットワークの層を追加します。最初の符号化器には 32 個の畳み込みフィルターがあります。

numEncoderStages = 4;
numFiltersFirstEncoder = 32;
encoderNamePrefix = "Encoder-Stage-";

encoderLayers = [
    convolution2dLayer([3 3],numFiltersFirstEncoder,Padding="same", ...
        WeightsInitializer="narrow-normal",Name=encoderNamePrefix+"1-Conv-1")
    leakyReluLayer(0.2,Name=encoderNamePrefix+"1-ReLU-1")
    convolution2dLayer([3 3],numFiltersFirstEncoder,Padding="same", ...
        WeightsInitializer="narrow-normal",Name=encoderNamePrefix+"1-Conv-2")
    leakyReluLayer(0.2,Name=encoderNamePrefix+"1-ReLU-2")
    maxPooling2dLayer([2 2],Stride=[2 2],Name=encoderNamePrefix+"1-MaxPool")  
    ];

さらに符号化サブネットワークの層を追加します。これらのサブネットワークは、groupNormalizationLayer を使用して各畳み込み層の後にチャネル単位のインスタンス正規化を追加します。各符号化器サブネットワークには、前の符号化器サブネットワークの 2 倍の数のフィルターがあります。

cnIdx = 1;
for stage = 2:numEncoderStages
    
    numFilters = numFiltersFirstEncoder*2^(stage-1);
    layerNamePrefix = encoderNamePrefix+num2str(stage);
    
    encoderLayers = [
        encoderLayers
        convolution2dLayer([3 3],numFilters,Padding="same", ...
            WeightsInitializer="narrow-normal",Name=layerNamePrefix+"-Conv-1")
        groupNormalizationLayer("channel-wise",Name="cn"+num2str(cnIdx))
        leakyReluLayer(0.2,Name=layerNamePrefix+"-ReLU-1")
        convolution2dLayer([3 3],numFilters,Padding="same", ...
            WeightsInitializer="narrow-normal",Name=layerNamePrefix+"-Conv-2")
        groupNormalizationLayer("channel-wise",Name="cn"+num2str(cnIdx+1))
        leakyReluLayer(0.2,Name=layerNamePrefix+"-ReLU-2")
        maxPooling2dLayer([2 2],Stride=[2 2],Name=layerNamePrefix+"-MaxPool")
        ];     
    
    cnIdx = cnIdx + 2;
end

ブリッジ層を追加します。ブリッジ サブネットワークには、最終的な符号化器サブネットワークおよび最初の復号化器サブネットワークの 2 倍の数のフィルターがあります。

numFilters = numFiltersFirstEncoder*2^numEncoderStages;
bridgeLayers = [
    convolution2dLayer([3 3],numFilters,Padding="same", ...
        WeightsInitializer="narrow-normal",Name="Bridge-Conv-1")
    groupNormalizationLayer("channel-wise",Name="cn7")
    leakyReluLayer(0.2,Name="Bridge-ReLU-1")
    convolution2dLayer([3 3],numFilters,Padding="same", ...
        WeightsInitializer="narrow-normal",Name="Bridge-Conv-2")
    groupNormalizationLayer("channel-wise",Name="cn8")
    leakyReluLayer(0.2,Name="Bridge-ReLU-2")];

最初の 3 つの復号化器サブネットワークの層を追加します。

numDecoderStages = 4;
cnIdx = 9;
decoderNamePrefix = "Decoder-Stage-";

decoderLayers = [];
for stage = 1:numDecoderStages-1
    
    numFilters = numFiltersFirstEncoder*2^(numDecoderStages-stage);
    layerNamePrefix = decoderNamePrefix+num2str(stage);  
    
    decoderLayers = [
        decoderLayers
        transposedConv2dLayer([3 3],numFilters,Stride=[2 2],Cropping="same", ...
            WeightsInitializer="narrow-normal",Name=layerNamePrefix+"-UpConv")
        leakyReluLayer(0.2,Name=layerNamePrefix+"-UpReLU")
        depthConcatenationLayer(2,Name=layerNamePrefix+"-DepthConcatenation")
        convolution2dLayer([3 3],numFilters,Padding="same", ...
            WeightsInitializer="narrow-normal",Name=layerNamePrefix+"-Conv-1")
        groupNormalizationLayer("channel-wise",Name="cn"+num2str(cnIdx))
        leakyReluLayer(0.2,Name=layerNamePrefix+"-ReLU-1")
        convolution2dLayer([3 3],numFilters,Padding="same", ...
            WeightsInitializer="narrow-normal",Name=layerNamePrefix+"-Conv-2")
        groupNormalizationLayer("channel-wise",Name="cn"+num2str(cnIdx+1))
        leakyReluLayer(0.2,Name=layerNamePrefix+"-ReLU-2")
        ];        
    
    cnIdx = cnIdx + 2;    
end

最後の復号化器サブネットワークの層を追加します。このサブネットワークは、他の復号化器サブネットワークによって実行されたチャネル単位のインスタンス正規化を除外します。各復号化器サブネットワークには、前のサブネットワークの半分の数のフィルターがあります。

numFilters = numFiltersFirstEncoder;
layerNamePrefix = decoderNamePrefix+num2str(stage+1); 

decoderLayers = [
    decoderLayers
    transposedConv2dLayer([3 3],numFilters,Stride=[2 2],Cropping="same", ...
       WeightsInitializer="narrow-normal",Name=layerNamePrefix+"-UpConv")
    leakyReluLayer(0.2,Name=layerNamePrefix+"-UpReLU")
    depthConcatenationLayer(2,Name=layerNamePrefix+"-DepthConcatenation")
    convolution2dLayer([3 3],numFilters,Padding="same", ...
        WeightsInitializer="narrow-normal",Name=layerNamePrefix+"-Conv-1")
    leakyReluLayer(0.2,Name=layerNamePrefix+"-ReLU-1")
    convolution2dLayer([3 3],numFilters,Padding="same", ...
        WeightsInitializer="narrow-normal",Name=layerNamePrefix+"-Conv-2")
    leakyReluLayer(0.2,Name=layerNamePrefix+"-ReLU-2")];

U-Net の最終層を追加します。ピクセル シャッフル層は、ピクセル シャッフル アップサンプリングを使用して、最終の畳み込みからの活性化の H/2 x W/2 x 12 チャネル サイズから H x W x 3 チャネルの活性化に移行します。最終層は、双曲線正接関数を使用して、出力が目的の [0, 1] の範囲になるようにします。

finalLayers = [
    convolution2dLayer([3 3],12,Padding="same",WeightsInitializer="narrow-normal", ...
       Name="Decoder-Stage-4-Conv-3")
    pixelShuffleLayer("pixelShuffle",2)
    tanhScaledAndShiftedLayer("tanhActivation")];

layers = [initialLayer;encoderLayers;bridgeLayers;decoderLayers;finalLayers];
lgraph = layerGraph(layers);

符号化サブネットワークと復号化サブネットワークの層を接続します。

lgraph = connectLayers(lgraph,"Encoder-Stage-1-ReLU-2", ...
    "Decoder-Stage-4-DepthConcatenation/in2");
lgraph = connectLayers(lgraph,"Encoder-Stage-2-ReLU-2", ...
    "Decoder-Stage-3-DepthConcatenation/in2");
lgraph = connectLayers(lgraph,"Encoder-Stage-3-ReLU-2", ...
    "Decoder-Stage-2-DepthConcatenation/in2");
lgraph = connectLayers(lgraph,"Encoder-Stage-4-ReLU-2", ...
    "Decoder-Stage-1-DepthConcatenation/in2");
net = dlnetwork(lgraph);

ディープ ネットワーク デザイナーアプリを使用して、ネットワーク アーキテクチャを可視化します。

deepNetworkDesigner(lgraph)

特徴抽出ネットワークの読み込み

この関数は、イメージの特徴をさまざまな層で抽出できるように、事前学習済みの VGG-16 深層ニューラル ネットワークを変更します。この多層にわたる特徴は、コンテンツ損失の計算に使用されます。

事前学習済みの VGG-16 ネットワークを取得するには、vgg16をインストールします。必要なサポート パッケージがインストールされていない場合、ダウンロード用リンクが表示されます。

vggNet = vgg16;

VGG-16 ネットワークを特徴抽出に適したネットワークにするには、"relu5_3" までの層を使用します。

vggNet = vggNet.Layers(1:31);
vggNet = dlnetwork(layerGraph(vggNet));

モデル勾配と損失関数の定義

補助関数 modelGradients は学習データのバッチの勾配と全体的な損失を計算します。この関数は、この例のサポート関数の節で定義されています。

全体的な損失は、平均絶対誤差 (MAE) 損失とコンテンツ損失の 2 つの損失の重み付き和です。MAE 損失とコンテンツ損失の影響が全体的な損失に対してほぼ均等になるように、コンテンツ損失が重み付けされます。

lossOverall=lossMAE+weightFactor*lossContent

MAE 損失は、ネットワーク予測のサンプルとターゲット イメージのサンプルの間の距離 L1 にペナルティを課します。イメージ処理アプリケーションの場合、L2 より L1 の方が適している場合が多くなります。これは、ブレの影響を低減する効果があるためです [4]。この損失は、この例のサポート関数の節で定義された補助関数 maeLoss を使用して実装されます。

コンテンツ損失は、ネットワークが高位の構造コンテンツと低位のエッジと色情報を学習するときに効果があります。損失関数は、各活性化層の予測とターゲットの間の平均二乗誤差 (MSE) の重み付き和を計算します。この損失は、この例のサポート関数の節で定義された補助関数 contentLoss を使用して実装されます。

コンテンツ損失重み係数の計算

補助関数 modelGradients は入力引数としてコンテンツ損失の重み係数を必要とします。MAE 損失と重み付きコンテンツ損失が等しくなるように、学習データのサンプル バッチの重み係数を計算します。

RAW ネットワーク入力と RGB ターゲット出力のペアで構成される学習データのバッチをプレビューします。

trainingBatch = preview(dsTrainAug);
networkInput = dlarray((trainingBatch{1,1}),"SSC");
targetOutput = dlarray((trainingBatch{1,2}),"SSC");

関数forwardを使用して、未学習 U-Net ネットワークの応答を予測します。

predictedOutput = forward(net,networkInput);

予測 RGB イメージとターゲット RGB イメージの間の MAE 損失とコンテンツ損失を計算します。

sampleMAELoss = maeLoss(predictedOutput,targetOutput);
sampleContentLoss = contentLoss(vggNet,predictedOutput,targetOutput);

重み係数を計算します。

weightContent = sampleMAELoss/sampleContentLoss;

学習オプションの指定

ADAM 最適化の特性を制御するためにカスタム学習ループ内で使用する学習オプションを定義します。学習を 20 エポック行います。

learnRate = 5e-5;
numEpochs = 20;

ネットワークの学習または事前学習済みネットワークのダウンロード

既定では、この例は補助関数 downloadTrainedNetwork を使用して、事前学習済みバージョンの RAW-to-RGB ネットワークをダウンロードします。この事前学習済みのネットワークを使用することで、学習の完了を待たずに例全体を実行できます。

ネットワークに学習させるには、次のコードで変数 doTrainingtrue に設定します。カスタム学習ループでモデルに学習させます。それぞれの反復で次を行います。

  • 関数nextを使用して、現在のミニバッチ データを読み取ります。

  • 関数dlfevalと補助関数 modelGradients を使用してモデルの勾配を評価します。

  • 関数adamupdateと勾配情報を使用してネットワーク パラメーターを更新します。

  • 反復のたびに学習の進行状況プロットを更新し、計算されたさまざまな損失を表示します。

GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™、および CUDA® 対応の NVIDIA® GPU が必要です。詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。学習には NVIDIA™ Titan RTX で約 88 時間を要します。ご使用の GPU ハードウェアによっては、さらに長い時間がかかる可能性もあります。

doTraining = false;
if doTraining
    
    % Create a directory to store checkpoints
    checkpointDir = fullfile(dataDir,"checkpoints",filesep);
    if ~exist(checkpointDir,"dir")
        mkdir(checkpointDir);
    end
    
    % Initialize training plot
    [hFig,batchLine,validationLine] = initializeTrainingPlotRAWToRGB;
    
    % Initialize Adam solver state
    [averageGrad,averageSqGrad] = deal([]);
    iteration = 0;
    
    start = tic;
    for epoch = 1:numEpochs
        reset(trainingQueue);
        shuffle(trainingQueue);
        while hasdata(trainingQueue)
            [inputRAW,targetRGB] = next(trainingQueue);  
            
            [grad,loss] = dlfeval(@modelGradients, ...
                net,vggNet,inputRAW,targetRGB,weightContent);
            
            iteration = iteration + 1;
            
            [net,averageGrad,averageSqGrad] = adamupdate(net, ...
                grad,averageGrad,averageSqGrad,iteration,learnRate);
              
            updateTrainingPlotRAWToRGB(batchLine,validationLine,iteration, ...
                loss,start,epoch,validationQueue,numValImages,valBatchSize, ...
                net,vggNet,weightContent);
        end
        % Save checkpoint of network state
        save(checkpointDir+"epoch"+epoch,"net");
    end

    % Save the final network state
    modelDateTime = string(datetime("now",Format="yyyy-MM-dd-HH-mm-ss"));
    save(fullfile(dataDir,"trainedRAWToRGBNet-"+modelDateTime+".mat"),"net");    

else
    trainedNet_url = "https://ssd.mathworks.com/supportfiles"+ ...
        "/vision/data/trainedRAWToRGBNet.mat";
    downloadTrainedNetwork(trainedNet_url,dataDir);
    load(fullfile(dataDir,"trainedRAWToRGBNet.mat"));
end

イメージ画質メトリクスの計算

MSSIM や PSNR などの参照ベースの画質メトリクスは、画質の定量的測定を可能にします。パッチ済みのテスト イメージは空間的にレジストレーションされ、サイズも同じであるため、その MSSIM や PSNR を計算することができます。

minibatchqueue オブジェクトを使用して、パッチ済みイメージのテスト セットを反復処理します。

patchTestSet = combine(dsTestRAW,dsTestRGB);
testPatchQueue = minibatchqueue(patchTestSet, ...
    MiniBatchSize=16,MiniBatchFormat="SSCB");

テスト セットを反復処理し、関数multissim (Image Processing Toolbox)およびpsnr (Image Processing Toolbox)を使用して各テスト イメージの MSSIM および PSNR を計算します。メトリクスはマルチチャネル入力に対して適切に定義されていないため、各カラー チャネルのメトリクスの平均を近似値として使用して、カラー イメージの MSSIM を計算します。

totalMSSIM = 0;
totalPSNR = 0;
while hasdata(testPatchQueue)
    [inputRAW,targetRGB] = next(testPatchQueue);
    outputRGB = forward(net,inputRAW);
    targetRGB = targetRGB ./ 255; 
    mssimOut = sum(mean(multissim(outputRGB,targetRGB),3),4);
    psnrOut = sum(psnr(outputRGB,targetRGB),4);
    totalMSSIM = totalMSSIM + mssimOut;
    totalPSNR = totalPSNR + psnrOut;
end

テスト セット全体の平均 MSSIM と平均 PSNR を計算します。この結果は、平均 MSSIM については [3] の同様の U-Net アプローチと一致し、平均 PSNR については [3] の PyNet アプローチと拮抗しています。[3] と比較して、損失関数の違いおよびピクセル シャッフル アップサンプリングの使用がこれらの違いの原因であると考えられます。

numObservations = dsTestRGB.numpartitions;
meanMSSIM = totalMSSIM / numObservations
meanMSSIM = 
  1(S) × 1(S) × 1(C) × 1(B) single gpuArray dlarray

    0.8401

meanPSNR = totalPSNR / numObservations
meanPSNR = 
  1(S) × 1(S) × 1(C) × 1(B) single gpuArray dlarray

   21.0730

フルサイズ イメージでの学習済みイメージ処理パイプラインの評価

高解像度のテスト イメージを取得する際に使用した携帯電話のカメラとデジタル一眼レフ カメラのセンサーは異なるため、シーンはレジストレーションされておらず、サイズも異なります。ネットワークおよびデジタル一眼レフ カメラの ISP からの高解像度イメージの参照ベースの比較は困難です。ただし、イメージ処理の目標は見た目の良いイメージを作成することであるため、イメージの定性的な比較は有効です。

携帯電話のカメラで取得したフルサイズ RAW イメージが格納されているイメージ データストアを作成します。

testImageDir = fullfile(dataDir,"test");
testImageDirRAW = "huawei_full_resolution";
dsTestFullRAW = imageDatastore(fullfile(testImageDir,testImageDirRAW));

フルサイズ RAW テスト セット内にあるイメージ ファイルの名前を取得します。

targetFilesToInclude = extractAfter(string(dsTestFullRAW.Files), ...
    fullfile(testImageDirRAW,filesep));
targetFilesToInclude = extractBefore(targetFilesToInclude,".png");

関数transformを使用し、データをネットワークに必要な形式に変換して RAW データを前処理します。関数 transform は、補助関数 preprocessRAWDataForRAWToRGB で指定された演算を使用してデータを処理します。この補助関数は、この例にサポート ファイルとして添付されています。

dsTestFullRAW = transform(dsTestFullRAW,@preprocessRAWDataForRAWToRGB);

ハイエンドのデジタル一眼レフ カメラで取得したフルサイズ RGB テスト イメージが格納されているイメージ データストアを作成します。Zurich RAW-to-RGB データ セットは RAW イメージよりフルサイズの RGB イメージの方を多く含むため、対応する RAW イメージのある RGB イメージのみを追加します。

dsTestFullRGB = imageDatastore(fullfile(dataDir,"full_resolution","canon"));
dsTestFullRGB.Files = dsTestFullRGB.Files( ...
    contains(dsTestFullRGB.Files,targetFilesToInclude));

ターゲット RGB イメージを読み取り、最初の数個のイメージをモンタージュで表示します。

targetRGB = readall(dsTestFullRGB);
montage(targetRGB,Size=[5 2],Interpolation="bilinear")

minibatchqueue オブジェクトを使用して、フルサイズ イメージのテスト セットを反復処理します。高解像度イメージの処理に十分なメモリを搭載した GPU デバイスがある場合は、出力環境を "gpu" として指定して GPU 上で予測を実行できます。

testQueue = minibatchqueue(dsTestFullRAW,MiniBatchSize=1, ...
    MiniBatchFormat="SSCB",OutputEnvironment="cpu");

フルサイズの RAW テスト イメージごとに、ネットワークでforwardを呼び出して出力 RGB イメージを予測します。

outputSize = 2*size(preview(dsTestFullRAW),[1 2]);
outputImages = zeros([outputSize,3,dsTestFullRAW.numpartitions],"uint8");

idx = 1;
while hasdata(testQueue)
    inputRAW = next(testQueue);
    rgbOut = forward(net,inputRAW);
    rgbOut = gather(extractdata(rgbOut));    
    outputImages(:,:,:,idx) = im2uint8(rgbOut);
    idx = idx+1;
end

モンタージュ ビューを確認して出力全体を把握します。類似した特性をもつ、見た目の良いイメージが出力されます。

montage(outputImages,Size=[5 2],Interpolation="bilinear")

ターゲット RGB イメージを、ネットワークで予測された対応するイメージと比較します。デジタル一眼レフ カメラのターゲット イメージより彩度の高い色が出力されます。シンプルな U-Net アーキテクチャからの色はデジタル一眼レフ カメラのターゲットと同じではありませんが、多くの場合、定性的に美しいイメージとなります。

imgIdx = 1;
imTarget = targetRGB{imgIdx};
imPredicted = outputImages(:,:,:,imgIdx);
montage({imTarget,imPredicted},Interpolation="bilinear")

RAW-to-RGB ネットワークの性能を改善するために、ネットワーク アーキテクチャは色およびコントラストを記述するグローバルな特徴からの複数のスケールを使用して、局所化された空間的特徴の詳細を学習します [3]。

サポート関数

モデル勾配関数

補助関数 modelGradients は勾配および全体的な損失を計算します。勾配情報は、モデル内の学習可能なパラメーターごとの層、パラメーター名、および値を含む table として返されます。

function [gradients,loss] = modelGradients(dlnet,vggNet,X,T,weightContent)
    Y = forward(dlnet,X);
    lossMAE = maeLoss(Y,T);
    lossContent = contentLoss(vggNet,Y,T);
    loss = lossMAE + weightContent.*lossContent;
    gradients = dlgradient(loss,dlnet.Learnables);
end

平均絶対誤差損失関数

補助関数 maeLoss はネットワーク予測 Y とターゲット イメージ T の平均絶対誤差を計算します。

function loss = maeLoss(Y,T)
    loss = mean(abs(Y-T),"all");
end

コンテンツ損失関数

補助関数 contentLoss は、各活性化層のネットワーク予測 Y とターゲット イメージ T の間の MSE の重み付き和を計算します。補助関数 contentLoss は、補助関数 mseLoss を使用して各活性化層の MSE を計算します。各活性化層の損失の影響が全体的なコンテンツ損失に対してほぼ均等になるように、重みが選択されます。

function loss = contentLoss(net,Y,T)

    layers = ["relu1_1","relu1_2","relu2_1","relu2_2", ...
        "relu3_1","relu3_2","relu3_3","relu4_1"];
    [T1,T2,T3,T4,T5,T6,T7,T8] = forward(net,T,Outputs=layers);
    [X1,X2,X3,X4,X5,X6,X7,X8] = forward(net,Y,Outputs=layers);
    
    l1 = mseLoss(X1,T1);
    l2 = mseLoss(X2,T2);
    l3 = mseLoss(X3,T3);
    l4 = mseLoss(X4,T4);
    l5 = mseLoss(X5,T5);
    l6 = mseLoss(X6,T6);
    l7 = mseLoss(X7,T7);
    l8 = mseLoss(X8,T8);
    
    layerLosses = [l1 l2 l3 l4 l5 l6 l7 l8];
    weights = [1 0.0449 0.0107 0.0023 6.9445e-04 2.0787e-04 2.0118e-04 6.4759e-04];
    loss = sum(layerLosses.*weights);  
end

平均二乗誤差損失関数

補助関数 mseLoss はネットワーク予測 Y とターゲット イメージ T の MSE を計算します。

function loss = mseLoss(Y,T)
    loss = mean((Y-T).^2,"all");
end

参考文献

1) Sumner, Rob. "Processing RAW Images in MATLAB". May 19, 2014. https://rcsumner.net/raw_guide/RAWguide.pdf.

2) Chen, Chen, Qifeng Chen, Jia Xu, and Vladlen Koltun. "Learning to See in the Dark." ArXiv:1805.01934 [Cs], May 4, 2018. http://arxiv.org/abs/1805.01934.

3) Ignatov, Andrey, Luc Van Gool, and Radu Timofte. "Replacing Mobile Camera ISP with a Single Deep Learning Model." ArXiv:2002.05509 [Cs, Eess], February 13, 2020. http://arxiv.org/abs/2002.05509. プロジェクトのウェブサイト

4) Zhao, Hang, Orazio Gallo, Iuri Frosio, and Jan Kautz. "Loss Functions for Neural Networks for Image Processing." ArXiv:1511.08861 [Cs], April 20, 2018. http://arxiv.org/abs/1511.08861.

5) Johnson, Justin, Alexandre Alahi, and Li Fei-Fei. "Perceptual Losses for Real-Time Style Transfer and Super-Resolution." ArXiv:1603.08155 [Cs], March 26, 2016. http://arxiv.org/abs/1603.08155.

6) Shi, Wenzhe, Jose Caballero, Ferenc Huszár, Johannes Totz, Andrew P. Aitken, Rob Bishop, Daniel Rueckert, and Zehan Wang. "Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network." ArXiv:1609.05158 [Cs, Stat], September 23, 2016. http://arxiv.org/abs/1609.05158.

参考

| | | |

関連する例

詳細