How to calculate sum(A .* (B * C), 'all') [Ed. actually sum(A .* log(B * C), 'all')] efficiently when A is sparse and B*C is full and large?
4 ビュー (過去 30 日間)
古いコメントを表示
I have three matrices, A of size [J,I], B of size [J,K], C of size [K,I]. A is a sparse matrix with more than 90% zeros, while B and C are positive double matrices. The typical values are J=1e5, I=1e4, K=50.
The problem is that B*C creates a full matrix of size [J,I], which leads to redundant memory usage because what I need is merely the elements (B * C)(find(A)). My current constraint is that I don't have enough memory for a full matrix of size [J,I]. I wonder if there's a smart way to avoid such unnecessary memory usage for calculating this specific expression?
I have tried coding B into a tall array, but error appears like "tall arrays are not allowed to contain sparse data" when .* is evaluted. I also tried coding A into a tall array using tall(full(A)), but that's not reasonable because I need to restore A in full matrix first, and A is in fact not "tall" at all. Another way I tried to reduce memory usage is to devide A and C into blocks and calculate this expression in part (using a for-loop). However, this is not efficient, and does not reach the goal of removing redundant memory usage.
Thanks in advance!
0 件のコメント
採用された回答
James Tursa
2021 年 11 月 1 日
編集済み: James Tursa
2021 年 11 月 9 日
Here is the straightforward mex code (i.e., no parallel sections) if you want to try it out. It computes the result directly in a loop without the need for large temporary memory allocations and data copying. You will need a supported C compiler installed. To compile it use the following at the command line:
mex sABC.c -R2018a
If you have an earlier version of MATLAB you can omit the -R2018a option.
To run it simply call as noted:
A = whatever
B = whatever
C = whatever
sABC(A,B,C)
The C source code:
/* File sABC.c
* sABC(A,B,C) returns sum(A.*log(B*C),'all')
*
* A = sparse real double MxN
* B = full real double MxK
* C = full real double KxN
*
* Programmer: James Tursa
* Date: 10/31/2021
*/
#include "mex.h"
#include <math.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double dot, result = 0.0;
mwSize M, K, N;
mwSize j, k, nrow;
double *A, *B, *C, *b, *c;
mwIndex *Air, *Ajc;
/* Argument checks */
if( nrhs != 3 ) {
mexErrMsgTxt("Need exactly three inputs.");
}
if( nlhs > 1 ) {
mexErrMsgTxt("Too many outputs.");
}
if( !mxIsDouble(prhs[0]) || !mxIsSparse(prhs[0]) || mxIsComplex(prhs[0]) ) {
mexErrMsgTxt("A must be real sparse double.");
}
if( !mxIsDouble(prhs[1]) || !mxIsDouble(prhs[2]) ||
mxIsSparse(prhs[1]) || mxIsSparse(prhs[2]) ||
mxIsComplex(prhs[1]) || mxIsComplex(prhs[2]) ) {
mexErrMsgTxt("B and C must be real full double matrices.");
}
if( mxGetNumberOfDimensions(prhs[1]) != 2 || mxGetNumberOfDimensions(prhs[2]) != 2 ) {
mexErrMsgTxt("B and C must be 2D.");
}
M = mxGetM(prhs[0]);
N = mxGetN(prhs[0]);
K = mxGetN(prhs[1]);
if( M != mxGetM(prhs[1]) ||
N != mxGetN(prhs[2]) ||
K != mxGetM(prhs[2]) ) {
mexErrMsgTxt("Dimensions not compatible.");
}
/* Calculate result, simple loop no parallel code */
Air = mxGetIr(prhs[0]);
Ajc = mxGetJc(prhs[0]);
A = (double *) mxGetData(prhs[0]);
B = (double *) mxGetData(prhs[1]);
C = (double *) mxGetData(prhs[2]);
for( j=0; j<N; j++ ) {
nrow = Ajc[j+1] - Ajc[j]; /* Number of row elements for this column */
while( nrow-- ) {
b = B + *Air++; /* B row pointer */
c = C + j*K; /* C column pointer */
dot = 0.0;
for( k=0; k<K; k++ ) { /* dot product of B row and C column */
dot += (*b) * (*c);
b += M;
c++;
}
result += *A++ * log(dot); /* Accumulate in result */
}
}
plhs[0] = mxCreateDoubleScalar(result);
}
その他の回答 (4 件)
Matt J
2021 年 10 月 30 日
編集済み: Matt J
2021 年 10 月 30 日
Yes I'm still calculating sum(A .* (log(B*C)),'all').
I would probably just break C down into a small number of chunks and loop, e.g.,
Cr=reshape(C,K,I/10,10);
Acell=mat2cell(A,J,ones(1,10)*I/10);
mysum=0;
for n=1:10
mysum=mysum+sum( Acell{n}.*log(B*Cr(:,:,n)) ,'all');
end
Another way I tried to reduce memory usage is to devide A and C into blocks and calculate this expression in part (using a for-loop). However, this is not efficient, and does not reach the goal of removing redundant memory usage.
I'm not sure why you conclude this is not efficient, but regardless, I don't think you're going to be able to avoid it (in the case where you have the log operation in there) unless there is some particular structure to the sparsity pattern in A that you haven't told us about.
It's important to remember that there is a lot of parallel computation happening in a matrix multiplication. When parallel computation is involved, the number of computations isn't necessarily the thing that dominates performance.
James Tursa
2021 年 10 月 29 日
編集済み: James Tursa
2021 年 10 月 29 日
You could use this loop to avoid the memory usage, but it will run slowly because of the data copying going on in the background for the values, row, and column extractions from the variables. This extra data copying could be avoided in a mex routine if you really needed to recover that speed.
[row,col,v] = find(A);
mysum = 0;
for k=1:numel(v)
mysum = mysum + v(k)*(B(row(k),:)*C(:,col(k)));
end
4 件のコメント
James Tursa
2021 年 10 月 31 日
Do you have a supported C/C++ compiler installed? The code for this would be pretty straightforward.
参考
カテゴリ
Help Center および File Exchange で Performance and Memory についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!