Deep learning: predict and forward give very inconsistent result with batch normalisation
3 ビュー (過去 30 日間)
古いコメントを表示
Hi there,
I have a dnn with BatchNormalization layers. I did [Uf, statef] = forward(dnn_net, XT); "statef" returned from forward call should contain the learned mean and variance of the BN layers. I then update the dnn_net.State = statef so that the dnn_net.State is updated with the learned mean and var from the forward call. Then I did Up = predict(dnn_net,XT) [on the same data XT as the forward call]. Then I compare the mse(Up,Uf) and it turns out to be quite large (e.g 2.xx). This is totally not expected. Please help
I have created a simple test script to show the problem below:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'b-');
err = mse(Uf(1,:),Up(1,:))
0 件のコメント
回答 (1 件)
Matt J
2024 年 8 月 25 日
編集済み: Matt J
2024 年 8 月 25 日
To get agreement, you need to allow the trained mean and variance to converge:
numLayer = 9;
numNeurons = 80;
layers = featureInputLayer(2);% 2 input features
for i = 1:numLayer-1
layers = [
layers
fullyConnectedLayer(numNeurons)
batchNormalizationLayer
geluLayer
];
end
layers = [
layers
fullyConnectedLayer(4)];
dnn_net = dlnetwork(layers,'Initialize',true);
dnn_net = expandLayers(dnn_net);% expand the layers in residual blocks
dnn_net = dlupdate(@double,dnn_net);
X = 2*pi*dlarray(randn(1,1000),"CB");
T = 2*pi*dlarray(randn(1,1000),"CB");
XT = cat(1,X,T);
% compute DNN output using forward function, statef contains
% batchnormlayers (learned means and learned var)
for i=1:500
[Uf,statef] = forward(dnn_net,XT);
dnn_net.State = statef;% update the dnn_net.State (so that it has the BN layers updated learned means and var
end
Up = predict(dnn_net,XT);
% DNN output between predict call and forward call shoudl be the same in
% this case (becausae the dnn_net.State is updated with the same learned mean/var from forward calls.
% However, it is not the case. The err is quite large
plot(Uf(1,:),'r-');hold on; plot(Up(1,:),'--b');
err = mse(Uf(1,:),Up(1,:))
6 件のコメント
参考
カテゴリ
Help Center および File Exchange で Sequence and Numeric Feature Data Workflows についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!