Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

アテンションを使用した sequence-to-sequence 変換

この例では、アテンションを使用した再帰型 sequence-to-sequence 符号化器-復号化器モデルを用いて数字の文字列をローマ数字に変換する方法を説明します。

再帰型符号化器-復号化器モデルは、要旨のテキストの要約やニューラル機械翻訳のようなタスクにおける有効性が証明されています。このモデルは、通常、LSTM のような再帰層で入力データを処理する "符号化器" と、第 2 の再帰層で符号化された入力を目的の出力にマッピングする "復号化器" で構成されます。"アテンション メカニズム" を組み込んだモデルでは、変換を生成しながら、符号化された入力の一部に復号化器の焦点を当てることが可能です。

符号化器モデルについて、この例では 1 つの埋め込みとそれに続く 1 つの LSTM 演算で構成されるシンプルなネットワークを使用します。埋め込みとは、categorical トークンを数値ベクトルに変換する手法です。

復号化器モデルについて、この例では 1 つの LSTM 演算と 1 つのアテンション メカニズムを含むネットワークを使用します。アテンション メカニズムにより、復号化器が符号化器の出力の特定部分に "注意を払う" ことができます。

学習データの読み込み

"romanNumerals.csv" から数字とローマ数字のペアをダウンロードします。

filename = fullfile("romanNumerals.csv");

options = detectImportOptions(filename, ...
    TextType="string", ...
    ReadVariableNames=false);
options.VariableNames = ["Source" "Target"];
options.VariableTypes = ["string" "string"];

data = readtable(filename,options);

データを学習用とテスト用の区画に 50% ずつ分割します。

idx = randperm(size(data,1),500);
dataTrain = data(idx,:);
dataTest = data;
dataTest(idx,:) = [];

数字とローマ数字のペアの一部を表示します。

head(dataTrain)
    Source       Target   
    ______    ____________

    "437"     "CDXXXVII"  
    "431"     "CDXXXI"    
    "102"     "CII"       
    "862"     "DCCCLXII"  
    "738"     "DCCXXXVIII"
    "527"     "DXXVII"    
    "401"     "CDI"       
    "184"     "CLXXXIV"   

データの前処理

例の最後にリストされている関数 transformText を使用し、テキスト データの前処理を行います。関数 transformText は、テキストを文字に分割して開始と停止のトークンを追加し、変換のために入力テキストの前処理とトークン化を行います。テキストを文字ではなく単語に分割してテキストを変換するには、最初の手順をスキップします。

startToken = "<start>";
stopToken = "<stop>";

strSource = dataTrain.Source;
documentsSource = transformText(strSource,startToken,stopToken);

wordEncoding オブジェクトを作成します。このオブジェクトは、ボキャブラリを使用し、トークンと数値インデックスを相互にマッピングします。

encSource = wordEncoding(documentsSource);

単語符号化を使用し、ソースのテキスト データを数値のシーケンスに変換します。

sequencesSource = doc2sequence(encSource,documentsSource,PaddingDirection="none");

同じ手順でターゲット データをシーケンスに変換します。

strTarget = dataTrain.Target;
documentsTarget = transformText(strTarget,startToken,stopToken);
encTarget = wordEncoding(documentsTarget);
sequencesTarget = doc2sequence(encTarget,documentsTarget,PaddingDirection="none");

シーケンスを長さで並べ替えます。シーケンスの長さを増やし、並べ替えたシーケンスを使用して学習させると、バッチ内のシーケンスのシーケンス長がほぼ同じになり、小さいシーケンスのバッチを長いシーケンスのバッチよりも先にモデルの更新に確実に使用できます。

sequenceLengths = cellfun(@(sequence) size(sequence,2),sequencesSource);
[~,idx] = sort(sequenceLengths);
sequencesSource = sequencesSource(idx);
sequencesTarget = sequencesTarget(idx);

ソース データとターゲット データを含む arrayDatastore オブジェクトを作成し、関数 combine を使用してそれらを結合します。

