MATLAB Answers

Cross-Entropy Minimization - Extreme Code Performance

4 ビュー (過去 30 日間)
Tommaso Belluzzo
Tommaso Belluzzo 2020 年 7 月 15 日
編集済み: Tommaso Belluzzo 2020 年 7 月 17 日
I'm working on a multivariate cross-entropy minimization model (for more details about it, see this paper, pp. 32-33). It's purpose is to adjust a prior multivariate distribution (in this case, a gaussian normal) with information on marginals coming from real observations.
The code at the end of the post represents my current implementation. The maths should have been correctly reproduced, unless I missed something critical. The real problem I'm struggling to deal with is the performance of the code.
In the first part of the model, cumulative probabilities have to be computed over all the orthants of the distribution density. This process has a time complexity of 2^N, where N is the number of entities included into the dataset. As long as the number of entities is less than 12, everything is fast enough on my PC. With 20 entities, which is my current target, the model needs to run mvncdf over 1048576 combinations of orthants and this takes forever to finish.
I already improved the code replacing for loops with parfor loops. I'm planning to improve it even further by replacing the built-in mvncdf function with a faster implementation (signally, this one) and by creating a micro-cache for mvncdf function calls (when filling bounds_2, the code actually recalculates cumulative probabilities that have been already computed for bounds_1).
I'm not very familiar with cross-entropy minimization models, so maybe there are math tricks I can use to simplify this calculation. Maybe the code can be vectorized even more. Well... any help or suggestion to improve the calculations speed is more than welcome!
clc();
clear();
% DATA
pods = [0.015; 0.02; 0.013; 0.007; 0.054; 0.034; 0.009; 0.065; 0.029; 0.205];
dts = [2.1; 2; 2.2; 2.4; 1.5; 1.8; 2.3; 1.5; 1.8; 0.8];
% Test of time complexity:
% pods = [pods; pods];
% dts = [dts; dts];
n = numel(pods);
c = eye(n);
% G / BOUNDS FOR 1
g1 = combn([0 1],n);
g1_s = size(g1,1);
bounds_1 = zeros(g1_s,1);
parfor i = 1:g1_s
interval = zeros(n,2);
for j = 1:n
if (g1(i,j) == 0)
interval(j,:) = [-Inf dts(j)];
else
interval(j,:) = [dts(j) Inf];
end
end
bounds_1(i) = mvncdf(interval(:,1),interval(:,2),0,c);
end
% G / BOUNDS FOR 2:N
gaux = combn([0 1],n-1);
g2 = cell(n,1);
g2_s = g1_s / 2;
bounds_2 = zeros(n,g2_s);
for k = 1:n
g2_k = [gaux(:,1:(k-1)) ones(size(gaux,1),1) gaux(:,k:end)];
parfor i = 1:g2_s
interval = zeros(n,2);
for j = 1:n
if (g2_k(i,j) == 0)
interval(j,:) = [-Inf dts(j)];
else
interval(j,:) = [dts(j) Inf];
end
end
bounds_2(k,i) = mvncdf(interval(:,1),interval(:,2),0,c);
end
g2{k} = g2_k;
end
% SOLUTION
options = optimset(optimset(@fsolve),'Display','iter','TolFun',1e-08,'TolX',1e-08);
cns = [1; pods];
x0 = zeros(size(pods,1)+1,1);
lm = fsolve(@(x)objective(x,n,g1,bounds_1,g2,bounds_2,cns),x0,options);
stop = 1;
% Objective function of the model.
function lm = objective(x,n,g1,bounds_1,g2,bounds_2,cns)
mu = x(1);
lambda = x(2:end);
lm = zeros(n + 1,1);
for i = 1:numel(bounds_1)
lm(1) = lm(1) + exp(-g1(i,:) * lambda) * bounds_1(i);
end
for i = 1:n
g2_k = g2{i,1};
for j = 1:size(bounds_2,2)
lm(i+1) = lm(i+1) + exp(-g2_k(j,:) * lambda) * bounds_2(i,j);
end
end
lm = (exp(-1-mu) * lm) - cns;
end
% All combinations of elements.
function [m,i] = combn(v,n)
if ((fix(n) ~= n) || (n < 1) || (numel(n) ~= 1))
error('Parameter N must be a scalar positive integer.');
end
if (isempty(v))
m = [];
i = [];
elseif (n == 1)
m = v(:);
i = (1:numel(v)).';
else
i = combn_local(1:numel(v),n);
m = v(i);
end
function y = combn_local(v,n)
if (n > 1)
[y{n:-1:1}] = ndgrid(v);
y = reshape(cat(n+1,y{:}),[],n);
else
y = v(:);
end
end
end

  0 件のコメント

サインインしてコメントする。

回答 (0 件)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by