blockedNetwork
説明
例
U-Net スタイルの符号化器の作成
層の配列を作成する関数を定義します。最初のブロックは、畳み込み層に 32 個のフィルターをもちます。連続するブロックごとにフィルターの数が倍になります。
unetBlock = @(block) [
convolution2dLayer(3,2^(5+block))
reluLayer
convolution2dLayer(3,2^(5+block))
reluLayer
maxPooling2dLayer(2,"Stride",2)];
複数の層を含む 4 つの反復するブロックからなるネットワークを作成します。ネットワーク内のすべての層の名前に接頭辞 "encoder_" を追加します。
net = blockedNetwork(unetBlock,4,"NamePrefix","encoder_")
net = dlnetwork with properties: Layers: [20x1 nnet.cnn.layer.Layer] Connections: [19x2 table] Learnables: [16x3 table] State: [0x3 table] InputNames: {'encoder_Block1Layer1'} OutputNames: {'encoder_Block4Layer5'} Initialized: 0 View summary with summary.
サイズが [224 224 3] の入力用のネットワークの重みを初期化します。
net = initialize(net,dlarray(zeros(224,224,3),"SSC"));
ネットワークを表示します。
analyzeNetwork(net)
事前学習済みの GoogLeNet からの U-Net の作成
事前学習済みの GoogLeNet ネットワークから、4 つのダウンサンプリング演算をもつ GAN 符号化器ネットワークを作成します。
depth = 4;
[encoder,outputNames] = pretrainedEncoderNetwork('googlenet',depth);
符号化器ネットワークの入力サイズを決定します。
inputSize = encoder.Layers(1).InputSize;
まずサンプル データ入力を作成し、活性化を返す forward
を呼び出すことで、符号化器ネットワーク内の活性化層の出力サイズを決定します。
exampleInput = dlarray(zeros(inputSize),'SSC'); exampleOutput = cell(1,length(outputNames)); [exampleOutput{:}] = forward(encoder,exampleInput,'Outputs',outputNames);
復号化器ブロック内のチャネル数を、各活性化の 3 番目のチャネルの長さとして決定します。
numChannels = cellfun(@(x) size(extractdata(x),3),exampleOutput); numChannels = fliplr(numChannels(1:end-1));
1 つの復号化器ブロックの層の配列を作成する関数を定義します。
decoderBlock = @(block) [ transposedConv2dLayer(2,numChannels(block),'Stride',2) convolution2dLayer(3,numChannels(block),'Padding','same') reluLayer convolution2dLayer(3,numChannels(block),'Padding','same') reluLayer];
符号化器モジュール内のダウンサンプリングされたブロックの数と同じ数のアップサンプリング ブロックをもつ復号化器モジュールを作成します。
decoder = blockedNetwork(decoderBlock,depth);
符号化器モジュールと復号化器モジュールを接続し、スキップ接続を追加して、U-Net ネットワークを作成します。
net = encoderDecoderNetwork([224 224 3],encoder,decoder, ... 'OutputChannels',3,'SkipConnections','concatenate')
net = dlnetwork with properties: Layers: [139x1 nnet.cnn.layer.Layer] Connections: [167x2 table] Learnables: [116x3 table] State: [0x3 table] InputNames: {'data'} OutputNames: {'encoderDecoderFinalConvLayer'} Initialized: 1 View summary with summary.
ネットワークを表示します。
analyzeNetwork(net)
入力引数
fun
— 層のブロックを作成する関数
関数
層のブロックを作成する関数。次のシグネチャをもつ関数として指定します。
block = fun(blockIndex)
fun
およびblockIndex
への入力は、範囲 [1,numBlocks
] の整数です。fun
およびblock
からの出力は、層または層配列です。
numBlocks
— ブロック数
正の整数
ネットワーク内のブロック数。正の整数として指定します。
namePrefix
— すべての層の名前の接頭辞
""
(既定値) | string | 文字ベクトル
ネットワーク内のすべての層の接頭辞。string または文字ベクトルとして指定します。
データ型: char
| string
出力引数
ヒント
blockedNetwork
によって返されるdlnetwork
(Deep Learning Toolbox) は初期化されていないため、学習や推定に使用できる状態ではありません。ネットワークを初期化するには、関数initialize
(Deep Learning Toolbox) を使用します。関数
encoderDecoderNetwork
を使用して符号化器ネットワークを復号化器ネットワークに接続します。
バージョン履歴
R2021a で導入
MATLAB コマンド
次の MATLAB コマンドに対応するリンクがクリックされました。
コマンドを MATLAB コマンド ウィンドウに入力して実行してください。Web ブラウザーは MATLAB コマンドをサポートしていません。
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list:
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)