Main Content

encoderDecoderNetwork

符号化器-復号化器ネットワークの作成

R2021a 以降

説明

net = encoderDecoderNetwork(inputSize,encoder,decoder) は符号化器ネットワークと復号化器ネットワークを接続し、符号化器-復号化器ネットワーク net を作成します。

この関数には Deep Learning Toolbox™ が必要です。

net = encoderDecoderNetwork(inputSize,encoder,decoder,Name,Value) は、名前と値の引数を使用して符号化器-復号化器ネットワークの特性を変更します。

すべて折りたたむ

4 つの符号化器ブロックで構成される符号化器モジュールを作成します。

encoderBlock = @(block) [
    convolution2dLayer(3,2^(5+block),"Padding",'same')
    reluLayer
    convolution2dLayer(3,2^(5+block),"Padding",'same')
    reluLayer
    maxPooling2dLayer(2,"Stride",2)];
encoder = blockedNetwork(encoderBlock,4,"NamePrefix","encoder_");

4 つの復号化器ブロックで構成される復号化器モジュールを作成します。

decoderBlock = @(block) [
    transposedConv2dLayer(2,2^(10-block),'Stride',2)
    convolution2dLayer(3,2^(10-block),"Padding",'same')
    reluLayer
    convolution2dLayer(3,2^(10-block),"Padding",'same')
    reluLayer];
decoder = blockedNetwork(decoderBlock,4,"NamePrefix","decoder_");

ブリッジ層を作成します。

bridge = [
    convolution2dLayer(3,1024,"Padding",'same')
    reluLayer
    convolution2dLayer(3,1024,"Padding",'same')
    reluLayer
    dropoutLayer(0.5)];            

ネットワーク入力サイズを指定します。

inputSize = [224 224 3];

符号化器モジュール、ブリッジ、および復号化器モジュールを接続し、スキップ接続を追加して、U-Net ネットワークを作成します。

unet = encoderDecoderNetwork(inputSize,encoder,decoder, ...
    "OutputChannels",3, ...
    "SkipConnections","concatenate", ...
    "LatentNetwork",bridge)
unet = 
  dlnetwork with properties:

         Layers: [55x1 nnet.cnn.layer.Layer]
    Connections: [62x2 table]
     Learnables: [46x3 table]
          State: [0x3 table]
     InputNames: {'encoderImageInputLayer'}
    OutputNames: {'encoderDecoderFinalConvLayer'}
    Initialized: 1

  View summary with summary.

ネットワークを表示します。

analyzeNetwork(unet)

事前学習済みの GoogLeNet ネットワークから、4 つのダウンサンプリング演算をもつ GAN 符号化器ネットワークを作成します。

depth = 4;
[encoder,outputNames] = pretrainedEncoderNetwork('googlenet',depth);

符号化器ネットワークの入力サイズを決定します。

inputSize = encoder.Layers(1).InputSize;

まずサンプル データ入力を作成し、活性化を返す forward を呼び出すことで、符号化器ネットワーク内の活性化層の出力サイズを決定します。

exampleInput = dlarray(zeros(inputSize),'SSC');
exampleOutput = cell(1,length(outputNames));
[exampleOutput{:}] = forward(encoder,exampleInput,'Outputs',outputNames);

復号化器ブロック内のチャネル数を、各活性化の 3 番目のチャネルの長さとして決定します。

numChannels = cellfun(@(x) size(extractdata(x),3),exampleOutput);
numChannels = fliplr(numChannels(1:end-1));

1 つの復号化器ブロックの層の配列を作成する関数を定義します。

decoderBlock = @(block) [
    transposedConv2dLayer(2,numChannels(block),'Stride',2)
    convolution2dLayer(3,numChannels(block),'Padding','same')
    reluLayer
    convolution2dLayer(3,numChannels(block),'Padding','same')
    reluLayer];

符号化器モジュール内のダウンサンプリングされたブロックの数と同じ数のアップサンプリング ブロックをもつ復号化器モジュールを作成します。

decoder = blockedNetwork(decoderBlock,depth);

符号化器モジュールと復号化器モジュールを接続し、スキップ接続を追加して、U-Net ネットワークを作成します。

net = encoderDecoderNetwork([224 224 3],encoder,decoder, ...
   'OutputChannels',3,'SkipConnections','concatenate')
net = 
  dlnetwork with properties:

         Layers: [139x1 nnet.cnn.layer.Layer]
    Connections: [167x2 table]
     Learnables: [116x3 table]
          State: [0x3 table]
     InputNames: {'data'}
    OutputNames: {'encoderDecoderFinalConvLayer'}
    Initialized: 1

  View summary with summary.

ネットワークを表示します。

analyzeNetwork(net)

入力引数

すべて折りたたむ

ネットワーク入力サイズ。正の整数の 3 要素ベクトルとして指定します。inputSize の形式は [H W C] で、"H" は高さ、"W" は幅、"C" はチャネル数です。

例: [28 28 3] は 3 チャネル イメージに 28 × 28 ピクセルの入力サイズを指定します。

符号化器ネットワーク。dlnetwork (Deep Learning Toolbox) オブジェクトとして指定します。

復号化器ネットワーク。dlnetwork (Deep Learning Toolbox) オブジェクトとして指定します。このネットワークには 1 つの入力と 1 つの出力がなければなりません。

名前と値の引数

オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで、Name は引数名で、Value は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。

R2021a より前では、コンマを使用して名前と値をそれぞれ区切り、Name を引用符で囲みます。

例: 'SkipConnections',"concatenate" は、符号化器ネットワークと復号化器ネットワークの間のスキップ接続のタイプを連結として指定します。

符号化器と復号化器を接続するネットワーク。layer または layer の配列として指定します。

復号化器の出力に接続されたネットワーク。layer または layer の配列として指定します。'OutputChannels' 引数を指定した場合、最終的なネットワークは復号化器の最後の 1 行 1 列の畳み込み層の後に接続されます。

復号化器ネットワークの出力チャネルの数。正の整数として指定します。この引数を指定した場合、復号化器の最後の層では指定されたチャネル数で 1 行 1 列の畳み込み演算を実行します。

スキップ接続により活性化がマージされる符号化器/復号化器層のペアの名前。次のいずれかの値を指定します。

  • "auto" — 関数 encoderDecoderNetwork が符号化器/復号化器層のペアの名前を自動的に決定します。

  • M 行 2 列 string 配列 — 最初の列は符号化器層の名前、2 番目の列は各復号化器層の名前となります。

'SkipConnections' 引数を "none" に指定した場合、関数 encoderDecoderNetwork は 'SkipConnectionNames' の値を無視します。

データ型: char | string

符号化器ネットワークと復号化器ネットワークの間のスキップ接続タイプ。"none""auto"、または "concatenate" として指定します。

データ型: char | string

出力引数

すべて折りたたむ

符号化器/復号化器ネットワーク。dlnetwork (Deep Learning Toolbox) オブジェクトとして返されます。

バージョン履歴

R2021a で導入