メインコンテンツ

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

attentionLayer

ドット積注意層

R2024a 以降

    説明

    ドット積注意層は、重み付き乗算演算を使用して入力の一部に焦点を当てます。

    作成

    説明

    layer = attentionLayer(numHeads) はドット積注意層を作成し、NumHeads プロパティを設定します。

    layer = attentionLayer(numHeads,Name=Value) は、1 つ以上の名前と値の引数を使用して、ScaleHasPaddingMaskInputHasScoresOutputAttentionMaskDropoutProbability、および Name の各プロパティも設定します。

    プロパティ

    すべて展開する

    アテンション

    ヘッドの数。正の整数として指定します。

    各ヘッドは入力に対して個別の線形変換を実行し、アテンションの重みを独立して計算します。層はこれらのアテンション重みを使用して入力表現の重み付き和を計算し、コンテキスト ベクトルを生成します。ヘッドの数を増やすと、モデルはさまざまな種類の依存関係を捉え、入力のさまざまな部分に同時に注意を向けることができるようになります。ヘッドの数を減らすと、層の計算コストを削減できます。

    NumHeads の値は、入力されたクエリ、キー、および値のチャネル次元のサイズを均等に分割しなければなりません。

    データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    クエリとキーのドット積をスケーリングするための乗法係数。次のいずれかの値として指定します。

    • "auto" — ドット積を 1/sqrt(D) で乗算します。ここで、D はキーのチャネル数を NumHeads で割った値です。

    • 数値スカラー — ドット積を指定されたスカラーで乗算します。

    データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | char | string | cell

    層にパディング マスクを表す入力があるかどうかを示すフラグ。0 (false) または 1 (true) として指定します。

    HasPaddingMaskInput プロパティが 0 (false) の場合、層は、それぞれ入力されたクエリ、キー、および値に対応する "query""key"、および "value" という名前の 3 つの入力をもちます。この場合、層はすべての要素をデータとして扱います。

    HasPaddingMaskInput プロパティが 1 (true) の場合、層は、パディング マスクに対応する "mask" という名前の追加の入力をもちます。この場合、パディング マスクは 1 と 0 から成る配列になります。層は、クエリ、キー、値の要素について、マスク内の対応する要素が 1 の場合は使用し、0 の場合は無視します。

    パディング マスクの形式は入力されたキーの形式と一致していなければなりません。パディング マスクの "S" (空間)、"T" (時間)、および "B" (バッチ) の次元のサイズは、キーと値の対応する次元のサイズと一致していなければなりません。

    パディング マスクには任意の数のチャネルを含めることができます。ソフトウェアは、パディング値を示すために最初のチャネルの値のみを使用します。

    層にスコア (アテンションの重みとも呼ばれる) を表す出力があるかどうかを示すフラグ。0 (false) または 1 (true) として指定します。

    HasScoresOutput プロパティが 0 (false) の場合、層は、出力データに対応する "out" という名前の 1 つの出力をもちます。

    HasScoresOutput プロパティが 1 (true) の場合、層は、それぞれ出力データおよびアテンション スコアに対応する "out" および "scores" という名前の 2 つの入力をもちます。

    attention 演算を適用するときに含める要素を示すアテンション マスク。次のいずれかの値として指定します。

    • "none" — 位置に関係なく、要素に注意を払うことを抑制しません。AttentionMask"none" である場合、ソフトウェアはパディング マスクのみを使用して注意を抑制します。

    • "causal" — 入力されたクエリの "S" (空間) または "T" (時間) 次元の位置 m にある要素が、入力されたキーと値の対応する次元において、位置 n (nm より大きい) にある要素に注意を払うことを抑制します。このオプションは自己回帰モデルに使用します。

    • 論理配列または数値配列 — 指定された配列内の対応する要素が 0 である場合、入力されたキーと値の要素に注意を払うことを抑制します。指定された配列は、NkNq 列の行列、または Nk×Nq×numObservations の配列でなければなりません。Nk は入力されたキーの "S" (空間) 次元または "T" (時間) 次元のサイズ、Nq は入力されたクエリの対応する次元のサイズ、numObservations は入力されたクエリの "B" 次元のサイズです。

    データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | logical | char | string

    アテンション スコアをドロップアウトする確率。範囲 [0, 1) のスカラーとして指定します。

    学習中、ソフトウェアは指定された確率を使用して、アテンション スコアの値をランダムにゼロに設定します。これらのドロップアウトにより、モデルが特定の依存関係に過度に依存することを防ぎ、より堅牢で一般化可能な表現を学習できるようになります。

    データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    層の名前。文字ベクトルまたは string スカラーとして指定します。Layer 配列入力の場合、trainnet 関数および dlnetwork 関数は、名前のない層に自動的に名前を割り当てます。

    AttentionLayer オブジェクトは、このプロパティを文字ベクトルとして格納します。

    データ型: char | string

    層への入力の数。3 または 4 として返されます。

    HasPaddingMaskInput プロパティが 0 (false) の場合、層は、それぞれ入力されたクエリ、キー、および値に対応する "query""key"、および "value" という名前の 3 つの入力をもちます。この場合、層はすべての要素をデータとして扱います。

    HasPaddingMaskInput プロパティが 1 (true) の場合、層は、パディング マスクに対応する "mask" という名前の追加の入力をもちます。この場合、パディング マスクは 1 と 0 から成る配列になります。層は、クエリ、キー、値の要素について、マスク内の対応する要素が 1 の場合は使用し、0 の場合は無視します。

    パディング マスクの形式は入力されたキーの形式と一致していなければなりません。パディング マスクの "S" (空間)、"T" (時間)、および "B" (バッチ) の次元のサイズは、キーと値の対応する次元のサイズと一致していなければなりません。

    パディング マスクには任意の数のチャネルを含めることができます。ソフトウェアは、パディング値を示すために最初のチャネルの値のみを使用します。

    データ型: double

    層の入力名。文字ベクトルの cell 配列として返されます。

    HasPaddingMaskInput プロパティが 0 (false) の場合、層は、それぞれ入力されたクエリ、キー、および値に対応する "query""key"、および "value" という名前の 3 つの入力をもちます。この場合、層はすべての要素をデータとして扱います。

    HasPaddingMaskInput プロパティが 1 (true) の場合、層は、パディング マスクに対応する "mask" という名前の追加の入力をもちます。この場合、パディング マスクは 1 と 0 から成る配列になります。層は、クエリ、キー、値の要素について、マスク内の対応する要素が 1 の場合は使用し、0 の場合は無視します。

    パディング マスクの形式は入力されたキーの形式と一致していなければなりません。パディング マスクの "S" (空間)、"T" (時間)、および "B" (バッチ) の次元のサイズは、キーと値の対応する次元のサイズと一致していなければなりません。

    パディング マスクには任意の数のチャネルを含めることができます。ソフトウェアは、パディング値を示すために最初のチャネルの値のみを使用します。

    AttentionLayer オブジェクトは、このプロパティを文字ベクトルの cell 配列として格納します。

    この プロパティ は読み取り専用です。

    層の出力の数。

    HasScoresOutput プロパティが 0 (false) の場合、層は、出力データに対応する "out" という名前の 1 つの出力をもちます。

    HasScoresOutput プロパティが 1 (true) の場合、層は、それぞれ出力データおよびアテンション スコアに対応する "out" および "scores" という名前の 2 つの入力をもちます。

    データ型: double

    この プロパティ は読み取り専用です。

    層の出力名。

    HasScoresOutput プロパティが 0 (false) の場合、層は、出力データに対応する "out" という名前の 1 つの出力をもちます。

    HasScoresOutput プロパティが 1 (true) の場合、層は、それぞれ出力データおよびアテンション スコアに対応する "out" および "scores" という名前の 2 つの入力をもちます。

    AttentionLayer オブジェクトは、このプロパティを文字ベクトルの cell 配列として格納します。

    すべて折りたたむ

    10 個のヘッドをもつドット積注意層を作成します。

    layer = attentionLayer(10)
    layer = 
      AttentionLayer with properties:
    
                       Name: ''
                  NumInputs: 3
                 InputNames: {'query'  'key'  'value'}
                   NumHeads: 10
                      Scale: 'auto'
              AttentionMask: 'none'
         DropoutProbability: 0
        HasPaddingMaskInput: 0
            HasScoresOutput: 0
    
       Learnable Parameters
        No properties.
    
       State Parameters
        No properties.
    
      Show all properties
    
    

    クロスアテンションによるシンプルなニューラル ネットワークを作成します。

    numChannels = 256;
    numHeads = 8;
    
    net = dlnetwork;
    
    layers = [
        sequenceInputLayer(1,Name="query")
        fullyConnectedLayer(numChannels)
        attentionLayer(numHeads,Name="attention")
        fullyConnectedLayer(numChannels,Name="fc-out")];
    
    net = addLayers(net,layers);
    
    layers = [
        sequenceInputLayer(1, Name="key-value")
        fullyConnectedLayer(numChannels,Name="fc-key")];
    
    net = addLayers(net,layers);
    net = connectLayers(net,"fc-key","attention/key");
    
    net = addLayers(net, fullyConnectedLayer(numChannels,Name="fc-value"));
    net = connectLayers(net,"key-value","fc-value");
    net = connectLayers(net,"fc-value","attention/value");

    ネットワークをプロットで表示します。

    figure
    plot(net)

    Figure contains an axes object. The axes object contains an object of type graphplot.

    アルゴリズム

    すべて展開する

    参照

    [1] Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, and Illia Polosukhin. "Attention is all you need." In Advances in Neural Information Processing Systems, Vol. 30. Curran Associates, Inc., 2017. https://papers.nips.cc/paper/7181-attention-is-all-you-need.

    拡張機能

    すべて展開する

    バージョン履歴

    R2024a で導入