How to improve speed of calculating trace in a script?

3 ビュー (過去 30 日間)
Xiaohan Du
Xiaohan Du 2018 年 4 月 3 日
コメント済み: Xiaohan Du 2018 年 4 月 7 日
Hi all,
In my project I have to calculate the trace of some matrix products, I have the following script to demonstrate the purpose:
clear; clc;
% number of total tests.
nTest = 500;
% part 1. generate nTest*2 random matrices.
nd = 1000;
nt = 100;
dis = cell(nTest, 2);
dis = cellfun(@(v) rand(nd, nt), dis, 'un', 0);
% part 2. perform truncated-SVD on each matrix,
% only leave nRem singular vectors and values.
nRem = 50;
disSVD = cell(nTest, 2);
for isvd = 1:nTest
for jsvd = 1:2
[u, s, v] = svd(dis{isvd, jsvd}, 0);
disSVD{isvd, jsvd} = {u(:, 1:nRem), s(1:nRem, 1:nRem), v(:, 1:nRem)};
end
end
% part 3. for each SVD result, perform trace to obtain disTrans. disTrans is
% non-symmetric, thus jtr needs to start from 1.
disTrans = zeros(nTest);
for itr = 1:nTest
u1 = disSVD{itr, 1};
for jtr = 1:nTest
u2 = disSVD{jtr, 2};
disTrans(itr, jtr) = ...
trace(u1{3} * u1{2}' * u1{1}' * u2{1} * u2{2} * u2{3}');
end
end
I ran profile to find out which part is the slowest, it turns out it's calculating the trace in part 3 due to the large number. Unfortunately in my project the number of calculating trace is also very large. So any idea of how to improve the speed of calculating the trace? The profile is shown here:
Many thanks!
  8 件のコメント
Steven Lord
Steven Lord 2018 年 4 月 3 日
Note that built-in functions like svd, cellfun, rand, etc. don't show up in your Profiler report but they do take time. If you open the entry for testuiTujSortImprove I believe you'll see a line in the Children table where the time spent in built-in functions will be reported, and given how many times you called svd in particular I think that will account for the "missing" time.
Xiaohan Du
Xiaohan Du 2018 年 4 月 4 日
I got it, opening the entry of testuiTujSortImprove shows:
It seems that the operations inside trace, i.e.
u1{3} * u1{2}' * u1{1}' * u2{1} * u2{2} * u2{3}'
cost the majority of time.

サインインしてコメントする。

採用された回答

Christine Tobler
Christine Tobler 2018 年 4 月 3 日
You can make the trace operator work faster as follows: Currently, the input is two truncated SVDs, A1 = U1 * S1 * V1' and A2 = U2 * S2 * V2', and you are computing
trace(A1'*A2)
correct? After inserting A1 and A2, you can use the property of trace that trace(A*B) == trace(B*A) (note that trace(A*B*C) ~= trace(A*C*B), see wikipedia).
So this means that you can rearrange
trace(V1*S1'*U1'*U2*S2*V2') == trace( (V2'*V1) * S1 * (U1'*U2) * S2)
Make sure that the parentheses are set like this, and all other operations are on nRem-by-nRem matrices.
By the way, you can also rewrite trace(A'*B) as sum(sum(A.*B)), but I'm not sure if this will give you a speedup for this case.
  3 件のコメント
Christine Tobler
Christine Tobler 2018 年 4 月 4 日
Trace is a very quick operation, compared with the matrix multiplications, so the larger part of the speed-up is probably about those multiplications. The trace call itself should also become a bit faster, because it now acts on nRem-by-nRem matrices, instead of nTest-by-nTest matrices.
You can try taking the function calls apart:
M3 = u2{3}' * u1{3};
M1 = u1{1}' * u2{1};
M = M3 * u1{2} * M1 * u2{2};
trace(M);
This way, the profiler will tell you how much time is spent in each line, and you can see how much time the trace call takes.
Xiaohan Du
Xiaohan Du 2018 年 4 月 7 日
Thanks Chris!

サインインしてコメントする。

その他の回答 (0 件)

カテゴリ

Help Center および File ExchangeMATLAB についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by