Computing Hessian by dlgradient
7 ビュー (過去 30 日間)
古いコメントを表示
Hi every one.
I am using a training loop for my model in which gradients are computing by dlgradient. As you know, dlgradient (through dlfeval) returns a TABLE in which the layers, parameters (weights and bias) and gradients' values are stored. Also, we know that dlgradient accepts "loss" as a SCALLER and dlnet.Learnables, data samples dlX and targets dlY for these computations. I am interested in computing Hesseian for a small network using dlX and dlY. In fact I am going to compute a sub-sampled Hessian if I uses mini-batch dlX. (SO, I do not have problem for storing this matrix then!). However, I do not know how I apply dlgradient one more time for computing Hessian. If someone knows, I would thankfull him/her.
0 件のコメント
回答 (1 件)
Yash
2023 年 12 月 18 日
Hi Mahsa,
To compute the Hessian using dlgradient, you can use the same approach as for computing gradients. However, instead of computing gradients for each parameter, you need to compute the second-order partial derivatives for each pair of parameters. You can use the dlgradient function twice, once for each parameter, and then compute the Hessian matrix using the second-order partial derivatives.
Here is a code snippet you can use as a reference to understand what I want to convey:
Assuming that dlnet is your network, dlX and dlY are your data samples and targets, and mse is your loss function.
% Define the loss function
loss = @(dlY, Y) mse(dlY, Y);
% Compute the gradients for each parameter
[grads, ~] = dlgradient(dlnet, dlX, 'Output', dlY, 'LossFunction', loss);
% Compute the Hessian matrix
H = zeros(numel(dlnet.Learnables), numel(dlnet.Learnables));
for i = 1:numel(dlnet.Learnables)
for j = i:numel(dlnet.Learnables)
% Compute the second-order partial derivative
hessian = dlgradient(grads(i), dlnet.Learnables(j), 'Output', dlY, 'LossFunction', loss);
H(i,j) = hessian;
H(j,i) = hessian;
end
end
The grads variable contains the gradients for each parameter, and the H variable contains the Hessian matrix.
Hope this helps!
0 件のコメント
参考
カテゴリ
Help Center および File Exchange で Calculus についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!