sequencesSourceDs = arrayDatastore(sequencesSource,OutputType="same");
sequencesTargetDs = arrayDatastore(sequencesTarget,OutputType="same");

sequencesDs = combine(sequencesSourceDs,sequencesTargetDs);

モデル パラメーターの初期化

モデル パラメーターを初期化します。符号化器と復号化器のそれぞれについて、128 の埋め込み次元、100 個の隠れユニットを持つ LSTM 層、および確率 0.05 でランダムにドロップアウトするドロップアウト層を指定します。

embeddingDimension = 128;
numHiddenUnits = 100;
dropout = 0.05;

符号化器モデル パラメーターの初期化

関数 initializeGaussian を使用し、ガウスで符号化埋め込みの重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。平均値を 0、標準偏差を 0.01 に指定します。詳細については、ガウスによる初期化を参照してください。

inputSize = encSource.NumWords + 1;
sz = [embeddingDimension inputSize];
mu = 0;
sigma = 0.01;
parameters.encoder.emb.Weights = initializeGaussian(sz,mu,sigma);

符号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

  • 関数 initializeGlorot を使用し、Glorot 初期化子で入力の重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、Glorot の初期化を参照してください。

  • 関数 initializeOrthogonal を使用し、直交初期化子で再帰重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、直交初期化を参照してください。

  • 関数 initializeUnitForgetGate を使用し、ユニット忘却ゲート初期化子でバイアスを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、ユニット忘却ゲートによる初期化を参照してください。

符号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

sz = [4*numHiddenUnits embeddingDimension];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension;

parameters.encoder.lstm.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm.Bias = initializeUnitForgetGate(numHiddenUnits);

復号化器モデル パラメーターの初期化

関数 initializeGaussian を使用し、ガウスで符号化埋め込みの重みを初期化します。平均値を 0、標準偏差を 0.01 に指定します。

outputSize = encTarget.NumWords + 1;
sz = [embeddingDimension outputSize];
mu = 0;
sigma = 0.01;
parameters.decoder.emb.Weights = initializeGaussian(sz,mu,sigma);

関数 initializeGlorot を使用し、Glorot 初期化子でアテンション メカニズムの重みを初期化します。

sz = [numHiddenUnits numHiddenUnits];
numOut = numHiddenUnits;
numIn = numHiddenUnits;
parameters.decoder.attention.Weights = initializeGlorot(sz,numOut,numIn);

復号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

  • 関数 initializeGlorot を使用し、Glorot 初期化子で入力の重みを初期化します。

  • 関数 initializeOrthogonal を使用し、直交初期化子で再帰重みを初期化します。

  • 関数 initializeUnitForgetGate を使用し、ユニット忘却ゲート初期化子でバイアスを初期化します。

復号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

