how can I obtain shapley values from convolutional neural network

20 ビュー (過去 30 日間)
Alok
Alok 2024 年 10 月 24 日
コメント済み: Alok 2024 年 10 月 25 日
clear all;
close all;
[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;
[XTest,YTest,anglesTest] = digitTest4DArrayData;
% Define the layers of the CNN
layers = [
imageInputLayer([28 28 1]) % Assuming grayscale 64x64 images
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10) % Assuming 10 classes
softmaxLayer
classificationLayer];
% Define the training options
options = trainingOptions('adam', ...
'MaxEpochs',10, ...
'InitialLearnRate',1e-4, ...
'Verbose',false, ...
'Plots','training-progress');
% Train the CNN on the training data
net = trainNetwork(XTrain, categorical(YTrain), layers, options);
% Predict on the test data
YPred = classify(net, XTest);
accuracy = mean(YPred == categorical(YTest));
% Display the confusion matrix
confusionchart(categorical(YTest), YPred);
explainer = shapley( ...
@(XTest)PredictCNN(net,XTest,YTest(1)), ...
reshape(XTest,[5000,28*28]), "QueryPoint", reshape(XTest(:,:,1,1),[1,28*28]) );
function score = PredictCNN(net,XTest,YTest)
YPred = predict(net,XTest);
score = YPred(:,double(YTest));
end

採用された回答

Rahul
Rahul 2024 年 10 月 25 日
Hi @Alok,
I understand that you are trying to get shapely values from a convolutional neural network. A few adjustments you can make in your code to ensure correct computation of Shapley values:
Flattening: The 4D images are reshaped to a 2D matrix for Shapley value calculation. The ‘shapley’ function expects a 2D input matrix ([NumImages, NumFeatures]), but CNNs typically take 4D input data ([Height, Width, Channels, NumSamples]).
XTestFlatSample = reshape(XTestSample, [28*28, numSamples])';
This line reshapes ‘XTestSample’ (e.g., 28x28x1x200 for grayscale images) into ‘XTestFlatSample’, a 2D matrix of size [200, 28*28], suitable for the ‘shapley’ function.
QueryPoint: Ensuring the ‘QueryPoint’ matches the format expected by your model, which is specified as ‘XTestFlatSample(1,:)’, indicating the first flattened testing image is being used as the point to explain.
Parallelization: Set 'UseParallel' to true in the ‘shapley’ function call to distribute the computation across multiple cores if you have the Parallel Computing Toolbox, to improve performance.
explainer = shapley( ...
@(X)PredictCNN(net, X, YTestSample(1)), ...
XTestFlatSample, ...
"QueryPoint", XTestFlatSample(1,:), ...
'UseParallel', true);
Custom Prediction Function: The CNN expects a 4D input format, but the ‘shapley’ function provides a 2D matrix. Additionally, ‘shapley’ requires a column vector of scores for each image, not the logits for all classes.
function score = PredictCNN(net, X, targetClass)
% Reshape the flattened input back to 4D
X = reshape(X', 28, 28, 1, size(X,1));
% Obtain logits from CNN
YPred = activations(net, X, 'fc');
% Extract the score for the target class
score = squeeze(YPred(targetClass, :))';
end
  • X = reshape(X', 28, 28, 1, size(X,1));” reshapes each input back to the original [28, 28, 1] image format.
  • score = squeeze(YPred(targetClass, :))';” selects the logits for the target class (the class label of the ‘QueryPoint’) and returns a column vector.
Here’s how the final code would look like:
clear all;
close all;
[XTrain,YTrain,anglesTrain] = digitTrain4DArrayData;
[XTest,YTest,anglesTest] = digitTest4DArrayData;
% Define the layers of the CNN
layers = [
imageInputLayer([28 28 1]) % Assuming grayscale 28x28 images
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10) % Assuming 10 classes
softmaxLayer
classificationLayer];
% Define the training options
options = trainingOptions('adam', ...
'MaxEpochs',10, ...
'InitialLearnRate',1e-4, ...
'Verbose',false, ...
'Plots','training-progress');
% Train the CNN on the training data
net = trainNetwork(XTrain, categorical(YTrain), layers, options);
% Predict on the test data
YPred = classify(net, XTest);
accuracy = mean(YPred == categorical(YTest));
% Display the confusion matrix
confusionchart(categorical(YTest), YPred);
% Sample a smaller portion of XTest for Shapley computation
numSamples = 200; % Adjust to a lower sample size if necessary
XTestSample = XTest(:,:,:,1:numSamples);
YTestSample = YTest(1:numSamples);
% Flatten the sample images to a 2D matrix
XTestFlatSample = reshape(XTestSample, [28*28, numSamples])';
% Shapley value computation with parallel processing and smaller sample
explainer = shapley( ...
@(X)PredictCNN(net, X, YTestSample(1)), ... % Custom prediction function
XTestFlatSample, ... % Pass the flattened input sample images
"QueryPoint", XTestFlatSample(1,:), ...
'UseParallel', true); % Use parallel processing
% Custom function to reshape input and get raw predictions from CNN
function score = PredictCNN(net, X, targetClass)
% Reshape the flattened input back to 4D [Height, Width, Channels, NumSamples]
X = reshape(X', 28, 28, 1, size(X,1)); % Transpose X to match original dimensions
% Get raw scores (logits) before softmax
YPred = activations(net, X, 'fc'); % 'fc' refers to the fully connected layer
YPred = squeeze(YPred); % Remove singleton dimensions
% Extract the score for the target class (column corresponding to class of interest)
score = YPred(targetClass, :)'; % Return the score for the target class as a column vector
end
You can refer to the following documentation link for more information regarding the usage of ‘shapley’ function:
Hope this helps!
  1 件のコメント
Alok
Alok 2024 年 10 月 25 日
Dear Rahul
Thank you for your kind assistance.
Yes, the code worked!
Appreciate it a lot!
Kind regards,
Alok

サインインしてコメントする。

その他の回答 (1 件)

Taylor
Taylor 2024 年 10 月 24 日
The shapley function expects the input data to be in a format suitable for the model. You are reshaping XTest to a 2D matrix with dimensions [5000, 28*28], assuming 5000 samples of 28x28 images. Ensure that XTest indeed has 5000 samples. If not, adjust the reshape dimensions accordingly.
The PredictCNN function is used as a handle in the shapley function. It takes XTest and YTest as inputs. However, the shapley function only passes the reshaped XTest. You will need to modify PredictCNN to handle this correctly, possibly by removing YTest from its input arguments.
Modify PredictCNN to accommodate the input format expected by shapley:
numSamples = size(XTest, 4); % Adjust based on your dataset
explainer = shapley( ...
@(XTest)PredictCNN(net, XTest), ...
reshape(XTest, [numSamples, 28*28]), ...
"QueryPoint", reshape(XTest(:,:,1,1), [1, 28*28]) );
Ensure that the XTest reshaping aligns with the actual number of samples:
function score = PredictCNN(net, XTest)
YPred = predict(net, XTest);
score = YPred;
end
  1 件のコメント
Alok
Alok 2024 年 10 月 24 日
Thanks for assisting with this issue. After executing the above suggestions, I got the following error message:
Error using shapley (line 300)
Unable to predict using the blackbox model.
Error in Shap_Example1 (line 54)
explainer = shapley( ...
Caused by:
Error using DAGNetwork/predict (line 195)
Incorrect input size. The input images must have a size of [28 28 1].
Size of XTest is:
size(XTest)
ans =
28 28 1 5000
The same error I have been getting before.
Appreciate your further suggestions.

サインインしてコメントする。

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by