フィルターのクリア

How do I create a custom layer with 2 inputs?

3 ビュー (過去 30 日間)
Mark White
Mark White 2022 年 2 月 3 日
回答済み: MR 2022 年 10 月 26 日
I have defined a custom layer endStateLayer to take two inputs, but when I try to create the dlnetwork I get an error saying the second input is not connected.
endStateLayer is designed to take the sequence output from a LSTM layer and the sequence lengths so the end state can be identified in each observation. (The sequences were padded at the end.) I have another custom layer, seqenceLengthLayer, which determines the sequence lengths.
This is my network:
[XTrain,YTrain] = japaneseVowelsTrainData;
numObservations = length(XTrain);
embedDim = size(XTrain{1},1);
latentDim = 9;
numHiddenUnits = 100;
filterSize = 3;
numFilters = 32;
% define encoder network
layersEnc = [
sequenceInputLayer( embedDim, 'Name', 'in' )
lstmLayer( numHiddenUnits, 'OutputMode', 'sequence' )
endStateLayer( 'Name', 'endstate' )
fullyConnectedLayer( latentDim ) ];
lgraphEnc = layerGraph( layersEnc );
lgraphEnc = addLayers( lgraphEnc, ...
sequenceLengthLayer( 0, 'Name', 'seqlen' ) ); % padding indicator is 0
lgraphEnc = connectLayers( lgraphEnc, ...
'in', 'seqlen' );
lgraphEnc = connectLayers( lgraphEnc, ...
'seqlen', 'endstate/len' );
dlnetEnc = dlnetwork( layersEnc );
It appears to be fully connected based on analyzeNetwork( lgraphEnc ):
The errors are to be expected when using lgraph for dlnetwork without an output.
However, I get the following error when I try to create the dlnetwork:
dlnetEnc = dlnetwork( layersEnc );
Error using dlnetwork/initialize (line 481)
Invalid network.
Error in dlnetwork (line 218)
net = initialize(net, dlX{:});
Error in lstmTest (line 26)
dlnetEnc = dlnetwork( layersEnc );
Caused by:
Example inputs: Incorrect number of example network inputs. 0 example network inputs provided but network has 2 inputs including 1 unconnected layer inputs.
Layer 'endstate': Unconnected input. Each input must be connected to input data or to the output of another layer.
Detected unconnected inputs:
input 'len'
This is the sequenceLengthLayer definition. I don't think the particulars with the predict function matter here, but I'm showing it for completeness. (I've not been able to check it fully because i can't setup the network.)
classdef sequenceLengthLayer < nnet.layer.Layer & ...
nnet.layer.Formattable
properties
% (Optional) Layer properties.
PaddingIndicator
end
properties (Learnable)
% Layer learnable parameters.
end
methods
function layer = sequenceLengthLayer( padIndicator, NameValueArgs )
% layer = sequenceLengthLayer( padIndicator )
% creates an sequenceLengthLayer object that determines
% the length of the input sequence
% Parse input arguments.
arguments
padIndicator = 0;
NameValueArgs.Name = '';
end
name = NameValueArgs.Name;
% Set layer name.
layer.Name = name;
% Set layer description.
layer.Description = "Sequence length layer for padding " ...
+ join(string(padIndicator));
% Set layer type.
layer.Type = "Sequence Length";
% set the padding indicator.
layer.PaddingIndicator = padIndicator;
end
function L = predict( layer, X )
% Forward input data through the layer at prediction time and
% output the result.
%
% Inputs:
% layer - Layer to forward propagate through
% X - Input data, specified as a formatted dlarray
% with a 'T' and 'C' dimension
% Outputs:
% Z - Output of layer forward function returned as
% a formatted dlarray with format 'CB'.
maxLength = size( X, 3 );
miniBatchSize = size(X, 2);
L = zeros( 1, miniBatchSize, 'like', X);
% find where X contains padding using the first channel
isPadding = (X(1,:,:)==layer.PaddingIndicator);
isPadding = logical( extractdata(isPadding) );
for i = 1:miniBatchSize
found = false;
padStart = 0;
while ~found && padStart<(maxLength-1)
% find where the padding begins
padStart = find( isPadding( padStart+1:end,i ), 1, 'first' );
% check if the paddings continues to the end
found = all(isPadding( padStart:end, i ));
end
if isempty( padStart )
% no padding found - go to the end
padStart = maxLength;
end
L(i) = padStart;
end
L = dlarray( L, 'CB' );
end
end
end
This is the endStateLayer where I think the problem lies. It is the one saying it does not get an second input. Is my construction function correct?
classdef endStateLayer < nnet.layer.Layer & ...
nnet.layer.Formattable
properties
% (Optional) Layer properties.
end
properties (Learnable)
% Layer learnable parameters.
end
methods
function layer = endStateLayer( NameValueArgs )
% layer = endStateLayer( outputSize )
% creates an endStateLayer object that extracts the
% state of a sequence, X, a specified point, L
% Parse input arguments.
arguments
NameValueArgs.Name = '';
end
name = NameValueArgs.Name;
% Set layer name.
layer.Name = name;
% Set layer description.
layer.Description = "End state layer ";
% Set layer type.
layer.Type = "End State";
% set the inputs.
layer.NumInputs = 2;
layer.InputNames = { 'in', 'len' };
end
function Z = predict( layer, X, L )
% Forward input data through the layer at prediction time and
% output the result.
%
% Inputs:
% layer - Layer to forward propagate through
% X - Input sequence data, specified as a
% formatted dlarray with a 'T' and 'C' dims
% L - Input sequence length data.
% Outputs:
% Z - Output of layer forward function returned as
% a formatted dlarray with format 'CB'.
miniBatchSize = size(X, 2);
Z = zeros( size(X,1), size(X,2), 'like', X);
for i = 1:miniBatchSize
Z(:,i) = X(:, i, L(i));
end
Z = dlarray(Z, 'CB');
end
end
end
I can't see what I have done wrong. I have reviewed the examples online. What have I missed?

回答 (1 件)

MR
MR 2022 年 10 月 26 日
Hi Mark,
I have run your code on Matlab 2022b and I was able to reproduce your error message. Then I ran analyzeNetwork(lgraphEnc,"TargetUsage","dlnetwork") and I didn't get any error message. In a nutshell I think it might just be a typo. Write
dlnetEnc = dlnetwork(lgraphEnc); instead of dlnetEnc = dlnetwork( layersEnc );

カテゴリ

Help Center および File ExchangeBuild Deep Neural Networks についてさらに検索

製品


リリース

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by