How to use fmincon to fit parameters to observed data?

29 ビュー (過去 30 日間)
Jacob Elder
Jacob Elder 2019 年 7 月 19 日
コメント済み: Jacob Elder 2019 年 7 月 20 日
Hello,
I am trying to fit the following function to observed data of y output:
function model = TDFit(param,input)
alpha = param(1);
g = param(2);
stim = input(:,1);
r = input(:,2);
model = [];
X = [];
for j=1:length(stim)
if stim(j) == 1
X(j, :) = [1 0 0 0 0];
elseif stim(j) == 2
X(j, :) = [0 1 0 0 0];
elseif stim(j) == 3
X(j, :) = [0 0 1 0 0];
elseif stim(j) == 4
X(j, :) = [0 0 0 1 0];
elseif stim(j) == 5
X(j, :) = [0 0 0 0 1];
end
end
% initialization
[N,D] = size(X);
w = zeros(D,1); % weights
X = [X; zeros(1,D)]; % add buffer at end
% run Kalman filter
for n = 1:N
h = X(n,:) - g*X(n+1,:); % temporal difference features
V = X(n,:)*w; % value estimate
rhat = h*w; % reward prediction
dt = r(n) - rhat; % prediction error
w = w + alpha*dt*h'; % weight update
model = [model ; V];
end
end
I have previously been using lsqcurvefit to fit the parameters from xdata (input data) and ydata (observed output data).
% import dataframe
TempDf = importfile('fullDf.csv')
%get unique subject IDs
uIds = unique(TempDf.subID);
%get unique cluster #s
uClust = unique(TempDf.clustType);
%remove NaN data-- NaN data causes NaNs for all subsequent trials
TempDf( any(ismissing(TempDf.selfRespT1),2), :) = [];
%initialize dataframe for parame
%paramDf = [];
%initialize data structure for TD output
fullOutput = struct('N', {}, 'trialInClust', {}, 'trialNum', {}, 'w1',{},'w2',{},'w3',{},'w4',{},'w5',{},'dt',{}, 'rhat',{}, 'V', {}, 'subID', {});
%loop through total amount of subjects
for k=1:length(uIds)
%select # for ID of current subject
h = uIds(k);
%subset dataframe for current subject
subDf = TempDf(TempDf.subID==h, :);
% ### Parameter fitting for current subject ### %
% column array for states/predictions (e.g. self evaluations)
selfEval = table2array(subDf(:,{'selfRespT1'}));
% column array for rewards (e.g. feedback)
rewards = table2array(subDf(:,{'feedback'}));
% column array for cluster cues (probability)
stimulus = table2array(subDf(:,{'clustType'}));
% data to fit parameters to
xData = [stimulus rewards];
yData = selfEval;
% assign function to ft for parameter fitting
fun = @(param,xData)TDFit(param,xData);
% start point values for alpha/learning rate and discount factor
% respectively
x0 = [0.35 0.45];
% fit parameters
[x,resnorm,residual] = lsqcurvefit(fun,x0,xData,yData,[0 0],[1 1]);
% append current subject parameters to matrix
paramDf(k) = struct('subID',h,'alpha',x(1),'discount',x(2),'resnorm',resnorm);
%paramDf = [paramDf; h, x(1), x(2)];
% ### Finished parameter fitting for current subject ### %
% subset only relevant columns for TD function
subDf = subDf(:,[5,8,14]);
% apply TD function to relevant columns:
% trial total; self evaluation; feedback; learning rate; discount
TD_out = TD(subDf{:,1}, subDf{:,3}, subDf{:,2}, x(1), x(2));
% repeat subID for one column for length of output
C = num2cell(repelem(h, height(subDf)));
% merge subID column with data structure for TD output
[TD_out.subID] = C{:};
% append TD output to other subjects' data
fullOutput = [fullOutput, TD_out];
end
% remove 'N' for trials in cluster; only relevant for TD script looping
% through clusters within participants
%fullOutput = rmfield(fullOutput, 'N');
% write to csv
writetable(struct2table(paramDf),'parameters.csv','Delimiter',',','QuoteStrings',true);
writetable(struct2table(fullOutput),'TD_outputPtp.csv','Delimiter',',','QuoteStrings',true);
However, my impression is that for reinforcement learning parameter optimization, you want to search for the optimal parameters by minimizing the function, and the search is applied to the negative log likelihood. So I have read that I should be using fmincon or fminsearch instead. However, if I use those, it is not as clear to me how to fit the function to observed output data. While lsqcurvefit allows you to insert a parameter for the output data, fmincon only has the fun, x0, constraints, and boundaries.
How do I do basically what I am doing above, except using fmincon to fit to observed data (y output data)?

採用された回答

Walter Roberson
Walter Roberson 2019 年 7 月 19 日
For this kind of fitting, lsqcurvefit() is typically much faster. However, for nonlinear models, it is common for lsqcurvefit() to miss the basin of attraction for the best fit, and to be entirely happy with the results it produces. (If you are looking for the red-est bird in the yard and there are robbins in plain sight, you might be 99% confident that you have found the red-est bird, having entirely failed to look around the corner to the small niche where the cardinals nest.)
Now, in such a case where lsqcurvefit() gives widely wrong solutions compared to what could be done, it is because there are multiple local minima. And that is a problem for fmincon(), as fmincon() gets stuck in local minima. fmincon() is not a global optimizer. Where fmincon() has advantages is if there are constraints that rule out false minima, then in some useful cases, fmincon() can find the correct minima.
fminsearch() is a non-local minimizer: it gets stuck in local minima less than fmincon() does. It isn't really a global minimizer, though: it can get itself stuck in local minima too. And fminsearch() is quite weak on handling constraints: the most you can do, by hacking around a bit, is to get fminsearch() to stop searching when a boundary is reached, with no ability to bounce off of a constraint.
MATLAB does not have any algorithms that can guarantee global minima (it is possible to prove that no universal global minimizer is possible): MATLAB just has algorithms that use strategies to continue searching.
ga() does not get stuck in local minima in the same way (but you can construct surfaces that it does not do well on.)
In each case you would parameterize the objective function, passing in the data to be compared against.
  5 件のコメント
Walter Roberson
Walter Roberson 2019 年 7 月 20 日
I have never worked with log likelihood, sorry.
Jacob Elder
Jacob Elder 2019 年 7 月 20 日
No problem. You were more than helpful. Thank you again.

サインインしてコメントする。

その他の回答 (0 件)

カテゴリ

Help Center および File ExchangeSolver Outputs and Iterative Display についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by