netTrained = trainnet(s​equences,t​argets,net​,lossFcn,o​ptions),se​quences包含复​数无法使用此函数

3 ビュー (過去 30 日間)
Alexander Liao
Alexander Liao 2024 年 7 月 9 日
編集済み: Paras Gupta 2024 年 7 月 18 日
问题:
应用函数netTrained = trainnet(sequences,targets,net,lossFcn,options),
sequences包含复数时如何使用此函数?
函数说明里有提示可使用复数输入:This argument supports complex-valued predictors and targets.
代码:
XTrain = permute(dataTrain(:,1:end-1,:),[1,3,2]);
TTrain = permute(dataTrain(:,2:end,:),[1,3,2]);
numChannels = betalen;
layers = [
sequenceInputLayer(numChannels)
lstmLayer(128)
fullyConnectedLayer(numChannels)];
options = trainingOptions("adam", ...
MaxEpochs=200, ...
SequencePaddingDirection="left", ...
Shuffle="every-epoch", ...
Plots="training-progress", ...
Verbose=false);
net = trainnet(XTrain,TTrain,layers,"mse",options);
报错结果:
错误使用 trainnet (第 46 行)
在层 'lstm' 期间执行失败。
出错 HDL (第 66 行)
net = trainnet(XTrain,TTrain,layers,"mse",options);
原因:
错误使用 dlarray/lstm (第 105 行)
位置 1 处的参数无效。 值必须为实数。

回答 (1 件)

Paras Gupta
Paras Gupta 2024 年 7 月 18 日
編集済み: Paras Gupta 2024 年 7 月 18 日
Hi Alexander,
I understand that you are trying to use the "trainnet" function on complex-valued sequences and complex-valued targets.
You are correct in noting that the documentation indicates that the "trainnet" function can support complex-valued predictors and targets. However, the built-in loss functions provided by "trainnet" do not inherently support complex-valued targets. To address this, you will need to define a custom loss function that can handle complex values for targets.
Moreover, the "sequenceInputLayer" in your model should be configured to handle complex-valued inputs. This can be done by setting the "SplitComplexInputs" argument to true.
Below is an example of a custom loss function for complex inputs, which you can use in your training loop:
% dummy data
numSamples = 100;
numTimesteps = 10;
numChannels = 2;
realPart = randn(numSamples, numTimesteps, numChannels);
imagPart = randn(numSamples, numTimesteps, numChannels);
dataTrain = realPart + 1i * imagPart;
XTrain = permute(dataTrain(:,1:end-1,:),[1,3,2]);
% complex target
TTrain = permute(dataTrain(:,2:end,:),[1,3,2]);
% real target
% TTrain = rand(numSamples, numChannels, numTimesteps-1);
numChannels = 2;
layers = [
sequenceInputLayer(numChannels, SplitComplexInputs=true) % split Complex Inputs
lstmLayer(128)
fullyConnectedLayer(numChannels)];
options = trainingOptions("adam", ...
MaxEpochs=200, ...
SequencePaddingDirection="left", ...
Shuffle="every-epoch", ...
Plots="training-progress", ...
Verbose=false);
% net = trainnet(XTrain, TTrain, layers, "mse", options);
% custom loss function passed as function handle
net = trainnet(XTrain, TTrain, layers, @complexLoss, options);
function loss = complexLoss(Y, T)
difference = Y - T;
squaredMagnitude = real(difference).^2;
loss = mean(squaredMagnitude, 'all');
end
You can refer to the following documentation links for more information on the code above:
Hope this helps with your work.

カテゴリ

Help Center および File ExchangeDeep Learning Toolbox についてさらに検索

製品


リリース

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by