フィルターのクリア

Calculating gradient for equation in higher dimension?

1 回表示 (過去 30 日間)
Rohit Gupta
Rohit Gupta 2018 年 6 月 5 日
編集済み: Christine Tobler 2018 年 6 月 6 日
Suppose I have a loss function in higher dimension say tucker decomposition
Are there any tools for automatic gradient calculation and optimization.

回答 (2 件)

Christine Tobler
Christine Tobler 2018 年 6 月 5 日
If the matrices and ND-arrays are of a fixed dimension, you can create symbolic variables for them:
A = sym('A', [3 3 3])
and then multiply these together as described by the Tucker decomposition. n-mode product can be achieved by permuting, reshaping and applying matrix multiplication.
However, the output of sym/gradient will not be in a nice closed matrix-notation form: It will just contain a formula with all individual elements of the matrices you constructed. Here's an example of what I mean:
>> A = sym('A', [3 3])
A =
[ A1_1, A1_2, A1_3]
[ A2_1, A2_2, A2_3]
[ A3_1, A3_2, A3_3]
>> x = sym('x', [3 1])
x =
x1
x2
x3
>> gradient(x.'*A*x, x)
ans =
2*A1_1*x1 + A1_2*x2 + A1_3*x3 + A2_1*x2 + A3_1*x3
A1_2*x1 + A2_1*x1 + 2*A2_2*x2 + A2_3*x3 + A3_2*x3
A1_3*x1 + A2_3*x2 + A3_1*x1 + A3_2*x2 + 2*A3_3*x3
% Compare to the simple matrix form of the gradient:
>> simplify(gradient(x.'*A*x, x) == (A+A.')*x)
ans =
TRUE
TRUE
TRUE
  2 件のコメント
Rohit Gupta
Rohit Gupta 2018 年 6 月 6 日
Thanks, but i tried it for
g = gradient((norm(A-U*S*V')),U)
Where A[10 x 10], U[10 x 4], S[4 x 4], V[10 x 4]. It hangs
Christine Tobler
Christine Tobler 2018 年 6 月 6 日
編集済み: Christine Tobler 2018 年 6 月 6 日
The problem here is not the gradient so much as the norm: The 2-norm of a matrix is based on its singular value decomposition, and the closed form solution of the singular value decomposition is quite long even for just a 2-by-2 or 3-by-3 matrix.
If you really need to use the matrix 2-norm, the best way to compute the gradient might be through numeric differentiation.
If you can equivalently use the Frobenius norm here (ideally even the square Frobenius norm), this would be much easier to compute.
A = sym('A', [3 3]);
U = sym('U', [3 2]);
V = sym('V', [3 2]);
S = sym('S', [2 2]);
gradient(norm(A - U*S*V', 'fro')^2)
This finishes in a few seconds, and is probably simple enough that you could deduce the matrix-based closed-form version.

サインインしてコメントする。


John D'Errico
John D'Errico 2018 年 6 月 5 日
help sym/gradient
  1 件のコメント
Rohit Gupta
Rohit Gupta 2018 年 6 月 5 日
編集済み: Rohit Gupta 2018 年 6 月 5 日
will this work for complicated functions involving n-mode product, trace etc?

サインインしてコメントする。

カテゴリ

Help Center および File ExchangeLinear Algebra についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by