Locate index and weights for linear interpolation for a monotonically increasing vector

13 ビュー (過去 30 日間)
Fredrik P
Fredrik P 2022 年 3 月 22 日
編集済み: Fredrik P 2022 年 3 月 22 日
I have a performance critical function that amounts to finding all that is needed to interpolate a point (or a vector of points) to a space given by a discrete grid but without doing actually interpolation. That code is very fast (about three times faster than the function I'm posting here), but here I'm trying to approach the problem from a different angle to see if the fact that both the vector of values and the vector of grid points will be monotonically increasing might be used to boost performance. Can you see anything that I might do to improve performance?
The timings I have currently are these
values = sort(rand(1, 100))';
gridpoints = linspace(0.1, 0.9, 60)';
[timeit(@() locateVector(values, gridpoints)), timeit(@() locateVector(values, gridpoints, "binarySearch", true))]
% ans =
%
% 1.0e-03 *
%
% 0.1004 0.0957
and locateVector.m looks like this.
function [indices, weights] = locateVector(values, gridpoints, NameValueArgs)
%LOCATEVECTOR Locate first node on grid below a given value.
%
% PURPOSE
% Find nodes on a grid and compute relative proximities of nodes' two
% most adjacent grid nodes
%
% INPUT
% x N * 1 monotonically increasing vector, numeric values
% to be located on grid
% X N * 1 (or 1 * N) monotonically increasing vector, grid
% NameValueArgs boolean name value pairs, 'locateBelow',
% 'locateAbove', or 'binarySearch'
%
% OUTPUT
% indices array sized as x, indices of first node from X below
% each element of x
% weights array sized [2, size(x)], relative proximity to
% adjacent nodes from X for each element of x
%
% EXAMPLES
% [indices, weights] = locate(x, X) for each element in x returns the
% index of the last node in X that is below that x (or 1 if the x is
% below each node in X or numel(X) - 1 if the x is above each node in X)
% as well as the relative proximities to that X node and the node above
% it (unless the x is below X(0) in which case the weights are [1, 0] or
% the x is above X(end) in which case the weights are [0, 1]).
%
% M-FILES required: none
%
% MAT-FILES required: none
arguments
values (:, 1) {mustBeNumeric, mustBeIncreasing}
gridpoints (:, 1) {mustBeNumeric, mustBeIncreasing}
NameValueArgs.locateBelow (1, 1) {mustBeLogical} = false
NameValueArgs.locateAbove (1, 1) {mustBeLogical} = false
NameValueArgs.binarySearch (1, 1) {mustBeLogical} = false
end
% Preallocate
indices = ... Indices of first node below (or 1 if no nodes below)
ones(length(values), 1);
weights = ... Relative proximity of the two neighboring nodes
nan(2, length(values));
low = 1;
high = length(gridpoints) - 1;
if NameValueArgs.binarySearch
searchFunction = @binarySearch;
else
searchFunction = @incrementalSearch;
end
for ix = 1:numel(values)
if values(ix) <= gridpoints(1)
% indices(ix) = 1; % No need as indices are initialized to 1
if ~NameValueArgs.locateBelow
weights(:, ix) = [1; 0];
end
elseif values(ix) >= gridpoints(end)
indices(ix:end) = high;
if ~NameValueArgs.locateAbove
weights(:, ix:end) = repmat( ...
[0; 1], ...
1, length(values) - ix + 1 ...
);
end
break;
elseif low == high % gridpoints(end - 1) < values(ix) < gridpoints(end)
indices(ix) = high;
else
indices(ix) = ...
searchFunction(low, high, values(ix), gridpoints);
low = indices(ix);
end
end
noWeights = isnan(weights(1, :));
weights(:, noWeights) = ( ...
[ ...
gridpoints(indices(noWeights) + 1) - values(noWeights), ...
values(noWeights) - gridpoints(indices(noWeights)) ...
] ./ ( ...
gridpoints(indices(noWeights) + 1) ...
- gridpoints(indices(noWeights)) ...
))';
end
function low = incrementalSearch(low, ~, value, gridpoints)
while value >= gridpoints(low + 1)
low = low + 1;
end
end
function low = binarySearch(low, high, value, gridpoints)
low = low - 1;
high = high + 1;
while low < high - 1
mid = floor((low + high) / 2);
if value < gridpoints(mid)
high = mid;
else
low = mid;
end
end
end
function mustBeIncreasing(a)
%MUSTBEINCREASING Validate that value is an increasing vector or issue error
%
% MUSTBEINCREASING(a) issues an error if a is not an monotonically increasing vector.
if ~isvector(a) || isscalar(a) || ~all(diff(a) > 0)
throwAsCaller(MException( ...
'Increasing:notIncreasing', ...
'Input must be a strictly increasing vector.' ...
));
end
end
function mustBeLogical(a)
%MUSTBELOGICAL Validate that value is of class logical or issue error
%
% MUSTBELOGICAL(a) issues an error if a isn't a logical.
if ~islogical(a)
throwAsCaller(MException( ...
'Logical:notLogical', ...
'Input must be logical.' ...
));
end
end
  2 件のコメント
John D'Errico
John D'Errico 2022 年 3 月 22 日
timeit(@() locateVector(values, gridpoints))
Unrecognized function or variable 'mustBeLogical'.
Fredrik P
Fredrik P 2022 年 3 月 22 日
編集済み: Fredrik P 2022 年 3 月 22 日
Sorry about the missing code! (Fixed now in my edit.) I forgot that mustBeLogical.m was something that I wrote myself---even though the documentation (also written by me) explicitly stated that it was...

サインインしてコメントする。

回答 (1 件)

John D'Errico
John D'Errico 2022 年 3 月 22 日
編集済み: John D'Errico 2022 年 3 月 22 日
I cannot compare your code to something meaningful because your code is not complete. It lacks mustBeLogical, and possibly other tools tool.
But I see that on my computer, using an existing tool like discretize to do the work, it tells me the time required was roughly a factor of 100 times faster than your code. Yes, my computer may be faster or slower than yours. But not by a factor of 100.
timeit(@() discretize(values, gridpoints))
ans =
1.1458e-06
Is there a good reason why you want to use SLOW code like that? Why you would not want to use an existing tool to solve the problem? Yes, your code produced weights. But that is a trivial one line computation, once you know where the points lie.
Edit:
I see you added the mustBeLogical code. So now I can run it as a comparison.
timeit(@() locateVector(values, gridpoints))
ans =
6.5403e-05
timeit(@() discretize(values, gridpoints))
ans =
1.0422e-06
God, your computer is really slow. Anyway, locateVector is a little worse than 6x slower then discretize.
You still have not said why you think you need to use that code instead of something better, that is already part of MATLAB. Computing linear interpolation weights is a one line expression. So, asking if there is something you can do to improve performance, the answwer is yes. Don't use that code.
  1 件のコメント
Fredrik P
Fredrik P 2022 年 3 月 22 日
編集済み: Fredrik P 2022 年 3 月 22 日
No, there isn't any particular reason to use slow code :-D My faster version actually uses histc, which at least on on my machines performs better than discretize, to locate the indices. Here, I wanted to see if I somehow could get more performance by using that the vectors are monotonically increasing.
Below is the timings with my implementation of what you suggested (further below), where I have left the arguments block unchanged so as to let that distort the comparison.
[timeit(@() locateVector(values, gridpoints)), timeit(@() locateVector(values, gridpoints, "binarySearch", true)), timeit(@() locate(values, gridpoints))]
% ans =
%
% 1.0e-03 *
%
% 0.0922 0.1345 0.0358
(My faster version clocked in a 1.0e-03 * 0.0287.)
function [indices, weights] = locate(values, gridpoints, NameValueArgs)
arguments
values (:, 1) {mustBeNumeric, mustBeIncreasing}
gridpoints (:, 1) {mustBeNumeric, mustBeIncreasing}
NameValueArgs.locateBelow (1, 1) {mustBeLogical} = false
NameValueArgs.locateAbove (1, 1) {mustBeLogical} = false
NameValueArgs.binarySearch (1, 1) {mustBeLogical} = false
end
% Preallocate
weights = ... Relative proximity of the two neighboring nodes
nan(2, length(values));
indices = discretize(values, gridpoints);
below = values < gridpoints(1);
indices(below) = 1;
above = values >= gridpoints(end);
indices(above) = length(gridpoints) - 1;
if ~NameValueArgs.locateBelow
weights(:, below) = repmat([1; 0], 1, sum(below));
end
if ~NameValueArgs.locateAbove
weights(:, above) = repmat([0; 1], 1, sum(above));
end
noWeights = isnan(weights(1, :));
weights(:, noWeights) = ( ...
[ ...
gridpoints(indices(noWeights) + 1) - values(noWeights), ...
values(noWeights) - gridpoints(indices(noWeights)) ...
] ./ ( ...
gridpoints(indices(noWeights) + 1) ...
- gridpoints(indices(noWeights)) ...
))';
end
function mustBeIncreasing(a)
%MUSTBEINCREASING Validate that value is an increasing vector or issue error
%
% MUSTBEINCREASING(a) issues an error if a is not an monotonically increasing vector.
if ~isvector(a) || isscalar(a) || ~all(diff(a) > 0)
throwAsCaller(MException( ...
'Increasing:notIncreasing', ...
'Input must be a strictly increasing vector.' ...
));
end
end
function mustBeLogical(a)
%MUSTBELOGICAL Validate that value is of class logical or issue error
%
% MUSTBELOGICAL(a) issues an error if a isn't a logical.
if ~islogical(a)
throwAsCaller(MException( ...
'Logical:notLogical', ...
'Input must be logical.' ...
));
end
end

サインインしてコメントする。

カテゴリ

Help Center および File ExchangeMatrices and Arrays についてさらに検索

製品


リリース

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by