Main Content

拡散を使用したイメージの生成

R2023b 以降

この例では、ノイズ除去拡散確率モデル (DDPM) [1] を使用して新しいイメージを生成する方法を示します。

拡散モデルは、次の 2 つの主なステップから成る学習プロセスによってイメージの生成を学習します。

  • 順方向拡散 — モデルは、鮮明なイメージを入力として受け取り、それに繰り返しノイズを追加します。モデルは、ノイズが追加されたときのイメージの変化を観察することで、データ内の統計パターンと依存関係を捉えることを学習します。

  • 逆拡散 — モデルは、ノイズの多いバージョンから元の鮮明なイメージを再構築することを目指します。モデルは、順方向拡散処理によって各ステップで追加されたノイズを把握しているため、拡散処理を逆方向に繰り返し実行することができます。モデルは、追加されたノイズを予測し、それをイメージから減算することで、徐々にノイズを除去して元の鮮明なイメージを復元します。モデルは、ノイズを正確に予測することでイメージのノイズ除去を学習します。

ネットワークに学習させたら、ネットワークによって予測された追加ノイズをランダム ノイズから順次削除することで、新しいイメージを生成できます。

データの読み込み

関数 imageDatastore を使用して数字データをイメージ データストアとして読み込み、イメージ データが格納されているフォルダーを指定します。

dataFolder = fullfile(tempdir,"DigitsData");
unzip("DigitsData.zip",dataFolder);
imds = imageDatastore(dataFolder,IncludeSubfolders=true);

拡張イメージ データストアを使用し、イメージのサイズを 32×32 ピクセルに変更します。

imgSize = [32 32];
audsImds = augmentedImageDatastore(imgSize,imds);

順方向拡散処理

順方向拡散 (ノイズ付加) 処理では、結果とランダム ノイズの区別がつかなくなるまで、イメージにガウス ノイズを繰り返し追加します。各ノイズ付加ステップ t で、次の方程式を使用してガウス ノイズ ε を追加します。

xt+1=1-βtxt+βtε,

ここで、xt はノイズを含む t 番目のイメージで、βt は分散スケジュールです。分散スケジュールは、モデルがイメージにノイズを追加する方法を定義します。この例では、βmin=0.0001 から βmax=0.02 まで、t と共に線形に増加する 500 ステップの分散スケジュールを定義します。

numNoiseSteps = 500;
betaMin = 1e-4;
betaMax = 0.02;
varianceSchedule = linspace(betaMin,betaMax,numNoiseSteps);

テスト イメージに順方向拡散処理を適用します。拡張イメージ データストアから単一のイメージを抽出し、ピクセル値が [-1 1] の範囲になるように再スケーリングします。

img = read(audsImds);
img = img{1,1};
img = img{:};
img = rescale(img,-1,1);

この例の最後に定義されている補助関数 applyNoiseToImage を使用し、テスト イメージのノイズの量を徐々に増やします。中間点の出力を表示するには、numNoiseSteps ステップのノイズをイメージに適用します。

拡散処理を実行すると、左側の鮮明なイメージを起点として最終的なノイズ付加ステップまで処理が実行され、イメージとランダム ノイズの区別がつかなくなるまでノイズが追加されます。

tiledlayout("flow");
nexttile
imshow(img,[])
title("t = 0");
for i = 1:5
    nexttile
    noise = randn(size(img),like=img);
    noiseStepsToApply = numNoiseSteps/5 * i;
    noisyImg = applyNoiseToImage(img,noise,noiseStepsToApply,varianceSchedule);

    % Extract the data from the dlarray.
    noisyImg = extractdata(noisyImg);
    imshow(noisyImg,[])
    title("t = " + string(noiseStepsToApply));
end

Figure contains 6 axes objects. Hidden axes object 1 with title t = 0 contains an object of type image. Hidden axes object 2 with title t = 100 contains an object of type image. Hidden axes object 3 with title t = 200 contains an object of type image. Hidden axes object 4 with title t = 300 contains an object of type image. Hidden axes object 5 with title t = 400 contains an object of type image. Hidden axes object 6 with title t = 500 contains an object of type image.

各ステップで追加されたノイズがわかっていれば、処理を正確に逆転させることで元の鮮明なイメージを再現できます。次に、イメージの正確な統計分布がわかっている場合は、ランダム ノイズに対して逆方向の処理を実行し、学習データの分布から導出されたイメージを作成できます。

データセットの正確な統計分布は、解析的に計算するには複雑すぎます。ただし、この例では、深層学習ネットワークに学習させてそれを近似的に得る方法を示します。ネットワークに学習させたら、ネットワークによって予測された追加ノイズをランダム ノイズから順次削除 (ノイズ除去) することで、新しいイメージを生成できます。

ネットワークの定義

