How to calculate sum(A .* (B * C), 'all') [Ed. actually sum(A .* log(B * C), 'all')] efficiently when A is sparse and B*C is full and large?

4 ビュー (過去 30 日間)
Wenyu Zhang
Wenyu Zhang 2021 年 10 月 29 日
編集済み: James Tursa 2021 年 11 月 9 日
I have three matrices, A of size [J,I], B of size [J,K], C of size [K,I]. A is a sparse matrix with more than 90% zeros, while B and C are positive double matrices. The typical values are J=1e5, I=1e4, K=50.
The problem is that B*C creates a full matrix of size [J,I], which leads to redundant memory usage because what I need is merely the elements (B * C)(find(A)). My current constraint is that I don't have enough memory for a full matrix of size [J,I]. I wonder if there's a smart way to avoid such unnecessary memory usage for calculating this specific expression?
I have tried coding B into a tall array, but error appears like "tall arrays are not allowed to contain sparse data" when .* is evaluted. I also tried coding A into a tall array using tall(full(A)), but that's not reasonable because I need to restore A in full matrix first, and A is in fact not "tall" at all. Another way I tried to reduce memory usage is to devide A and C into blocks and calculate this expression in part (using a for-loop). However, this is not efficient, and does not reach the goal of removing redundant memory usage.
Thanks in advance!

採用された回答

James Tursa
James Tursa 2021 年 11 月 1 日
編集済み: James Tursa 2021 年 11 月 9 日
Here is the straightforward mex code (i.e., no parallel sections) if you want to try it out. It computes the result directly in a loop without the need for large temporary memory allocations and data copying. You will need a supported C compiler installed. To compile it use the following at the command line:
mex sABC.c -R2018a
If you have an earlier version of MATLAB you can omit the -R2018a option.
To run it simply call as noted:
A = whatever
B = whatever
C = whatever
sABC(A,B,C)
The C source code:
/* File sABC.c
* sABC(A,B,C) returns sum(A.*log(B*C),'all')
*
* A = sparse real double MxN
* B = full real double MxK
* C = full real double KxN
*
* Programmer: James Tursa
* Date: 10/31/2021
*/
#include "mex.h"
#include <math.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double dot, result = 0.0;
mwSize M, K, N;
mwSize j, k, nrow;
double *A, *B, *C, *b, *c;
mwIndex *Air, *Ajc;
/* Argument checks */
if( nrhs != 3 ) {
mexErrMsgTxt("Need exactly three inputs.");
}
if( nlhs > 1 ) {
mexErrMsgTxt("Too many outputs.");
}
if( !mxIsDouble(prhs[0]) || !mxIsSparse(prhs[0]) || mxIsComplex(prhs[0]) ) {
mexErrMsgTxt("A must be real sparse double.");
}
if( !mxIsDouble(prhs[1]) || !mxIsDouble(prhs[2]) ||
mxIsSparse(prhs[1]) || mxIsSparse(prhs[2]) ||
mxIsComplex(prhs[1]) || mxIsComplex(prhs[2]) ) {
mexErrMsgTxt("B and C must be real full double matrices.");
}
if( mxGetNumberOfDimensions(prhs[1]) != 2 || mxGetNumberOfDimensions(prhs[2]) != 2 ) {
mexErrMsgTxt("B and C must be 2D.");
}
M = mxGetM(prhs[0]);
N = mxGetN(prhs[0]);
K = mxGetN(prhs[1]);
if( M != mxGetM(prhs[1]) ||
N != mxGetN(prhs[2]) ||
K != mxGetM(prhs[2]) ) {
mexErrMsgTxt("Dimensions not compatible.");
}
/* Calculate result, simple loop no parallel code */
Air = mxGetIr(prhs[0]);
Ajc = mxGetJc(prhs[0]);
A = (double *) mxGetData(prhs[0]);
B = (double *) mxGetData(prhs[1]);
C = (double *) mxGetData(prhs[2]);
for( j=0; j<N; j++ ) {
nrow = Ajc[j+1] - Ajc[j]; /* Number of row elements for this column */
while( nrow-- ) {
b = B + *Air++; /* B row pointer */
c = C + j*K; /* C column pointer */
dot = 0.0;
for( k=0; k<K; k++ ) { /* dot product of B row and C column */
dot += (*b) * (*c);
b += M;
c++;
}
result += *A++ * log(dot); /* Accumulate in result */
}
}
plhs[0] = mxCreateDoubleScalar(result);
}
  1 件のコメント