sz = [4*numHiddenUnits embeddingDimension+numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension + numHiddenUnits;

parameters.decoder.lstm.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm.Bias = initializeUnitForgetGate(numHiddenUnits);

復号化器の全結合演算に関する学習可能なパラメーターを初期化します。

  • Glorot 初期化子を使用して重みを初期化します。

  • 関数 initializeZeros を使用し、ゼロでバイアスを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、ゼロでの初期化を参照してください。

sz = [outputSize 2*numHiddenUnits];
numOut = outputSize;
numIn = 2*numHiddenUnits;

parameters.decoder.fc.Weights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.fc.Bias = initializeZeros([outputSize 1]);

モデルの関数の定義

この例の最後にリストされている関数 modelEncoder および modelDecoder を作成し、符号化器および復号化器モデルの出力をそれぞれ計算します。

この例の符号化器モデル関数の節にリストされている関数 modelEncoder は、入力データ、モデル パラメーター、学習の正しい出力の判断に使用するオプションのマスクを受け取り、モデルの出力と LSTM の隠れ状態を返します。

この例の復号化器モデル関数の節にリストされている関数 modelDecoder は、入力データ、モデル パラメーター、コンテキスト ベクトル、LSTM の初期隠れ状態、符号化器の出力、およびドロップアウトの確率を受け取り、復号化器の出力、更新されたコンテキスト ベクトル、更新された LSTM 状態、およびアテンション スコアを返します。

モデル損失関数の定義

この例のモデル損失関数のセクションにリストされている関数 modelLoss を作成します。この関数は、符号化器と復号化器のモデル パラメーター、入力データのミニバッチと入力データに対応するパディング マスク、およびドロップアウトの確率を受け取り、モデル内の学習可能なパラメーターについての損失とその損失の勾配を返します。

学習オプションの指定

ミニバッチ サイズを 32、学習率を 0.001 として 100 エポック学習させます。

miniBatchSize = 32;
numEpochs = 100;
learnRate = 0.001;

Adam のオプションを初期化します。

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

モデルの学習

カスタム学習ループを使用してモデルに学習させます。minibatchqueue を使用し、学習中にイメージのミニバッチを処理および管理します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、ミニバッチ内のすべてのシーケンスの長さを求め、ソース シーケンスとターゲット シーケンスについて、それぞれのシーケンスを最長のシーケンスと同じ長さにパディングします。

  • パディングされたシーケンスの 2 番目と 3 番目の次元を入れ替えます。

  • 基になるデータ型が single である、ミニバッチ変数の書式化されていない dlarray オブジェクトを返します。他のすべての出力は、データ型 single の配列です。

  • GPU が利用できる場合、GPU で学習を行います。利用可能な場合は、GPU 上のすべてのミニバッチ変数を返します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

minibatchqueue オブジェクトは、ミニバッチごとに、ソース シーケンス、ターゲット シーケンス、ミニバッチ内のすべてのソース シーケンスの長さ、およびターゲット シーケンスのシーケンス マスクの 4 つの出力引数を返します。

numMiniBatchOutputs = 4;

mbq = minibatchqueue(sequencesDs,numMiniBatchOutputs,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@(x,t) preprocessMiniBatch(x,t,inputSize,outputSize));

関数 adamupdate 用の値を初期化します。

trailingAvg = [];
trailingAvgSq = [];

学習の進行状況モニター用に合計反復回数を計算します。

numObservationsTrain = numel(sequencesSource);
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;

学習の進行状況モニターを初期化します。monitor オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。

monitor = trainingProgressMonitor( ...
    Metrics="Loss", ...
    Info="Epoch", ...
    XLabel="Iteration");

モデルに学習させます。各ミニバッチで次を行います。

  • パディングされたシーケンスのミニバッチを読み取ります。

  • 損失と勾配を計算します。

  • 関数 adamupdate を使用して符号化器および復号化器のモデル パラメーターを更新します。

  • 学習の進行状況モニターを更新します。

  • 学習の進行状況モニターの Stop プロパティが true の場合、学習を停止します。停止ボタンをクリックすると、学習モニターの Stop プロパティが 1 に変わります。

epoch = 0;
iteration = 0;

% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;

    reset(mbq);

    % Loop over mini-batches.
    while hasdata(mbq) && ~monitor.Stop
        iteration = iteration + 1;

        [X,T,sequenceLengthsSource,maskSequenceTarget] = next(mbq);

        % Compute loss and gradients.
        [loss,gradients] = dlfeval(@modelLoss,parameters,X,T,sequenceLengthsSource,...
            maskSequenceTarget,dropout);

        % Update parameters using adamupdate.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients,trailingAvg,trailingAvgSq,...
            iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);

        % Normalize loss by sequence length.
        loss = loss ./ size(T,3);

        % Update the training progress monitor. 
        recordMetrics(monitor,iteration,Loss=loss);
        updateInfo(monitor,Epoch=epoch + " of " + numEpochs);
        monitor.Progress = 100*iteration/numIterations;
    end
end

変換の生成

学習済みのモデルを使用して新しいデータの変換を生成するには、学習時と同じ手順を使用してテキスト データを数字のシーケンスに変換し、そのシーケンスを符号化器-復号化器モデルに入力し、トークン インデックスを使用して結果のシーケンスをテキストに変換し直します。