拡散ネットワークは、イメージ入力と、イメージに追加されるノイズ ステップの数を表すスカラー特徴入力の 2 つの入力を受け取ります。ネットワークは、イメージに追加されたものであるとモデルが予測したノイズを表す単一のイメージを出力します。

ネットワーク アーキテクチャは、[1] で使用されているネットワークに基づいています。ネットワークは、繰り返し実行される次の 2 種類の処理単位を中心に構築されています。

  • 残差ブロック

  • 注意ブロック

残差ブロックは、スキップ接続を使用して畳み込み演算を実行します。

注意ブロックは、スキップ接続を使用して自己注意演算を実行します。

このネットワークは、U-Net [2] に似た符号化器-復号化器構造を持っています。入力イメージを繰り返しダウンサンプリングして解像度を落としてから処理を行い、その後、繰り返しアップサンプリングして元のサイズに戻します。ネットワークは、各解像度で残差ブロックを使用して入力を処理します。ネットワークは、自己注意層を含む注意ブロックを使用し、16×16 ピクセルおよび 8×8 ピクセルの解像度で入力イメージの各部分間の相関関係を学習します。自己注意層を使用するために、ネットワークはカスタムの SpatialFlattenLayer を使用し、単一の空間次元となるように活性化の形状を変更します。ネットワークは、注意を適用した後、カスタムの SpatialUnflattenLayer を使用して活性化の形状を 2 次元イメージに戻します。この例では、両方の層がサポート ファイルとして含まれています。

ネットワークは、正弦波位置符号化を使用してノイズ ステップ入力を符号化し、それを各残差ブロックに追加して、追加されたさまざまな量のノイズを区別できるように学習を行います。ネットワークは、深さ連結を使用し、各ダウンサンプリング残差ブロックの出力を、それらを補完するアップサンプリング残差ブロックの出力と結合します。

拡散ネットワークを作成するには、この例にサポート ファイルとして添付されている関数 createDiffusionNetwork を使用します。この例では、1 つのカラー チャネルをもつグレースケール イメージを使用します。RGB イメージでネットワークに学習させるには、numInputChannels の値を 3 に変更します。

numInputChannels = 1;
net = createDiffusionNetwork(numInputChannels)
net = 
  dlnetwork with properties:

         Layers: [205×1 nnet.cnn.layer.Layer]
    Connections: [244×2 table]
     Learnables: [274×3 table]
          State: [0×3 table]
     InputNames: {'imageinput'  'input'}
    OutputNames: {'conv_29'}
    Initialized: 1

  View summary with summary.

拡散ネットワークには全体で 205 個の層があります。ネットワーク アーキテクチャの全体を確認するには、ディープ ネットワーク デザイナーを使用します。

deepNetworkDesigner(net)

モデル損失関数の定義

この例のモデル損失関数セクションにリストされている関数 modelLoss を作成します。この関数は、DDPM ネットワーク、さまざまな量のノイズが適用されたノイズの多いイメージのミニバッチ、各イメージに追加されたノイズの量に対応するノイズ ステップ値のミニバッチ、および対応するターゲット ノイズ値を入力として受け取ります。

学習オプションの指定

ミニバッチ サイズを 128 として 50 エポック学習させます。

miniBatchSize = 128;
numEpochs = 50;

ADAM 最適化のオプションを指定します。

  • 学習率 0.0005

  • 勾配の減衰係数 0.9

  • 2 乗勾配の減衰係数 0.9999

learnRate = 0.0005;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.9999;

モデルの学習

指定されたノイズ ステップでイメージに追加されたノイズを予測するように DDPM に学習させます。モデルに学習させるには、各エポックでイメージをシャッフルしてミニバッチをループ処理します。各ミニバッチでは、各イメージに対して次の処理が行われます。

  1. 1 から numNoiseSteps. の間のランダムなノイズ ステップ数 N を選択します。

  2. 標準正規分布から導出され、サイズがイメージと同じであるランダム ノイズの行列を生成します。これは、ネットワークが予測を学習する際のターゲットとして使用します。

  3. ターゲット ノイズをイメージに N 回適用します。

次に、ネットワーク出力とイメージに追加された実際のノイズとの間のモデル損失関数、および損失関数の勾配を計算します。勾配降下法を使用し、損失関数の勾配に沿ってネットワークの学習可能なパラメーターを更新します。

モデルの学習は大量の計算を必要とする処理であるため、数時間かかる場合があります。この例の実行時間を節約するには、doTrainingfalse に設定して事前学習済みのネットワークを読み込みます。自分でネットワークに学習させるには、doTrainingtrue に設定します。

doTraining = false;

minibatchqueue オブジェクトを使用し、イメージのミニバッチの処理および管理を行います。各ミニバッチで、この例の最後に定義されているカスタム ミニバッチ前処理関数 preprocessMinibatch を使用し、ピクセル値が [-1 1] の範囲になるようにイメージを再スケーリングします。

