How can I define a custom loss function using trainnet?

13 ビュー (過去 30 日間)
Matthew Murray
Matthew Murray 2024 年 3 月 29 日
編集済み: Matt J 2024 年 3 月 29 日
Hello,
I am trying to define a custom loss function using trainnet. The documentation says:
If the trainnet function does not provide the loss function that you need for your task, then you can specify a custom loss function to the trainnet as a function handle. The function must have the syntax loss = f(Y,T), where Y and T are the predictions and targets, respectively.
However, I am not sure how the predictions and targets are defined here. I am currently using trainnet as follows:
trainedNet = trainnet(dsTrain,layers,"mse",options);
dsTrain is a datastore containing the input and target images for the regression problem. But I would like change the loss to a custom function involving ssim. I would like something similar to the following, although, I know this isn't quite right:
trainedNet = trainnet(dsTrain,layers,@(Y,targets) 1-ssim(Y,targets),options);
I get the following errror message:
Error using trainnet
Value to differentiate is non-scalar. It must be a traced real dlarray scalar.
Thanks!

回答 (1 件)

Matt J
Matt J 2024 年 3 月 29 日
編集済み: Matt J 2024 年 3 月 29 日
If you have multichannel output, the loss function will give you an SSIM per channel, e..g,
loss = @(Y,targets) 1-ssim(Y,targets);
[Y,T]=deal(dlarray(rand(5,4,8),'SSC'));
L=loss(Y,T);
whos L
Name Size Bytes Class Attributes L 1x1x8 70 dlarray
You need to decide how you want this reduced to a single value.

カテゴリ

Help Center および File ExchangeDeep Learning Toolbox についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by