How to implement PyTorch's Linear layer in Matlab?
9 ビュー (過去 30 日間)
古いコメントを表示
Hello,
The problem is that Linear does not flatten its inputs whereas Matlab's fullyConnectedLayer does, so the two are not equivalent.
Thx,
J
0 件のコメント
回答 (4 件)
Matt J
2023 年 2 月 11 日
編集済み: Matt J
2023 年 2 月 11 日
One possibility might be to express the linear layer as a cascade of fullyConnectedLayer followed by a functionLayer. The functionLayer can reshape the flattened input back to the form you want,
layer = functionLayer(@(X)reshape(X,[h,w,c]));
9 件のコメント
Matt J
2023 年 2 月 13 日
編集済み: Matt J
2023 年 2 月 13 日
This solution sums all channels together.
No, it won't. (Keep in mind that this is the 3rd solution I've proposed as information about your aims has come out). After the reshaping, each channel is contained in its own column of X. And, because the filter you apply to X is (H*W)x1xN there is no way for the filter to combine elements from different columns.
Matt J
2023 年 2 月 13 日
編集済み: Matt J
2023 年 2 月 13 日
Another possible way to interpret your question is that you are trying to apply pagemtimes to the input X with a non-learnable matrix A, where the different channels of X are the pages. That can also be done with a functionLayer, as illustrated below both with normal arrays and with dlarrays,
A=rand(4,3); %non-learnable matrix A
xdata=rand(3,3,2); %input layer data with 2 channels
multLayer=functionLayer(@(X) dlarray( pagemtimes(A,stripdims(X)) ,dims(X)) );
X=dlarray(xdata,'SSC');
Y=multLayer.predict(X)
%%Verify agreement with normal pagemtimes
ydata=pagemtimes(A,xdata)
3 件のコメント
Matt J
2023 年 2 月 13 日
編集済み: Matt J
2023 年 2 月 13 日
The modification for the case where A is learnable is as below. I am using a pre-declared A here only so that I can demonstrate and test the response. In a real scenario, you wouldn't supply weights to the convolution2dLayer.
X=dlarray(rand(3,3,2),'SSC'); A=rand(4,3);
[h,w,c]=size(X);
L1=functionLayer( @(z) z(:,:) );
Lconv=convolution2dLayer([h,1],4,'Weights',permute(A,[2,3,4,1]));
L2=functionLayer(@(z)recoverShape(z,w,c) ,'Formattable',1);
net=dlnetwork([L1,Lconv,L2],X);
Yfinal=net.predict(X)
And as before, we can compare to the result of a plain-vanilla pagemtimes operation and see that it gives the same result:
Ycheck=pagemtimes(A, extractdata(X))
function out=recoverShape(z,w,c)
z=permute( stripdims(z), [3,2,1]);
out=dlarray(reshape(z,[],w,c),'SSC');
end
Matt J
2023 年 2 月 14 日
編集済み: Matt J
2023 年 2 月 14 日
Another approach is to write your own custom layer for channel-wise matrix multiplication. I have attached a possible version of this,
X=rand(3,3,2);
L=pagemtimesLayer(4); %Custom layer - premultiplies channels by 4-row learnable matrix A
L=initialize(L, X);
Ypred=L.predict(X)
Ycheck=pagemtimes(L.A,X) %Check agreement with a direct call to pagemtimes()
8 件のコメント
Matt J
2023 年 2 月 15 日
That sounds right.
Although, part of me questions whether it was the best design for TMW to make the the user responsible for summing over batched input in the backward() method, since that dimension should always be handled the same way.
参考
カテゴリ
Help Center および File Exchange で Image Data Workflows についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!