学習時と同じ手順を使用してテキスト データを前処理します。この例の最後にリストされている関数 transformText を使用し、テキストを文字に分割して開始と停止のトークンを追加します。

strSource = dataTest.Source;
strTarget = dataTest.Target;

関数 modelPredictions を使用してテキストを変換します。

maxSequenceLength = 10;
delimiter = "";

strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter);

テスト ソース テキスト、ターゲット テキスト、および変換結果を格納する table を作成します。

tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = strTranslated;

ランダムに選択した変換結果を表示します。

idx = randperm(size(dataTest,1),miniBatchSize);
tbl(idx,:)
ans=32×3 table
    Source      Target        Translated 
    ______    ___________    ____________

    "996"     "CMXCVI"       "CMMXCVI"   
    "576"     "DLXXVI"       "DCLXXVI"   
    "86"      "LXXXVI"       "DCCCLXV"   
    "23"      "XXIII"        "CCCCXIII"  
    "99"      "XCIX"         "CMMXIX"    
    "478"     "CDLXXVIII"    "DCCCLXXVII"
    "313"     "CCCXIII"      "CCCXIII"   
    "60"      "LX"           "DLX"       
    "864"     "DCCCLXIV"     "DCCCLIV"   
    "280"     "CCLXXX"       "CCCCLX"    
    "792"     "DCCXCII"      "DCCCIII"   
    "959"     "CMLIX"        "CMLXI"     
    "283"     "CCLXXXIII"    "CCCCLXXIII"
    "356"     "CCCLVI"       "CCCCVI"    
    "534"     "DXXXIV"       "DCCXXIV"   
    "721"     "DCCXXI"       "DCCCII"    
      ⋮

テキスト変換関数

関数 transformText は、テキストを文字に分割して開始と停止のトークンを追加し、変換のために入力テキストの前処理とトークン化を行います。テキストを文字ではなく単語に分割してテキストを変換するには、最初の手順をスキップします。

function documents = transformText(str,startToken,stopToken)

str = strip(replace(str,""," "));
str = startToken + str + stopToken;
documents = tokenizedDocument(str,CustomTokens=[startToken stopToken]);

end

ミニバッチ前処理関数

例の「モデルの学習」の節で説明されている関数 preprocessMiniBatch は、学習用にデータを前処理します。関数は、次の手順でデータを前処理します。

  1. ミニバッチ内のすべてのソース シーケンスとターゲット シーケンスの長さを判定します。

  2. 関数 padsequences を使用し、ミニバッチ内で最長のシーケンスと同じ長さになるようにシーケンスをパディングします。

  3. シーケンスの最後の 2 つの次元を入れ替えます。

function [X,T,sequenceLengthsSource,maskTarget] = preprocessMiniBatch(sequencesSource,sequencesTarget,inputSize,outputSize)

sequenceLengthsSource = cellfun(@(x) size(x,2),sequencesSource);

X = padsequences(sequencesSource,2,PaddingValue=inputSize);
X = permute(X,[1 3 2]);

[T,maskTarget] = padsequences(sequencesTarget,2,PaddingValue=outputSize);
T = permute(T,[1 3 2]);
maskTarget = permute(maskTarget,[1 3 2]);

end

モデル損失関数

関数 modelLoss は、符号化器と復号化器のモデル パラメーター、入力データのミニバッチと入力データに対応するパディング マスク、およびドロップアウトの確率を受け取り、モデルの学習可能パラメーターについての損失と損失の勾配を返します。

function [loss,gradients] = modelLoss(parameters,X,T,...
    sequenceLengthsSource,maskTarget,dropout)

% Forward through encoder.
[Z,hiddenState] = modelEncoder(parameters.encoder,X,sequenceLengthsSource);

% Decoder Output.
doTeacherForcing = rand < 0.5;
sequenceLength = size(T,3);
Y = decoderPredictions(parameters.decoder,Z,T,hiddenState,dropout,...
    doTeacherForcing,sequenceLength);

