Applying SHAP on a Reinforcement Learning Algorithm

I am trying to Apply SHAP on a reinforcement Learning Algorithm and I am not sure if MATLAB has the required SHAP packages such as shap.DeepExplainer() which is a python package.
If anyone has any further information on how to apply SHAP on the Neural Network agent of a Reinforcement Learning model, please let me know.

6 件のコメント

Ive J
Ive J 2023 年 6 月 5 日
have you looked at ?
doc shapley
Juliana Tavora
Juliana Tavora 2023 年 11 月 23 日
I have been wondering the same @Mahsa Raeisinezhad. Have you found further info? Documentation on shapley is not too clear and never mentions SHAP
Mahsa Raeisinezhad
Mahsa Raeisinezhad 2023 年 11 月 25 日
編集済み: Mahsa Raeisinezhad 2023 年 11 月 25 日
figure;
for p=1:d
scatter(shap(:,sortedPredictorIndicesMEAN(p)), ... % x-value of each point is the shapley value
p*ones(n,1), ... % y-value of each point is an integer corresponding to a predictor (to be jittered below)
[], ... % Marker size for each data point, taking the default here
normalize(table2array(tbl(1:n,sortedPredictorIndicesMEAN(p))),'range',[1 256]), ... % Colors based on feature values
'filled', ... % Fills the circles representing data points
'YJitter','density', ... % YJitter according to the density of the points in this row
'YJitterWidth',0.8)
if (p==1)
hold on;
end
end
title('Shapley Summary plot');
xlabel('Shapley Value (impact on model output)')
yticks([1:d]);
yticklabels(tbl.Properties.VariableNames(sortedPredictorIndicesMEAN));
% Set colormap as desired
colormap(CoolBlueToWarmRedColormap); % This colormap is like the one used in many Shapley summary plots
% colormap(parula); % This is the default colormap
cb= colorbar('Ticks', [1 256], 'TickLabels', {'Low', 'High'});
cb.Label.String = "Scaled Feature Value";
cb.Label.FontSize = 12;
cb.Label.Rotation = 270;
set(gca, 'YGrid', 'on');
xline(0, 'LineWidth', 1);
hold off;
%%
function colormap = CoolBlueToWarmRedColormap()
% Define start point, middle luminance, and end point in L*ch colorspace
% https://www.mathworks.com/help/images/device-independent-color-spaces.html
% The three components of L*ch are Luminance, chroma, and hue.
blue_lch = [54 70 4.6588]; % Starting blue point
l_mid = 40; % luminance of the midpoint
red_lch = [54 90 6.6378909]; % Ending red point
nsteps = 256;
% Build matrix of L*ch colors that is nsteps x 3 in size
% Luminance changes linearly from start to middle, and middle to end.
% Chroma and hue change linearly from start to end.
lch=[[linspace(blue_lch(1), l_mid, nsteps/2), linspace(l_mid, red_lch(1), nsteps/2)]', ... luminance column
[linspace(blue_lch(2), red_lch(2), nsteps)]', ... chroma column
[linspace(blue_lch(3), red_lch(3), nsteps)]']; ... hue column
% Convert L*ch to L*a*b, where a = c * cos(h) and b = c * sin(h)
lab=[lch(:,1) lch(:,2).*cos(lch(:,3)) lch(:,2).*sin(lch(:,3))];
% Convert L*a*b to RGB
colormap=lab2rgb(lab,'OutputType','uint8');
Mahsa Raeisinezhad
Mahsa Raeisinezhad 2023 年 11 月 25 日
myAct = @(env) predict_01(pretrainedAgent, env);
% number of itereations
shap=zeros(n,d);
figure;
hold on;
for i = 1:n
explainer = shapley(myAct, env);
explainer = fit(explainer, env);
shap(i,:)=explainer.ShapleyValues{:,2};
plot(explainer)
explainer_{i} = explainer;
end
Mahsa Raeisinezhad
Mahsa Raeisinezhad 2023 年 11 月 25 日
function myAct_ = predict_01(tbl, env)
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
agent = rlPPOAgent(obsInfo,actInfo);
actor = getActor(agent);
actorNet = getModel(actor);
% Get the i-th slice of transposedArray
dataArray = table2array(tbl);
dlObservationsF = dlarray(dataArray, 'C');
% dlObservationsF = dlarray(dataArray(i,:), 'C');
% Predict using the actor network
myAct = predict(actorNet, dlObservationsF);
myAct_ = extractdata(myAct);
myAct_ = max(myAct);
end
Mahsa Raeisinezhad
Mahsa Raeisinezhad 2023 年 11 月 25 日
The above is how tried to use SHAP in Matlab, I created a function handle of my Neural Network (agent) predictions using the Environment and applied SHAP on each decision, but I still highly recommend transfering to Python and using Python shap packages.

サインインしてコメントする。

回答 (2 件)

Mahsa Raeisinezhad
Mahsa Raeisinezhad 2023 年 11 月 23 日

0 投票

I decided to transfer everything in Python and use python packages. I used ONNX and Tensorflow for transferring everything. Hopefully if I have time in the future I write my own code to create same outcomes in Matlab.

1 件のコメント

Drew
Drew 2025 年 3 月 5 日
Starting in R2024a, MATLAB has new functionality to more easily create shapley summary plots. This is described in the release notes https://www.mathworks.com/help/releases/R2024a/stats/release-notes.html.
To see an example on MATLAB answers, see:

サインインしてコメントする。

Ive J
Ive J 2023 年 11 月 23 日

0 投票

Document and refs are clear enough if you're aware enough what your intentions are. Follow this example if you're interested to still stick to MATLAB.

カテゴリ

ヘルプ センター および File ExchangeReinforcement Learning Toolbox についてさらに検索

質問済み:

2023 年 6 月 5 日

コメント済み:

2025 年 3 月 5 日

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by