フィルターのクリア

Effficient Computation of Matrix Gradient

5 ビュー (過去 30 日間)
Shreyas Bharadwaj
Shreyas Bharadwaj 2024 年 4 月 9 日
コメント済み: Shreyas Bharadwaj 2024 年 4 月 9 日
Hi,
I am trying to compute the gradient of a matrix-valued function . I have computed the element-wise gradient as and have verified that it is correct numerically (for my purposes of gradient descent).
My MATLAB implementation of the above gradient is:
for p = 1:N
for q = 1:N
gradX(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
which I have also verified is correct numerically.
However, my issue is that N = 750, so this computation is extremely slow and impractical for gradient descent: on my desktop with 32 GB RAM and an Intel Xeon 3.7 GHz processor, one iteration takes around 10-15 minutes. I expect to need several hundred iterations for convergence.
I was wondering if there is any obvious way I am missing to speed up or parallelize it. I have tried parfor but have not had any luck.
Thank you and I very much appreciate any suggestions.
  2 件のコメント
Bruno Luong
Bruno Luong 2024 年 4 月 9 日
編集済み: Bruno Luong 2024 年 4 月 9 日
Whare is a typical size of w (or AXB)?
btw the first obvious optimization is pre multiply w with AXB.
Shreyas Bharadwaj
Shreyas Bharadwaj 2024 年 4 月 9 日
編集済み: Shreyas Bharadwaj 2024 年 4 月 9 日
Thank you, I will do that. All matrices, including w, are of size N x N i.e. 750 x 750.

サインインしてコメントする。

採用された回答

Bruno Luong
Bruno Luong 2024 年 4 月 9 日
The best
N = 200; % 750
gradX_1 = zeros(N,N);
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
tic
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
t1=toc
t1 = 15.1666
% Method 3
tic
C = w .* AXB;
gradX = A' * C * B';
t2=toc
t2 = 0.0049
err = norm(gradX(:)-gradX_1(:),'inf') / norm(gradX_1(:))
err = 2.4063e-17
fprintf('New code version 3 is %g faster\n', t1/t2)
New code version 3 is 3088.92 faster
  1 件のコメント
Shreyas Bharadwaj
Shreyas Bharadwaj 2024 年 4 月 9 日
Thank you very much! This is exactly what I was looking for.

サインインしてコメントする。

その他の回答 (1 件)

Bruno Luong
Bruno Luong 2024 年 4 月 9 日
I propose this, and time testing for N = 200;
N = 200; % 750
gradX_1 = zeros(N,N);
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
tic
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
t1=toc
t1 = 6.6905
gradX = zeros(N,N);
tic
C = w .* AXB;
C = reshape(C,1,[]);
for p = 1:N
Ap = A(:,p);
for q = 1:N
AB = Ap * B(q,:);
AB = reshape(AB,1,[]);
gradX(p,q) = C * AB';
end
end
t2=toc
t2 = 1.0750
fprintf('New code version 1 is %g faster\n', t1/t2)
New code version 1 is 6.22383 faster

カテゴリ

Help Center および File ExchangeSupport Vector Machine Regression についてさらに検索

製品


リリース

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by