% Masked loss.
Y = Y(:,:,1:end-1);
T = extractdata(gather(T(:,:,2:end)));
T = onehotencode(T,1,ClassNames=1:size(Y,1));

maskTarget = maskTarget(:,:,2:end);
maskTarget = repmat(maskTarget,[size(Y,1),1,1]);

loss = crossentropy(Y,T,Mask=maskTarget,Dataformat="CBT");

% Update gradients.
gradients = dlgradient(loss,parameters);

end

符号化器モデル関数

関数 modelEncoder は、入力データ、モデル パラメーター、学習の正しい出力の判断に使用するオプションのマスクを受け取り、モデル出力と LSTM の隠れ状態を返します。

sequenceLengths が空の場合、この関数が出力をマスクすることはありません。予測に関数 modelEncoder を使用する場合は、sequenceLengths に空の値を指定します。

function [Z,hiddenState] = modelEncoder(parameters,X,sequenceLengths)

% Embedding.
weights = parameters.emb.Weights;
Z = embed(X,weights,DataFormat="CBT");

% LSTM.
inputWeights = parameters.lstm.InputWeights;
recurrentWeights = parameters.lstm.RecurrentWeights;
bias = parameters.lstm.Bias;

numHiddenUnits = size(recurrentWeights, 2);
initialHiddenState = dlarray(zeros([numHiddenUnits 1]));
initialCellState = dlarray(zeros([numHiddenUnits 1]));

[Z,hiddenState] = lstm(Z,initialHiddenState,initialCellState,inputWeights, ...
    recurrentWeights,bias,DataFormat="CBT");

% Masking for training.
if ~isempty(sequenceLengths)
    miniBatchSize = size(Z,2);
    for n = 1:miniBatchSize
        hiddenState(:,n) = Z(:,n,sequenceLengths(n));
    end
end

end

復号化器モデル関数

関数 modelDecoder は、入力データ、モデル パラメーター、コンテキスト ベクトル、LSTM の初期隠れ状態、符号化器の出力、およびドロップアウトの確率を受け取り、復号化器の出力、更新されたコンテキスト ベクトル、更新された LSTM 状態、およびアテンション スコアを返します。

function [Y,context,hiddenState,attentionScores] = modelDecoder(parameters,X,context, ...
    hiddenState,Z,dropout)

% Embedding.
weights = parameters.emb.Weights;
X = embed(X,weights,DataFormat="CBT");

% RNN input.
sequenceLength = size(X,3);
Y = cat(1, X, repmat(context,[1 1 sequenceLength]));

% LSTM.
inputWeights = parameters.lstm.InputWeights;
recurrentWeights = parameters.lstm.RecurrentWeights;
bias = parameters.lstm.Bias;

initialCellState = dlarray(zeros(size(hiddenState)));

[Y,hiddenState] = lstm(Y,hiddenState,initialCellState, ...
    inputWeights,recurrentWeights,bias,DataFormat="CBT");

% Dropout.
mask = rand(size(Y),"like",Y) > dropout;
Y = Y.*mask;

% Attention.
weights = parameters.attention.Weights;
[context,attentionScores] = luongAttention(hiddenState,Z,weights);

% Concatenate.
Y = cat(1, Y, repmat(context,[1 1 sequenceLength]));

% Fully connect.
weights = parameters.fc.Weights;
bias = parameters.fc.Bias;
Y = fullyconnect(Y,weights,bias,DataFormat="CBT");

% Softmax.
Y = softmax(Y,DataFormat="CBT");

end

Luong アテンション関数

関数 luongAttention は、Luong の "一般的な" スコアリング [1] に従って、コンテキスト ベクトルとアテンション スコアを返します。これは、ドット積アテンションでクエリ、キー、および値に隠れ状態、重み付き潜在表現、および潜在表現をそれぞれ指定することと等価です。

function [context,attentionScores] = luongAttention(hiddenState,Z,weights)

numHeads = 1;
queries = hiddenState;
keys = pagemtimes(weights,Z);
values = Z;

