How to get shapley value for Neural Network trained on matlab? it keeps error...
古いコメントを表示
Hi there,
I wanted to get shapley value of my pre-trained ANN.
it is regression model.
it's input's shape is 7*5120 double
and output is 1*5120 double.
I'm confused with idea of shapley.. sorry
3 件のコメント
Ronit
2024 年 8 月 26 日
移動済み: Angelo Yeo
2024 年 8 月 26 日
Hello,
In game theory, the Shapley value of a player is the average marginal contribution of the player in a cooperative game. That is, Shapley values are fair allocations, to individual players, of the total gain generated from a cooperative game. In the context of machine learning prediction, the Shapley value of a feature for a query point explains the contribution of the feature to a prediction (the response for regression or the score of each class for classification) at the specified query point.
The Shapley value corresponds to the deviation of the prediction for the query point from the average prediction, due to the feature. For each query point, the sum of the Shapley values for all features corresponds to the total deviation of the prediction from the average.
Please refer to the documentation of Shapley Values for Machine Learning Model which can be used during implementation:
Note: shapley function in MATLAB is available in Statistics and Machine Learning Toolbox.
I hope it helps you query!
Angelo Yeo
2024 年 8 月 26 日
Can you be more specific about your model and the error message? It's the best if you can share your model (code and data) and the reproduction steps for the error.
한용
2024 年 8 月 26 日
編集済み: Angelo Yeo
2024 年 8 月 26 日
回答 (1 件)
I do not have the model and dataset, so I used a random samples. The key is to use yticklabels. Would this work for you?
clc;
clear;
%% Shapley 값 계산
% demo neural network
x = randn(7, 150);
t = randn(1, 150);
net = fitnet(10);
net = configure(net,x,t);
% view(net)
f = @(x) net(x')'; % 인공신경망 모델 함수를 정의
x_veri_shapley = x(:,101:end)'; % 각 행이 하나의 샘플이 되도록 전치
x_train_shapley = x(:, 1:100)'; % 각 행이 하나의 샘플이 되도록 전치
% 샘플링 예시
num_samples = size(x_veri_shapley,1); % 샘플링할 데이터 수
idx = randperm(size(x_veri_shapley, 1), num_samples);
x_veri_shapley_sampled = x_veri_shapley(idx, :);
% % 병렬 처리 활성화
explainer = shapley(f, x_train_shapley, 'QueryPoints', x_veri_shapley_sampled,'UseParallel', false);
%%
% plot(explainer)
%%
% MeanAbsoluteShapley table을 복사
shapley_table = explainer.MeanAbsoluteShapley;
% 변수 이름 변경
desired_variable_names = ["PGA", "Dur_{sig}", "Sa_{max}", "Tm", "CAV_{max}", "Arias_{max}", "f_{1}"];
shapley_table.Predictor = desired_variable_names(:); % 새 변수 이름으로 교체
% Shapley 값과 변수 이름을 Shapley 값의 내림차순으로 정렬
[sorted_values, sort_index] = sort(shapley_table.ShapleyValue, 'ascend');
sorted_names = shapley_table.Predictor(sort_index);
% 막대 그래프 그리기 (큰 값부터 작은 값 순서로)
% figure;
% barh(sorted_values);
% set(gca, 'YTickLabel', sorted_names);
% xlabel('Shapley 절댓값의 평균');
% ylabel('예측 변수');
% title('Shapley 중요도 플롯');
%%
close all;
figure(10);
plot(explainer,QueryPointIndices=30);
hAxes = gca;
hAxes.TickLabelInterpreter = "tex";
yticklabels(hAxes, sorted_names) % use "yticklabels" to change the YTickLabels
figure(11);
plot(explainer);
hAxes = gca;
hAxes.TickLabelInterpreter = "tex";
yticklabels(hAxes, sorted_names) % use "yticklabels" to change the YTickLabels
figure(12);
swarmchart(explainer);
hAxes = gca;
hAxes.TickLabelInterpreter = "tex";
yticklabels(hAxes, sorted_names) % use "yticklabels" to change the YTickLabels
カテゴリ
ヘルプ センター および File Exchange で Axis Labels についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!