mbq = minibatchqueue(audsImds, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@preprocessMiniBatch, ...
    MiniBatchFormat="SSCB", ...
    PartialMiniBatch="discard");

Adam 最適化のパラメーターを初期化します。

averageGrad = [];
averageSqGrad = [];

モデルのパフォーマンスを追跡するには、trainingProgressMonitor オブジェクトを使用します。モニター用に合計反復回数を計算します。

numObservationsTrain = numel(imds.Files);
numIterationsPerEpoch = ceil(numObservationsTrain/miniBatchSize);
numIterations = numEpochs*numIterationsPerEpoch;

TrainingProgressMonitor オブジェクトを初期化します。監視オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。

if doTraining
    monitor = trainingProgressMonitor(...
        Metrics="Loss", ...
        Info=["Epoch","Iteration"], ...
        XLabel="Iteration");
end

ネットワークに学習をさせます。

if doTraining
    iteration = 0;
    epoch = 0;

    while epoch < numEpochs && ~monitor.Stop
        epoch = epoch + 1;
        shuffle(mbq);

        while hasdata(mbq) && ~monitor.Stop
            iteration = iteration + 1;

            img = next(mbq);

            % Generate random noise.
            targetNoise = randn(size(img),Like=img);

            % Generate a random noise step.
            noiseStep = dlarray(randi(numNoiseSteps,[1 miniBatchSize],Like=img),"CB");

            % Apply noise to the image.
            noisyImage = applyNoiseToImage(img,targetNoise,noiseStep,varianceSchedule);

            % Compute loss.
            [loss,gradients] = dlfeval(@modelLoss,net,noisyImage,noiseStep,targetNoise);

            % Update model.
            [net,averageGrad,averageSqGrad] = adamupdate(net,gradients,averageGrad,averageSqGrad,iteration, ...
                learnRate,gradientDecayFactor,squaredGradientDecayFactor);

            % Record metrics.
            recordMetrics(monitor,iteration,Loss=loss);
            updateInfo(monitor,Epoch=epoch,Iteration=iteration);
            monitor.Progress = 100 * iteration/numIterations;
        end

        % Generate and display a batch of generated images.
        numImages = 16;
        displayFrequency = numNoiseSteps + 1;
        generateAndDisplayImages(net,varianceSchedule,imgSize,numImages,numInputChannels,displayFrequency);
    end
else
    % If doTraining is false, download and extract the pretrained network from the MathWorks website.
    pretrainedNetZipFile = matlab.internal.examples.downloadSupportFile('nnet','data/TrainedDiffusionNetwork.zip');
    unzip(pretrainedNetZipFile);
    load("DiffusionNetworkTrained/DiffusionNetworkTrained.mat");
end

新しいイメージの生成

サポート関数 generateImages を使用し、学習済みのネットワークを使用してイメージのバッチを生成します。ノイズ除去処理の実行状況を示すために、10 ノイズ ステップごとに中間イメージを表示します。

この関数は、ランダムなイメージを起点として拡散処理を逆方向に繰り返し実行します。また、ネットワークを使用して各ステップでノイズを予測し、それを除去します。この関数はネットワークの予測を numNoiseSteps 回評価するため、数分かかる場合があります。完了すると、ネットワークは新しいイメージを生成します。

numImages = 9;
displayFrequency = 10;
figure
generatedImages = generateAndDisplayImages(net,varianceSchedule,imgSize,numImages,numInputChannels,displayFrequency);

Figure contains 9 axes objects. Hidden axes object 1 contains an object of type image. Hidden axes object 2 contains an object of type image. Hidden axes object 3 contains an object of type image. Hidden axes object 4 contains an object of type image. Hidden axes object 5 contains an object of type image. Hidden axes object 6 contains an object of type image. Hidden axes object 7 contains an object of type image. Hidden axes object 8 contains an object of type image. Hidden axes object 9 contains an object of type image.

参考文献

  1. Ho, Jonathan, Ajay Jain, and Pieter Abbeel. “Denoising Diffusion Probabilistic Models.” In Advances in Neural Information Processing Systems, 33:6840–51. Curran Associates, Inc., 2020. https://proceedings.neurips.cc/paper_files/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf

  2. Ronneberger, O., P. Fischer, and T. Brox. "U-Net: Convolutional Networks for Biomedical Image Segmentation." Medical Image Computing and Computer-Assisted Intervention (MICCAI). Vol. 9351, 2015, pp. 234–241.

サポート関数

順方向ノイズ付加関数

順方向ノイズ付加関数 applyNoiseToImage は、イメージ img、ガウス ノイズの行列 noiseToApply、イメージに適用するノイズ ステップの数を示す整数 noiseStep、および長さ numNoiseSteps のベクトルで構成される分散スケジュール varianceSchedule を入力として受け取ります。