[context,attentionScores] = attention(queries,keys,values,numHeads, ...
    Scale=1, ...
    DataFormat="CBT");

end

復号化器モデル予測関数

関数 decoderModelPredictions は、入力シーケンス、ターゲット シーケンス、隠れ状態、ドロップアウトの確率、教師強制を有効にするためのフラグ、シーケンス長を所与として、予測されたシーケンス Y を返します。

function Y = decoderPredictions(parameters,Z,T,hiddenState,dropout, ...
    doTeacherForcing,sequenceLength)

% Convert to dlarray.
T = dlarray(T);

% Initialize context.
miniBatchSize = size(T,2);
numHiddenUnits = size(Z,1);
context = zeros([numHiddenUnits miniBatchSize],"like",Z);

if doTeacherForcing
    % Forward through decoder.
    Y = modelDecoder(parameters,T,context,hiddenState,Z,dropout);
else
    % Get first time step for decoder.
    decoderInput = T(:,:,1);

    % Initialize output.
    numClasses = numel(parameters.fc.Bias);
    Y = zeros([numClasses miniBatchSize sequenceLength],"like",decoderInput);

    % Loop over time steps.
    for t = 1:sequenceLength
        % Forward through decoder.
        [Y(:,:,t), context, hiddenState] = modelDecoder(parameters,decoderInput,context, ...
            hiddenState,Z,dropout);

        % Update decoder input.
        [~, decoderInput] = max(Y(:,:,t),[],1);
    end
end

end

テキスト変換関数

関数 translateText は、ミニバッチを反復することによってテキストの配列を変換します。この関数は、モデル パラメーター、入力 string 配列、最大シーケンス長、ミニバッチ サイズ、ソースおよびターゲットの単語符号化オブジェクト、開始トークンおよび停止トークン、出力を集計するための区切り記号を入力として受け取ります。

function strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter)

% Transform text.
documentsSource = transformText(strSource,startToken,stopToken);
sequencesSource = doc2sequence(encSource,documentsSource, ...
    PaddingDirection="right", ...
    PaddingValue=encSource.NumWords + 1);

% Convert to dlarray.
X = cat(3,sequencesSource{:});
X = permute(X,[1 3 2]);
X = dlarray(X);

% Initialize output.
numObservations = numel(strSource);
strTranslated = strings(numObservations,1);

% Loop over mini-batches.
numIterations = ceil(numObservations / miniBatchSize);
for i = 1:numIterations
    idxMiniBatch = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    miniBatchSize = numel(idxMiniBatch);

    % Encode using model encoder.
    sequenceLengths = [];
    [Z, hiddenState] = modelEncoder(parameters.encoder,X(:,idxMiniBatch,:),sequenceLengths);

    % Decoder predictions.
    doTeacherForcing = false;
    dropout = 0;
    decoderInput = repmat(word2ind(encTarget,startToken),[1 miniBatchSize]);
    decoderInput = dlarray(decoderInput);
    Y = decoderPredictions(parameters.decoder,Z,decoderInput,hiddenState,dropout, ...
        doTeacherForcing,maxSequenceLength);
    [~, idxPred] = max(extractdata(Y),[],1);

    % Keep translating flag.
    idxStop = word2ind(encTarget,stopToken);
    keepTranslating = idxPred ~= idxStop;

    % Loop over time steps.
    t = 1;
    while t <= maxSequenceLength && any(keepTranslating(:,:,t))

        % Update output.
        newWords = ind2word(encTarget, idxPred(:,:,t))';
        idxUpdate = idxMiniBatch(keepTranslating(:,:,t));
        strTranslated(idxUpdate) = strTranslated(idxUpdate) + delimiter + newWords(keepTranslating(:,:,t));

        t = t + 1;
    end
end

end

参考文献

[1] Luong, Minh-Thang, Hieu Pham, and Christopher D. Manning. "Effective approaches to attention-based neural machine translation." arXiv preprint arXiv:1508.04025 (2015).

参考

(Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | | | | | | (Text Analytics Toolbox)

関連するトピック