encoderDecoderNetwork
構文
説明
例
符号化器と復号化器のブロックからの U-Net ネットワークの作成
4 つの符号化器ブロックで構成される符号化器モジュールを作成します。
encoderBlock = @(block) [ convolution2dLayer(3,2^(5+block),"Padding",'same') reluLayer convolution2dLayer(3,2^(5+block),"Padding",'same') reluLayer maxPooling2dLayer(2,"Stride",2)]; encoder = blockedNetwork(encoderBlock,4,"NamePrefix","encoder_");
4 つの復号化器ブロックで構成される復号化器モジュールを作成します。
decoderBlock = @(block) [ transposedConv2dLayer(2,2^(10-block),'Stride',2) convolution2dLayer(3,2^(10-block),"Padding",'same') reluLayer convolution2dLayer(3,2^(10-block),"Padding",'same') reluLayer]; decoder = blockedNetwork(decoderBlock,4,"NamePrefix","decoder_");
ブリッジ層を作成します。
bridge = [ convolution2dLayer(3,1024,"Padding",'same') reluLayer convolution2dLayer(3,1024,"Padding",'same') reluLayer dropoutLayer(0.5)];
ネットワーク入力サイズを指定します。
inputSize = [224 224 3];
符号化器モジュール、ブリッジ、および復号化器モジュールを接続し、スキップ接続を追加して、U-Net ネットワークを作成します。
unet = encoderDecoderNetwork(inputSize,encoder,decoder, ... "OutputChannels",3, ... "SkipConnections","concatenate", ... "LatentNetwork",bridge)
unet = dlnetwork with properties: Layers: [55x1 nnet.cnn.layer.Layer] Connections: [62x2 table] Learnables: [46x3 table] State: [0x3 table] InputNames: {'encoderImageInputLayer'} OutputNames: {'encoderDecoderFinalConvLayer'} Initialized: 1 View summary with summary.
ネットワークを表示します。
analyzeNetwork(unet)
事前学習済みの GoogLeNet からの U-Net の作成
事前学習済みの GoogLeNet ネットワークから、4 つのダウンサンプリング演算をもつ GAN 符号化器ネットワークを作成します。
depth = 4;
[encoder,outputNames] = pretrainedEncoderNetwork('googlenet',depth);
符号化器ネットワークの入力サイズを決定します。
inputSize = encoder.Layers(1).InputSize;
まずサンプル データ入力を作成し、活性化を返す forward
を呼び出すことで、符号化器ネットワーク内の活性化層の出力サイズを決定します。
exampleInput = dlarray(zeros(inputSize),'SSC'); exampleOutput = cell(1,length(outputNames)); [exampleOutput{:}] = forward(encoder,exampleInput,'Outputs',outputNames);
復号化器ブロック内のチャネル数を、各活性化の 3 番目のチャネルの長さとして決定します。
numChannels = cellfun(@(x) size(extractdata(x),3),exampleOutput); numChannels = fliplr(numChannels(1:end-1));
1 つの復号化器ブロックの層の配列を作成する関数を定義します。
decoderBlock = @(block) [ transposedConv2dLayer(2,numChannels(block),'Stride',2) convolution2dLayer(3,numChannels(block),'Padding','same') reluLayer convolution2dLayer(3,numChannels(block),'Padding','same') reluLayer];
符号化器モジュール内のダウンサンプリングされたブロックの数と同じ数のアップサンプリング ブロックをもつ復号化器モジュールを作成します。
decoder = blockedNetwork(decoderBlock,depth);
符号化器モジュールと復号化器モジュールを接続し、スキップ接続を追加して、U-Net ネットワークを作成します。
net = encoderDecoderNetwork([224 224 3],encoder,decoder, ... 'OutputChannels',3,'SkipConnections','concatenate')
net = dlnetwork with properties: Layers: [139x1 nnet.cnn.layer.Layer] Connections: [167x2 table] Learnables: [116x3 table] State: [0x3 table] InputNames: {'data'} OutputNames: {'encoderDecoderFinalConvLayer'} Initialized: 1 View summary with summary.
ネットワークを表示します。
analyzeNetwork(net)
入力引数
inputSize
— ネットワーク入力サイズ
正の整数の 3 要素ベクトル
ネットワーク入力サイズ。正の整数の 3 要素ベクトルとして指定します。inputSize
の形式は [H W C] で、"H" は高さ、"W" は幅、"C" はチャネル数です。
例: [28 28 3]
は 3 チャネル イメージに 28 × 28 ピクセルの入力サイズを指定します。
encoder
— 符号化器ネットワーク
dlnetwork
オブジェクト
符号化器ネットワーク。dlnetwork
(Deep Learning Toolbox) オブジェクトとして指定します。
decoder
— 復号化器ネットワーク
dlnetwork
オブジェクト
復号化器ネットワーク。dlnetwork
(Deep Learning Toolbox) オブジェクトとして指定します。このネットワークには 1 つの入力と 1 つの出力がなければなりません。
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN
として指定します。ここで、Name
は引数名で、Value
は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。
R2021a より前では、コンマを使用して名前と値の各ペアを区切り、Name
を引用符で囲みます。
例: 'SkipConnections',"concatenate"
は、符号化器ネットワークと復号化器ネットワークの間のスキップ接続のタイプを連結として指定します。
LatentNetwork
— 符号化器と復号化器を接続するネットワーク
[]
(既定値) | layer オブジェクト | layer オブジェクトの配列
符号化器と復号化器を接続するネットワーク。layer または layer の配列として指定します。
FinalNetwork
— 復号化器の出力に接続されたネットワーク
[]
(既定値) | layer オブジェクト | layer オブジェクトの配列
復号化器の出力に接続されたネットワーク。layer または layer の配列として指定します。'OutputChannels
' 引数を指定した場合、最終的なネットワークは復号化器の最後の 1 行 1 列の畳み込み層の後に接続されます。
OutputChannels
— 出力チャネル数
[]
(既定値) | 正の整数
復号化器ネットワークの出力チャネルの数。正の整数として指定します。この引数を指定した場合、復号化器の最後の層では指定されたチャネル数で 1 行 1 列の畳み込み演算を実行します。
SkipConnectionNames
— 符号化器/復号化器層のペアの名前
"auto"
(既定値) | M 行 2 列の string 配列
スキップ接続により活性化がマージされる符号化器/復号化器層のペアの名前。次のいずれかの値を指定します。
"auto"
— 関数encoderDecoderNetwork
が符号化器/復号化器層のペアの名前を自動的に決定します。M 行 2 列 string 配列 — 最初の列は符号化器層の名前、2 番目の列は各復号化器層の名前となります。
'SkipConnections
' 引数を "none"
に指定した場合、関数 encoderDecoderNetwork
は 'SkipConnectionNames
' の値を無視します。
データ型: char
| string
SkipConnections
— スキップ接続のタイプ
"none"
(既定値) | "auto"
| "concatenate"
符号化器ネットワークと復号化器ネットワークの間のスキップ接続タイプ。"none"
、"auto"
、または "concatenate"
として指定します。
データ型: char
| string
出力引数
バージョン履歴
R2021a で導入
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)