関数は、各ステップで次の式を使用し、標準正規分布から導出されたランダム ノイズ ε をイメージ xt に適用します。

xt+1=1-βtxt+βtε.

ノイズの多いイメージの生成処理を高速化するには、次の方程式を使用して複数のノイズ付加ステップを一度に適用します。

xt=αtx0+1-αtε,

ここで、x0 は元のイメージで、αt=i=1t(1-βi) です。

function noisyImg = applyNoiseToImage(img,noiseToApply,noiseStep,varianceSchedule)
alphaBar = cumprod(1 - varianceSchedule);
alphaBarT = dlarray(alphaBar(noiseStep),"CBSS");

noisyImg = sqrt(alphaBarT).*img + sqrt(1 - alphaBarT).*noiseToApply;
end

モデル損失関数

モデル損失関数は、DDPM ネットワーク net、ノイズの多い入力イメージ X のミニバッチ、各イメージに追加されたノイズの量に対応するノイズ ステップ値 Y のミニバッチ、および対応するターゲット ノイズ値 T を入力として受け取ります。この関数は、net 内の学習可能なパラメーターに関する損失および損失の勾配を返します。

function [loss, gradients] = modelLoss(net,X,Y,T)
% Forward data through the network.
noisePrediction = forward(net,X,Y);

% Compute mean squared error loss between predicted noise and target.
loss = mse(noisePrediction,T);

gradients = dlgradient(loss,net.Learnables);
end

ミニバッチ前処理関数

関数 preprocessMiniBatch は、次の手順でデータを前処理します。

  1. 入力 cell 配列からイメージ データを抽出して数値配列に連結します。

  2. イメージの範囲が [-1,1] となるように再スケーリングします。

function X = preprocessMiniBatch(data)
% Concatenate mini-batch.
X = cat(4,data{:});

% Rescale the images so that the pixel values are in the range [-1 1].
X = rescale(X,-1,1,InputMin=0,InputMax=255);
end

イメージ生成関数

generateImages 関数は、学習済みの DDPM ネットワーク net、分散スケジュール varianceSchedule、イメージ サイズ imageSize、および目的のイメージの数 numImages を入力として受け取ります。この関数は、次の逆拡散処理を使用して生成された numImages 個のイメージのバッチを返します。

  1. 標準正規分布から導出され、サイズが目的のイメージと同じであるガウス ランダム ノイズで構成されるノイズの多いイメージを生成します。

  2. ネットワークを使用して、追加されたノイズ εpred を予測します。

  3. 次の式を使用して、イメージからこのノイズを除去します。

xt-1=11-βt(xt-βt1-αtεpred)+σtz,

ここで、事後分散は σt2=1-αt-11-αtβt であり、z はガウス ランダム ノイズの別の行列です。

4.numNoiseSteps から 1 までカウントダウンしながら、各ノイズ ステップに対して手順 2 ~ 3 を繰り返します。

function images = generateAndDisplayImages(net,varianceSchedule,imageSize,numImages,numChannels,displayFrequency)
% Generate random noise.
images = randn([imageSize numChannels numImages]);

% Compute variance schedule parameters.
alphaBar = cumprod(1 - varianceSchedule);
alphaBarPrev = [1 alphaBar(1:end-1)];
posteriorVariance = varianceSchedule.*(1 - alphaBarPrev)./(1 - alphaBar);

% Reverse the diffusion process.
numNoiseSteps = length(varianceSchedule);

for noiseStep = numNoiseSteps:-1:1
    if noiseStep ~= 1
        z = randn([imageSize,numChannels,numImages]);
    else
        z = zeros([imageSize,numChannels,numImages]);
    end

    % Predict the noise using the network.
    predictedNoise = predict(net,images,noiseStep);

    sqrtOneMinusBeta = sqrt(1 - varianceSchedule(noiseStep));
    addedNoise = sqrt(posteriorVariance(noiseStep))*z;
    predNoise = varianceSchedule(noiseStep)*predictedNoise/sqrt(1 - alphaBar(noiseStep));

    images = 1/sqrtOneMinusBeta*(images - predNoise) + addedNoise;

    % Display intermediate images.
    if mod(noiseStep,displayFrequency) == 0
        tLay = tiledlayout("flow");
        title(tLay,"t = "+ noiseStep)
        for ii = 1:numImages
            nexttile
            imshow(images(:,:,:,ii),[])
        end
        drawnow
    end
end

% Display final images.
tLay = tiledlayout("flow");
title(tLay,"t = 0")
for ii = 1:numImages
    nexttile
    imshow(images(:,:,:,ii),[])
end
end

参考

| | | | |

関連するトピック