Computing Hessian by dlgradient

Hi every one.
I am using a training loop for my model in which gradients are computing by dlgradient. As you know, dlgradient (through dlfeval) returns a TABLE in which the layers, parameters (weights and bias) and gradients' values are stored. Also, we know that dlgradient accepts "loss" as a SCALLER and dlnet.Learnables, data samples dlX and targets dlY for these computations. I am interested in computing Hesseian for a small network using dlX and dlY. In fact I am going to compute a sub-sampled Hessian if I uses mini-batch dlX. (SO, I do not have problem for storing this matrix then!). However, I do not know how I apply dlgradient one more time for computing Hessian. If someone knows, I would thankfull him/her.

回答 (1 件)

Yash
Yash 2023 年 12 月 18 日

0 投票

Hi Mahsa,
To compute the Hessian using dlgradient, you can use the same approach as for computing gradients. However, instead of computing gradients for each parameter, you need to compute the second-order partial derivatives for each pair of parameters. You can use the dlgradient function twice, once for each parameter, and then compute the Hessian matrix using the second-order partial derivatives.
Here is a code snippet you can use as a reference to understand what I want to convey:
Assuming that dlnet is your network, dlX and dlY are your data samples and targets, and mse is your loss function.
% Define the loss function
loss = @(dlY, Y) mse(dlY, Y);
% Compute the gradients for each parameter
[grads, ~] = dlgradient(dlnet, dlX, 'Output', dlY, 'LossFunction', loss);
% Compute the Hessian matrix
H = zeros(numel(dlnet.Learnables), numel(dlnet.Learnables));
for i = 1:numel(dlnet.Learnables)
for j = i:numel(dlnet.Learnables)
% Compute the second-order partial derivative
hessian = dlgradient(grads(i), dlnet.Learnables(j), 'Output', dlY, 'LossFunction', loss);
H(i,j) = hessian;
H(j,i) = hessian;
end
end
The grads variable contains the gradients for each parameter, and the H variable contains the Hessian matrix.
Hope this helps!

製品

リリース

R2021a

質問済み:

2022 年 2 月 4 日

回答済み:

2023 年 12 月 18 日

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by