Improve speed of linear interpolation in nested loops

9 ビュー (過去 30 日間)
Alessandro D
Alessandro D 2022 年 7 月 24 日
コメント済み: Alessandro D 2022 年 8 月 30 日
I have to do 1-dimensional linear interpolation many times within 4 nested loops. My X-grid is sorted so I can use interp1q but the code is still slow for my purposes. I managed to do a simple vectorization that eliminates the innermost loop (so I have only 3 loops instead of 4) and it's much faster, but unfortunately still not fast enough for my problem. Any suggestions on how to improve speed? Thanks
I report below a MWE (please, bear in mind that in my real problem the grids are larger)
clear
clc
close all
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
%b_gridp(:,k_c) = linspace(b_min,b_max,nb)'; %EDITED HERE
b_gridp(:,k_c) = linspace(b_min+rand,b_max-rand,nb)';
end
%% Slow, not vectorized code
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 15.150729 seconds.
%% This is a faster but not fast enough!
tic
stay_arr2 = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,:,knext_ind); % dim: (nb,nx)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % dim is (1,nx')
dexit = min(max(dexit_inter,0),1); % dim is (1,nx')
stay_arr2(:,k_c,b_c,x_c) = 1-dexit;
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 0.717499 seconds.
err = max(abs(stay_arr-stay_arr2),[],'all')
err = 0
  5 件のコメント
Alessandro D
Alessandro D 2022 年 7 月 25 日
編集済み: Alessandro D 2022 年 7 月 25 日
Thanks for your suggestion! I think there are two different issues here: (1) make interpolation faster using the fact that the grid is equally spaced and (2) vectorize the code. I created a small function, called find_loc_equi, to generate bins and weights for interpolation when the grid is equally spaced. I attach below the new code. The speed improvement is however minimal. What I don't know is how to vectorize more the code: e.g. is it possible to remove the inner loop over k_c?
clear
clc
close all
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
b_gridp(:,k_c) = linspace(b_min,b_max,nb)';
end
disp('start code')
start code
%% Non-vectorized code with interp1q
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 14.475729 seconds.
%% Vectorized code with faster bin location
tic
stay_arr2 = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital CAN I ELIMINATE THIS LOOP?
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,:,knext_ind); % dim: (nb,nx)
[left_loc,weights] = find_loc_equi(b_gridp(:,knext_ind),bnext); % scalars
dexit_inter = weights*pol_exit_bx(left_loc,:)+(1-weights)*pol_exit_bx(left_loc+1,:); % dim is (1,nx')
dexit = min(max(dexit_inter,0),1); % dim is (1,nx')
stay_arr2(:,k_c,b_c,x_c) = 1-dexit;
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 1.136634 seconds.
err2 = max(abs(stay_arr-stay_arr2),[],'all')
err2 = 1.5765e-14
function [bins,weights] = find_loc_equi(b_grid,bnext)
% Find the left interpolating node and the corresponding weight
% Assumptions: b_grid is equally spaced
% Thanks to Walter Roberson
n = size(b_grid,1);
deltas = b_grid(2,:)-b_grid(1,:);
initial_values = b_grid(1,:);
frac = (bnext-initial_values)./deltas;
bins = floor(frac)+1;
bins = min(n-1,max(1,bins));
weight_right = mod(frac,1);
weights = 1-weight_right;
weights(bnext>=b_grid(n)) = 0 ;
weights(bnext<=b_grid(1)) = 1 ;
end
Alessandro D
Alessandro D 2022 年 8 月 30 日
The last two lines of the function find_loc_equi should be as:
weights(bnext>=b_grid(n,:)) = 0 ;
weights(bnext<=b_grid(1,:)) = 1 ;

サインインしてコメントする。

採用された回答

Bruno Luong
Bruno Luong 2022 年 7 月 25 日
編集済み: Bruno Luong 2022 年 7 月 25 日
This seems to work
clear
clc
close all
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
b_gridp(:,k_c) = linspace(b_min,b_max,nb)';
end
disp('start code')
start code
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 15.097401 seconds.
%% Full vectorized code
tic
bgridcommon = b_gridp(:,1);
Y = interp1(bgridcommon,(1:nb)',pol_debt); % nk x nb x nx
Yt = max(min(Y,nb-1),1); % no need if there is no overflowed in the data
I = floor(Yt); % nk x nb x nx
W = Y-I;
[I,J]=ndgrid(I,1:nx); % (nk x nb x nx) x nx
K = repmat(pol_kp_ind,[1 1 1 nx]);
K = reshape(K,size(I));
rhsilin = sub2ind(size(pol_exitp),I,J,K); % (nk x nb x nx) x nx;
rhsilin = reshape(rhsilin, [nk,nb,nx,nx]);
dexit_inter = (1-W).*pol_exitp(rhsilin) + W.*pol_exitp(rhsilin+1);
dexit_inter = permute(dexit_inter, [4 1 2 3]); % [nx,nk,nb,nx]
dexit = min(max(dexit_inter,0),1);
stay_arr2 = 1-dexit;
toc
Elapsed time is 0.191299 seconds.
err = norm(stay_arr2(:)-stay_arr(:),Inf)
err = 8.6597e-15
  4 件のコメント
Bruno Luong
Bruno Luong 2022 年 7 月 26 日
編集済み: Bruno Luong 2022 年 7 月 26 日
Sorry forget my comment above about loop. The bin interval is not the first index. Here is the code corrected that works for variable bin vectors.
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
b_gridp(:,k_c) = linspace(b_min-rand(),b_max+rand(),nb)';
end
disp('start code')
start code
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
Elapsed time is 15.051721 seconds.
%% Full vectorized code
tic
K = pol_kp_ind;
bminK = reshape(b_gridp(1,K),size(K));
bmaxK = reshape(b_gridp(nb,K),size(K));
Y = 1 + (nb-1) * (pol_debt - bminK) ./ (bmaxK-bminK);
Yt = max(min(Y,nb-1),1); % no need if there is no overflowed in the data
I = floor(Yt); % nk x nb x nx
W = Y-I;
[I,J]=ndgrid(I,1:nx); % (nk x nb x nx) x nx
K = reshape(repmat(K,[1 1 1 nx]),size(I));
rhsilin = sub2ind(size(pol_exitp),I,J,K); % (nk x nb x nx) x nx;
rhsilin = reshape(rhsilin, [nk,nb,nx,nx]);
dexit_inter = (1-W).*pol_exitp(rhsilin) + W.*pol_exitp(rhsilin+1);
dexit_inter = permute(dexit_inter, [4 1 2 3]); % [nx,nk,nb,nx]
dexit = min(max(dexit_inter,0),1);
stay_arr2 = 1-dexit;
toc
Elapsed time is 0.166581 seconds.
err = norm(stay_arr2(:)-stay_arr(:),Inf)
err = 1.7542e-14
Alessandro D
Alessandro D 2022 年 7 月 27 日
Thanks, this works and is significantly faster!

サインインしてコメントする。

その他の回答 (0 件)

カテゴリ

Help Center および File ExchangePerformance and Memory についてさらに検索

製品


リリース

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by