How does selfAttentionLayer work,implementing validation with brief code?

29 ビュー (過去 30 日間)
cui,xingxing
cui,xingxing 2024 年 1 月 11 日
編集済み: cui,xingxing 約13時間 前
How does selfAttentionLayer work in detail every step of the way, can you simply reproduce its working process based on the paper formula? Thus verifying the selfAttentionLayer it's correctness and consisency.
official description:
A self-attention layer computes single-head or multihead self-attention of its input.
The layer:
  1. Computes the queries, keys, and values from the input
  2. Computes the scaled dot-product attention across heads using the queries, keys, and values
  3. Merges the results from the heads
  4. Performs a linear transformation on the merged result

採用された回答

cui,xingxing
cui,xingxing 2024 年 1 月 11 日
編集済み: cui,xingxing 約13時間 前
Here I have provided myself a simple code workflow with only 2 dimensions, "CT", to illustrate how each step works.
Note that each variable followed by a comment has a dimension representation.
%% 验证selfAttentionLayer操作计算与自己手算一致性!
XTrain = dlarray(rand(10,20));% CT
numClasses = 4;
numHeads = 6;
queryDims = 48; % N1=48
layers = [inputLayer(size(XTrain),"CT");
selfAttentionLayer(numHeads,queryDims,NumValueChannels=12,OutputSize=15,Name="sa");
layerNormalizationLayer;
fullyConnectedLayer(numClasses);
softmaxLayer];
net = dlnetwork(layers);
% analyzeNetwork(net)
XTrain = dlarray(XTrain,"CT");
[act1,act2] = predict(net,XTrain,Outputs=["input","sa"]);
act1 = extractdata(act1);% CT
act2 = extractdata(act2);% CT
% layer params
layerSA = net.Layers(2);
QWeights = layerSA.QueryWeights; % N1*C
KWeights = layerSA.KeyWeights;% N1*C
VWeights = layerSA.ValueWeights;% N2*C
outputW = layerSA.OutputWeights;% N3*N2
Qbias = layerSA.QueryBias; % N1*1
Kbias = layerSA.KeyBias;% N1*1
Vbias = layerSA.ValueBias; % N2*1
outputB = layerSA.OutputBias;% N3*1
% step1
q = QWeights*act1+Qbias; % N1*T
k = KWeights*act1+Kbias;% N1*T
v = VWeights*act1+Vbias;% N2*T
% step2,multiple heads
numChannelsQPerHeads = size(q,1)/numHeads;% 1*1
numChannelsVPerHeads = size(v,1)/numHeads;% 1*1
attentionM = cell(1,numHeads);
for i = 1:numHeads
idxQRange = numChannelsQPerHeads*(i-1)+1:numChannelsQPerHeads*i;
idxVRange = numChannelsVPerHeads*(i-1)+1:numChannelsVPerHeads*i;
qi = q(idxQRange,:);% diQ*T
ki = k(idxQRange,:);% diQ*T
vi = v(idxVRange,:);% diV*T
% attention
dk = size(qi,1);% 1*1
attentionScores = mysoftmax(ki'*qi./sqrt(dk));% T*T, note matlab interal code use k'*q,not q'*k
attentionM{i} = vi*attentionScores; % diV*T
end
%step3,merge attentionM
attention = cat(1,attentionM{:}); % N2*T,N2 = diV*numHeads
%step4,output linear projection
act_ = outputW*attention+outputB;% N3*T
act2(1,:)
ans = 1×20
-0.5919 -0.5888 -0.5905 -0.5916 -0.5902 -0.5956 -0.5936 -0.5910 -0.5906 -0.5922 -0.5943 -0.5926 -0.5915 -0.5920 -0.5947 -0.5935 -0.5925 -0.5932 -0.5917 -0.5884
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
act_(1,:)
ans = 1×20
-0.5919 -0.5888 -0.5905 -0.5916 -0.5902 -0.5956 -0.5936 -0.5910 -0.5906 -0.5922 -0.5943 -0.5926 -0.5915 -0.5920 -0.5947 -0.5935 -0.5925 -0.5932 -0.5917 -0.5884
<mw-icon class=""></mw-icon>
<mw-icon class=""></mw-icon>
I have reproduced its working process in the simplest possible way,hope it help others.
function out = mysoftmax(X,dim)
arguments
X
dim = 1;
end
% X = X-max(X,[],dim); %防止X过大导致取exp的值为Inf
X = exp(X);
out = X./sum(X,dim);
end
-------------------------Off-topic interlude, 2024-------------------------------
I am currently looking for a job in the field of CV algorithm development, based in Shenzhen, Guangdong, China,or a remote support position. I would be very grateful if anyone is willing to offer me a job or make a recommendation. My preliminary resume can be found at: https://cuixing158.github.io/about/ . Thank you!
Email: cuixingxing150@gmail.com

その他の回答 (0 件)

カテゴリ

Help Center および File ExchangeSequence and Numeric Feature Data Workflows についてさらに検索

製品


リリース

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by