メインコンテンツ

attention

ドット積アテンション

R2022b 以降

    説明

    attention 演算は、重み付き乗算を使用して入力の一部に焦点を当てます。

    Y = attention(queries,keys,values,numHeads) は、アテンション ヘッド数 numHeads を使用して、指定されたクエリ、キー、および値にドット積アテンション演算を適用します。クエリの入力引数は形式を整えた dlarray オブジェクトでなければなりません。

    [Y,weights] = attention(queries,keys,values,numHeads) は、ドット積アテンション演算を適用し、アテンションの重みも返します。

    [Y,weights] = attention(queries,keys,values,numHeads,DataFormat=FMT) は、形式を整えていない dlarray オブジェクト queries に対し、FMT で指定された形式でドット積アテンション演算を適用します。たとえば、DataFormat="CBT" は、データを "CBT" (チャネル、バッチ、時間) の形式で指定します。

    [Y,weights] = attention(queries,keys,values,numHeads,Name=Value) は、1 つ以上の名前と値の引数を使用して追加オプションを指定します。たとえば、DropoutProbability=0.01 は、ドロップアウト確率を 0.01 に指定します。

    すべて折りたたむ

    クエリ、キー、および値のサイズを指定します。

    querySize = 100;
    valueSize = 120;
    numQueries = 64;
    numValues = 80;
    numObservations = 32;

    クエリ、キー、および値を含むランダム配列を作成します。クエリには、dlarray の形式である "CBT" (チャネル、バッチ、時間) を指定します。

    queries = dlarray(rand(querySize,numObservations, numQueries),"CBT");
    keys = dlarray(rand(querySize,numObservations, numValues));
    values = dlarray(rand(valueSize,numObservations, numValues));

    アテンション ヘッド数を指定します。

    numHeads = 5;

    アテンション演算を適用します。

    [Y,weights] = attention(queries,keys,values,numHeads);

    出力のサイズと形式を表示します。

    size(Y)
    ans = 1×3
    
       120    32    64
    
    
    dims(Y)
    ans = 
    'CBT'
    

    重みのサイズと形式を表示します。

    size(weights)
    ans = 1×4
    
        80    64     5    32
    
    
    dims(weights)
    ans =
    
      0×0 empty char array
    

    attention 関数を使用すると、入力の一部に焦点を当てたマルチヘッド セルフ アテンション演算 [1] を実装できます。

    この例のマルチヘッド セルフ アテンション関数のセクションにリストされている関数 multiheadSelfAttention を作成します。multiheadSelfAttention 関数は、データ X、ヘッドの数、およびクエリ、キー、値、出力データに関する学習可能な重みを入力として受け取り、マルチヘッド アテンションの値を返します。

    入力 X は形式を整えていない dlarray オブジェクトでなければなりません。最初の次元は入力チャネルに対応し、2 番目の次元は時間次元または空間次元に対応し、3 番目の次元はバッチ次元に対応します。

    シーケンス データの配列を作成します。

    numChannels = 10;
    numObservations = 128;
    numTimeSteps = 100;
    
    X = rand(numChannels,numObservations,numTimeSteps);
    X = dlarray(X);
    size(X)
    ans = 1×3
    
        10   128   100
    
    

    マルチヘッド アテンションのヘッド数を指定します。

    numHeads = 8;

    マルチヘッド アテンションの学習可能なパラメーターを初期化します。

    • クエリ、キー、および値に関する学習可能な重みは、(numChannels*numHeads)numChannels 列の配列でなければなりません。

    • 出力に関する学習可能な重みは、(numChannels*numHeads)(numChannels*numHeads) 列の配列でなければなりません。

    outputSize = numChannels*numHeads;
    
    WQ = rand(outputSize,numChannels);
    WK = rand(outputSize,numChannels);
    WV = rand(outputSize,numChannels);
    WO = rand(outputSize,outputSize);

    マルチヘッド セルフ アテンション演算を適用します。

    Y = multiheadSelfAttention(X,numHeads,WQ,WK,WV,WO);

    出力のサイズを表示します。出力のサイズは、(numChannels*numHeads)×numObservations×(numTimeSteps) になります。

    size(Y)
    ans = 1×3
    
        80   128   100
    
    

    マルチヘッド セルフ アテンション関数

    multiheadSelfAttention 関数は、データ X、ヘッドの数、およびクエリ、キー、値、出力データに関する学習可能な重みを入力として受け取り、マルチヘッド アテンションの値を返します。

    • 入力 X は形式を整えていない dlarray オブジェクトでなければなりません。最初の次元は入力チャネルに対応し、2 番目の次元は時間次元または空間次元に対応し、3 番目の次元はバッチ次元に対応します。

    • クエリ、キー、および値に関する学習可能な重み行列は、(numChannels*numHeads)numChannels 列の行列でなければなりません。

    • 出力に関する学習可能な重み行列は、(numChannels*numHeads)(numChannels*numHeads) 列の行列でなければなりません。

    function Y = multiheadSelfAttention(X,numHeads,WQ,WK,WV,WO)
    
    queries = pagemtimes(WQ,X);
    keys = pagemtimes(WK,X);
    values = pagemtimes(WV,X);
    
    A = attention(queries,keys,values,numHeads,DataFormat="CBT");
    
    Y = pagemtimes(WO,A);
    
    end

    attention 関数を使用すると、入力に Luong アテンション演算を適用する関数を作成できます。Luong アテンション演算を適用する luongAttention 関数を作成します。この関数は例の最後にリストされています。

    配列のサイズを指定します。

    numHiddenUnits = 100;
    latentSize = 16;

    入力データを含むランダム配列を作成します。

    hiddenState = dlarray(rand(numHiddenUnits,1));
    Z = dlarray(rand(latentSize,1));
    weights = dlarray(rand(numHiddenUnits,latentSize));

    luongAttention 関数を適用します。

    [context,scores] = luongAttention(hiddenState,Z,weights);

    出力のサイズを表示します。

    size(context)
    ans = 1×2
    
        16     1
    
    
    size(scores)
    ans = 1×2
    
         1     1
    
    

    Luong アテンション関数

    luongAttention 関数は、Luong の "一般的な" スコアリング [2] に従って、コンテキスト ベクトルとアテンション スコアを返します。この演算は、ドット積アテンションでクエリ、キー、および値に隠れ状態、重み付き潜在表現、および潜在表現をそれぞれ指定することと等価です。

    function [context,scores] = luongAttention(hiddenState,Z,weights)
    
    numHeads = 1;
    queries = hiddenState;
    keys = pagemtimes(weights,Z);
    values = Z;
    
    [context,scores] = attention(queries,keys,values,numHeads, ...
        Scale=1, ...
        DataFormat="CBT");
    
    end

    入力引数

    すべて折りたたむ

    クエリ。dlarray オブジェクトとして指定します。

    queries には、最大 1 つの "S" (空間) 次元または "T" (時間) 次元を含めることができます。"U" (指定なし) というラベルが付いた queries 内のすべての次元はシングルトンでなければなりません。queries が、形式を整えていない dlarray オブジェクトである場合、DataFormat オプションを使用してデータ形式を指定します。

    keys"C" (チャネル) 次元のサイズは、queries の対応する次元のサイズと一致しなければなりません。

    querieskeys、および values"B" (バッチ) 次元のサイズは一致しなければなりません。

    キー。dlarray オブジェクトまたは数値配列として指定します。

    keys が、形式を整えた dlarray オブジェクトである場合、その形式は queries の形式と一致しなければなりません。keys が、形式を整えた dlarray オブジェクトでない場合、この関数は queries と同じ形式を使用します。

    keys"S" (空間) 次元または "T" (時間) 次元のサイズは、values の対応する次元のサイズと一致しなければなりません。

    keys"C" (チャネル) 次元のサイズは、queries の対応する次元のサイズと一致しなければなりません。

    querieskeys、および values"B" (バッチ) 次元のサイズは一致しなければなりません。

    値。dlarray オブジェクトまたは数値配列として指定します。

    values が、形式を整えた dlarray オブジェクトである場合、その形式は queries の形式と一致しなければなりません。そうでない場合、この関数は queries と同じ形式を使用します。

    keys"S" (空間) 次元または "T" (時間) 次元のサイズは、values の対応する次元のサイズと一致しなければなりません。

    querieskeys、および values"B" (バッチ) 次元のサイズは一致しなければなりません。

    ヘッドの数。正の整数として指定します。

    各ヘッドは入力に対して個別の線形変換を実行し、アテンションの重みを独立して計算します。層はこれらのアテンション重みを使用して入力表現の重み付き和を計算し、コンテキスト ベクトルを生成します。ヘッドの数を増やすと、モデルはさまざまな種類の依存関係を捉え、入力のさまざまな部分に同時に注意を向けることができるようになります。ヘッドの数を減らすと、層の計算コストを削減できます。

    numHeads の値は、querieskeys、および values"C" (チャネル) 次元のサイズを均等に分割しなければなりません。

    名前と値の引数

    すべて折りたたむ

    オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで、Name は引数名で、Value は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。

    R2021a より前では、コンマを使用して名前と値をそれぞれ区切り、Name を引用符で囲みます。

    例: attention(queries,keys,values,numHeads,DataFormat="CBT") は、形式を整えていないデータに対してアテンション演算を適用し、データ形式 "CBT" (チャネル、バッチ、時間) を指定します。

    データの次元の説明。文字ベクトルまたは string スカラーとして指定します。

    データ形式は文字列で、各文字は対応するデータ次元のタイプを表します。

    各文字は以下のとおりです。

    • "S" — 空間

    • "C" — チャネル

    • "B" — バッチ

    • "T" — 時間

    • "U" — 指定なし

    たとえば、シーケンスのバッチを表し、1 番目、2 番目、および 3 番目の次元がそれぞれチャネル、観測値、およびタイム ステップに対応する配列があるとします。データは "CBT" (チャネル、バッチ、時間) の形式で記述できます。

    "S" または "U" のラベルが付いた次元については、複数回指定できます。ラベル "C""B"、および "T" はそれぞれ 1 回まで使用できます。ソフトウェアは、2 番目の次元の後ろにある大きさが 1 の "U" 次元を無視します。

    入力データが、形式を整えた dlarray オブジェクトでない場合は、DataFormat オプションを指定しなければなりません。

    詳細については、深層学習のデータ形式を参照してください。

    データ型: char | string

    スケーリングされたドット積アテンションの乗法係数[1]。次のいずれかの値として指定します。

    • "auto" — ドット積を λ=1dk で乗算します。ここで、dk はキーのチャネル数をヘッドの数で割った値を表します。

    • 数値スカラー — 指定されたスケール係数でドット積を乗算します。

    データ型: single | double | char | string

    入力のどの要素がパディング値に対応するかを示すマスク。dlarray オブジェクト、logical 配列、またはバイナリ値の数値配列として指定します。

    この関数は、PaddingMask 内の対応する要素がそれぞれ 01 である場合に、入力データのキーと値のペアの要素へのアテンションを禁止および許可します。

    PaddingMask が、形式を整えた dlarray オブジェクトである場合、その形式は keys の形式と一致しなければなりません。PaddingMask が、形式を整えた dlarray オブジェクトでない場合、この関数は keys と同じ形式を使用します。PaddingMask"S" (空間)、"T" (時間)、および "B" (バッチ) の各次元のサイズは、keys および values の対応する次元のサイズと一致しなければなりません。

    パディング マスクには任意の数のチャネルを含めることができます。ソフトウェアは、最初のチャネルの値をパディング値の表示にのみ使用します。

    既定値は、keys と同じサイズの 1 から成る logical 配列です。

    attention 演算を適用するときに含める要素を示すアテンション マスク。次のいずれかの値として指定します。

    • "none" — 位置に関係なく、要素に注意を払うことを抑制しません。AttentionMask"none" である場合、ソフトウェアはパディング マスクのみを使用して注意を抑制します。

    • "causal" — 入力されたクエリの "S" (空間) または "T" (時間) 次元の位置 m にある要素が、入力されたキーと値の対応する次元において、位置 n (nm より大きい) にある要素に注意を払うことを抑制します。このオプションは自己回帰モデルに使用します。

    • 論理配列または数値配列 — 指定された配列内の対応する要素が 0 である場合、入力されたキーと値の要素に注意を払うことを抑制します。指定された配列は、NkNq 列の行列、または Nk×Nq×numObservations の配列でなければなりません。Nk は入力されたキーの "S" (空間) 次元または "T" (時間) 次元のサイズ、Nq は入力されたクエリの対応する次元のサイズ、numObservations は入力されたクエリの "B" 次元のサイズです。

    データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | logical | char | string

    アテンションの重みのドロップアウトの確率。範囲 [0, 1) のスカラーとして指定します。

    データ型: single | double

    出力引数

    すべて折りたたむ

    アテンション演算の結果。dlarray オブジェクトとして返されます。

    queries が、形式を整えた dlarray オブジェクトである場合、Y は、queries と同じ次元ラベルをもつ形式を整えた dlarray オブジェクトになります。Y"C" (チャネル) 次元のサイズは、values の対応する次元のサイズと同じです。Y"S" (空間) 次元または "T" 次元のサイズは、queries の対応する次元のサイズと同じです。

    queries が、形式を整えた dlarray オブジェクトでない場合、Y は、形式を整えていない dlarray オブジェクトになります。

    アテンションの重み。形式を整えていない dlarray オブジェクトとして返されます。

    weights は、Nk×Nq×numHeads×numObservations の配列です。ここで、Nkkeys"S" (空間) 次元または "T" (時間) 次元のサイズ、Nqqueries 内の対応する次元のサイズ、numObservationsqueries 内の "B" (バッチ) 次元のサイズです。

    アルゴリズム

    すべて折りたたむ

    参照

    [1] Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. "Attention is all you need." Advances in neural information processing systems 30 (December 2017): 6000-6010. https://papers.nips.cc/paper/7181-attention-is-all-you-need.

    [2] Luong, Minh-Thang, Hieu Pham, and Christopher D. Manning. "Effective approaches to attention-based neural machine translation." arXiv preprint arXiv:1508.04025 (2015).

    拡張機能

    すべて展開する

    バージョン履歴

    R2022b で導入