Wenyu Zhang
Wenyu Zhang 2021 年 11 月 1 日
編集済み: Wenyu Zhang 2021 年 11 月 1 日
Wow, it's amazing that this mex code even without parallel computing is faster than my other trials! Thank you for providing such a sample!

サインインしてコメントする。

その他の回答 (4 件)

Matt J
Matt J 2021 年 11 月 1 日
編集済み: Matt J 2021 年 11 月 1 日
One more way,
[kj,ki,a]=find(A);
C=C.';
accum=0;
for n=1:K
accum=accum+ B(kj,n).*C(ki,n);
end
result=log(accum).'*a;

Matt J
Matt J 2021 年 10 月 29 日
編集済み: Matt J 2021 年 10 月 30 日
Use the equivalent expression,
sum((B.'*A).*C,'all')
  8 件のコメント
Matt J
Matt J 2021 年 11 月 1 日
編集済み: Matt J 2021 年 11 月 2 日
May I ask if there's any equivalent expression of sum(v .* (B * C), 'all') when v is a 1xI dense vector?
sum(B,1)*C*v.'
Wenyu Zhang
Wenyu Zhang 2021 年 11 月 2 日
編集済み: Wenyu Zhang 2021 年 11 月 2 日
Thank you very much! It's not hard to prove that sum(B,1)*C*v' is equivalent to sum(v.*(B*C),'all'). Your expression is not only more elegant but also faster than mine.

サインインしてコメントする。


Matt J
Matt J 2021 年 10 月 30 日
編集済み: Matt J 2021 年 10 月 30 日
Yes I'm still calculating sum(A .* (log(B*C)),'all').
I would probably just break C down into a small number of chunks and loop, e.g.,
Cr=reshape(C,K,I/10,10);
Acell=mat2cell(A,J,ones(1,10)*I/10);
mysum=0;
for n=1:10
mysum=mysum+sum( Acell{n}.*log(B*Cr(:,:,n)) ,'all');
end
Another way I tried to reduce memory usage is to devide A and C into blocks and calculate this expression in part (using a for-loop). However, this is not efficient, and does not reach the goal of removing redundant memory usage.
I'm not sure why you conclude this is not efficient, but regardless, I don't think you're going to be able to avoid it (in the case where you have the log operation in there) unless there is some particular structure to the sparsity pattern in A that you haven't told us about.
It's important to remember that there is a lot of parallel computation happening in a matrix multiplication. When parallel computation is involved, the number of computations isn't necessarily the thing that dominates performance.
  5 件のコメント
Wenyu Zhang
Wenyu Zhang 2021 年 11 月 1 日
In general, my I is not exactly 1e4. May I ask if there exists an elegant way to reshape the matrix when J/I is not an integer? The solution I could think of is to make Cr a cell like Acell.
Matt J
Matt J 2021 年 11 月 1 日
編集済み: Matt J 2021 年 11 月 2 日
The solution I could think of is to make Cr a cell like Acell.
Yes, that would be the way. You can use mat2tiles in the File Exchange
Acell=mat2tiles(A,[inf,1e3]);
Ccell=mat2tiles(C,[inf,1e3]);

サインインしてコメントする。


James Tursa
James Tursa 2021 年 10 月 29 日
編集済み: James Tursa 2021 年 10 月 29 日
You could use this loop to avoid the memory usage, but it will run slowly because of the data copying going on in the background for the values, row, and column extractions from the variables. This extra data copying could be avoided in a mex routine if you really needed to recover that speed.
[row,col,v] = find(A);
mysum = 0;
for k=1:numel(v)
mysum = mysum + v(k)*(B(row(k),:)*C(:,col(k)));
end
  4 件のコメント
James Tursa
James Tursa 2021 年 10 月 31 日
Do you have a supported C/C++ compiler installed? The code for this would be pretty straightforward.
Wenyu Zhang
Wenyu Zhang 2021 年 10 月 31 日
Yes I have a supported C++ compiler. But I do not have enough basic knowledge about the mex routine. It may take some time for me to get started.

サインインしてコメントする。

カテゴリ

Help Center および File ExchangePerformance and Memory についてさらに検索

タグ

製品


リリース

R2019a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by