How to extract partial derivatives of some specific layer in the back-propagation of a deep learning model?
6 ビュー (過去 30 日間)
古いコメントを表示
Say I have a deep learning model, and after training I call this model net.
When I input some images into net, I want to have the partial derivatives , where h are the outputs of the relu1 layer (i.e. ) and θ are the parameters of all trainable weights of the layers before relu1.
You can see that h (i.e. the output of relu1) will have a size of . I write the size of the training weights before relu1 as , where would be the set of all trainable parameters of the layers before relu1. Therefore should have the size of .
How can I get in the code? Many thanks!
My current code
%% Load Data
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
numTrainFiles = 50;
[imdsTrain,imdsValidation] = splitEachLabel(imds,numTrainFiles,'randomize');
%% Define Network Architecture
inputSize = [28 28 1];
numClasses = 10;
layers = [
imageInputLayer(inputSize)
convolution2dLayer(5,20,'Name','conv1')
batchNormalizationLayer('Name','bn1')
reluLayer('Name','relu1')
fullyConnectedLayer(numClasses,'Name','fc2')
softmaxLayer('Name','softmax')
classificationLayer];
%% Train Network
options = trainingOptions('sgdm', ...
'MaxEpochs',4, ...
'ValidationData',imdsValidation, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(imdsTrain,layers,options);
0 件のコメント
回答 (1 件)
Dinesh Yadav
2019 年 11 月 26 日
Hi
Kindly go through the following link and examples in it.
After the reluLayer command you can use dlgradient to compute partial derivatives on the outputs of relu layer.
Hope it helps.
3 件のコメント
Dinesh Yadav
2019 年 11 月 27 日
I dont think there is a way to do it with dlgradient without using loops . If you want to do it without using loops you will have to write your own custom gradient function.
参考
カテゴリ
Help Center および File Exchange で Custom Training Loops についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!