Error for dlarray format, but why?

6 ビュー (過去 30 日間)
Song Decn
Song Decn 2021 年 6 月 1 日
回答済み: Ben 2023 年 6 月 20 日
K>> lstm(dlX, hiddenState, initialCellState, inputWeights, ...
recurrentWeights, bias)
Error using deep.internal.dlarray.validateWeights (line 9)
'U' dimension (if not a formatted dlarray, second dimension) of weights must have size
NumFeatures, where NumFeatures is the size of the 'C' dimension of the input data.
Can any expert help me to solve this issue? Also I am still quite confused about the concept with the format labels C S T U B
Is there any simple explanation for tutorial for their usage?
Many thks
  1 件のコメント
Matt J
Matt J 2023 年 6 月 18 日
編集済み: Matt J 2023 年 6 月 18 日
We need to examine dims(dlX) and the sizes of all your input variables.

サインインしてコメントする。

回答 (1 件)

Ben
Ben 2023 年 6 月 20 日
This error appears to be thrown if the inputWeights have the wrong size, e.g. you can take this example code from help lstm
numFeatures = 10;
numObservations = 32;
sequenceLength = 64;
X = dlarray(randn(numFeatures,numObservations,sequenceLength), 'CBT');
% Create formatted dlarrays for the lstm parameters with three
% hidden units.
numHiddenUnits = 3;
H0 = dlarray(randn(numHiddenUnits,numObservations),'CB');
C0 = dlarray(randn(numHiddenUnits,numObservations),'CB');
weights = dlarray(randn(4*numHiddenUnits,numFeatures),'CU');
recurrent = dlarray(randn(4*numHiddenUnits,numHiddenUnits),'CU');
bias = dlarray(randn(4*numHiddenUnits,1),'C');
% Apply an lstm calculation
[Y,hiddenState,cellState] = lstm(X,H0,C0,weights,recurrent,bias);
If you now make weights the wrong size in the 2nd dimension you get the error:
errorWeights = dlarray(randn(4*numHiddenUnits,numFeatures+1),'CU');
lstm(X,H0,C0,errorWeights,recurrent,bias); % throws error
This suggests your inputWeights have the wrong size to use lstm. The inputWeights require a size of 4*NumHiddenUnits x NumFeatures, and they can either be a dlarray with format labels or without:
% both of these are valid - the format label U is just to specify that this
% dimension doesn't correspond to any of the standard named labels S -
% spatial, C - channel, T - time, B - batch.
weights = dlarray(randn(4*numHiddenUnits,numFeatures),'CU');
weights = dlarray(randn(4*numHiddenUnits,numFeatures));
If you list the sizes as @Matt J says then we can debug the issue further.

カテゴリ

Help Center および File ExchangeImage Data Workflows についてさらに検索

タグ

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by