Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

アテンションを使用した sequence-to-sequence 変換

この例では、アテンションを使用した再帰型 sequence-to-sequence 符号化器-復号化器モデルを用いて数字の文字列をローマ数字に変換する方法を説明します。

再帰型符号化器-復号化器モデルは、要旨のテキストの要約やニューラル機械翻訳のようなタスクにおける有効性が証明されています。このモデルは、通常、LSTM のような再帰層で入力データを処理する "符号化器" と、第 2 の再帰層で符号化された入力を目的の出力にマッピングする "復号化器" で構成されます。"アテンション メカニズム" を組み込んだモデルでは、変換を生成しながら、符号化された入力の一部に復号化器の焦点を当てることが可能です。

符号化器モデルについて、この例では 1 つの埋め込みとそれに続く 2 つの LSTM 演算で構成されるシンプルなネットワークを使用します。埋め込みとは、categorical トークンを数値ベクトルに変換する手法です。

復号化器モデルについて、この例では 2 つの LSTM から成る符号化器と非常によく似たネットワークを使用します。ただし、重要な違いは、復号化器にはアテンション メカニズムが含まれることです。アテンション メカニズムにより、復号化器が符号化器の出力の特定部分に "注意を払う" ことができます。

学習データの読み込み

"romanNumerals.csv" から数字とローマ数字のペアをダウンロードします。

filename = fullfile("romanNumerals.csv");

options = detectImportOptions(filename, ...
    'TextType','string', ...
    'ReadVariableNames',false);
options.VariableNames = ["Source" "Target"];
options.VariableTypes = ["string" "string"];

data = readtable(filename,options);

データを学習用とテスト用の区画に 50% ずつ分割します。

idx = randperm(size(data,1),500);
dataTrain = data(idx,:);
dataTest = data;
dataTest(idx,:) = [];

数字とローマ数字のペアの一部を表示します。

head(dataTrain)
ans=8×2 table
    Source      Target  
    ______    __________

    "228"     "CCXXVIII"
    "267"     "CCLXVII" 
    "294"     "CCXCIV"  
    "179"     "CLXXIX"  
    "396"     "CCCXCVI" 
    "2"       "II"      
    "4"       "IV"      
    "270"     "CCLXX"   

データの前処理

例の最後にリストされている関数 transformText を使用し、テキスト データの前処理を行います。関数 transformText は、テキストを文字に分割して開始と停止のトークンを追加し、変換のために入力テキストの前処理とトークン化を行います。テキストを文字ではなく単語に分割してテキストを変換するには、最初の手順をスキップします。

startToken = "<start>";
stopToken = "<stop>";

strSource = dataTrain{:,1};
documentsSource = transformText(strSource,startToken,stopToken);

wordEncoding オブジェクトを作成します。このオブジェクトは、ボキャブラリを使用し、トークンと数値インデックスを相互にマッピングします。

encSource = wordEncoding(documentsSource);

単語符号化を使用し、ソースのテキスト データを数値のシーケンスに変換します。

sequencesSource = doc2sequence(encSource, documentsSource,'PaddingDirection','none');

同じ手順でターゲット データをシーケンスに変換します。

strTarget = dataTrain{:,2};
documentsTarget = transformText(strTarget,startToken,stopToken);
encTarget = wordEncoding(documentsTarget);
sequencesTarget = doc2sequence(encTarget, documentsTarget,'PaddingDirection','none');

モデル パラメーターの初期化

モデル パラメーターを初期化します。符号化器と復号化器のそれぞれについて、128 の埋め込み次元、200 個の隠れユニットをもつ 2 つの LSTM 層、および確率 0.05 でランダムにドロップアウトするドロップアウト層を指定します。

embeddingDimension = 128;
numHiddenUnits = 200;
dropout = 0.05;

符号化器モデル パラメーターの初期化

関数 initializeGaussian を使用し、ガウスで符号化埋め込みの重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。平均値を 0、標準偏差を 0.01 に指定します。詳細については、ガウスによる初期化を参照してください。

inputSize = encSource.NumWords + 1;
sz = [embeddingDimension inputSize];
mu = 0;
sigma = 0.01;
parameters.encoder.emb.Weights = initializeGaussian(sz,mu,sigma);

符号化器の LSTM 演算に関する学習可能なパラメーターを次のように初期化します。

  • 関数 initializeGlorot を使用し、Glorot 初期化子で入力の重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、Glorot の初期化を参照してください。

  • 関数 initializeOrthogonal を使用し、直交初期化子で再帰重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、直交初期化を参照してください。

  • 関数 initializeUnitForgetGate を使用し、ユニット忘却ゲート初期化子でバイアスを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、ユニット忘却ゲートによる初期化を参照してください。

最初の符号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

sz = [4*numHiddenUnits embeddingDimension];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension;

parameters.encoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

2 番目の符号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.encoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

復号化器モデル パラメーターの初期化

関数 initializeGaussian を使用し、ガウスで符号化埋め込みの重みを初期化します。平均値を 0、標準偏差を 0.01 に指定します。

outputSize = encTarget.NumWords + 1;
sz = [embeddingDimension outputSize];
mu = 0;
sigma = 0.01;
parameters.decoder.emb.Weights = initializeGaussian(sz,mu,sigma);

関数 initializeGlorot を使用し、Glorot 初期化子でアテンション メカニズムの重みを初期化します。

sz = [numHiddenUnits numHiddenUnits];
numOut = numHiddenUnits;
numIn = numHiddenUnits;
parameters.decoder.attn.Weights = initializeGlorot(sz,numOut,numIn);

復号化器の LSTM 演算に関する学習可能なパラメーターを次のように初期化します。

  • 関数 initializeGlorot を使用し、Glorot 初期化子で入力の重みを初期化します。

  • 関数 initializeOrthogonal を使用し、直交初期化子で再帰重みを初期化します。

  • 関数 initializeUnitForgetGate を使用し、ユニット忘却ゲート初期化子でバイアスを初期化します。

最初の復号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

sz = [4*numHiddenUnits embeddingDimension+numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension + numHiddenUnits;

parameters.decoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

2 番目の復号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.decoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

復号化器の全結合演算に関する学習可能なパラメーターを初期化します。

  • Glorot 初期化子を使用して重みを初期化します。

  • 関数 initializeZeros を使用し、ゼロでバイアスを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、ゼロでの初期化を参照してください。

sz = [outputSize 2*numHiddenUnits];
numOut = outputSize;
numIn = 2*numHiddenUnits;

parameters.decoder.fc.Weights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.fc.Bias = initializeZeros([outputSize 1]);

モデルの関数の定義

この例の最後にリストされている関数 modelEncoder および modelDecoder を作成し、符号化器および復号化器モデルの出力をそれぞれ計算します。

この例の符号化器モデル関数の節にリストされている関数 modelEncoder は、入力データ、モデル パラメーター、学習の正しい出力の判断に使用するオプションのマスクを受け取り、モデルの出力と LSTM の隠れ状態を返します。

この例の復号化器モデル関数の節にリストされている関数 modelDecoder は、入力データ、モデル パラメーター、コンテキスト ベクトル、LSTM の初期隠れ状態、符号化器の出力、およびドロップアウトの確率を受け取り、復号化器の出力、更新されたコンテキスト ベクトル、更新された LSTM 状態、およびアテンション スコアを返します。

モデル勾配関数の定義

この例のモデル勾配関数の節にリストされている関数 modelGradients を作成します。この関数は、符号化器と復号化器のモデル パラメーター、入力データのミニバッチと入力データに対応するパディング マスク、およびドロップアウトの確率を受け取り、学習可能なモデル パラメーターについての損失の勾配と、対応する損失を返します。

学習オプションの指定

ミニバッチ サイズを 32、学習率を 0.002 として 75 エポック学習させます。

miniBatchSize = 32;
numEpochs = 75;
learnRate = 0.002;

Adam のオプションを初期化します。

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

モデルの学習

カスタム学習ループを使用してモデルに学習させます。

シーケンス長の昇順に並べ替えたシーケンスで学習させます。その結果、バッチ内のシーケンスのシーケンス長がほぼ同じになり、小さいシーケンスのバッチを長いシーケンスのバッチよりも先にモデルの更新に確実に使用できます。

シーケンスを長さで並べ替えます。

sequenceLengths = cellfun(@(sequence) size(sequence,2), sequencesSource);
[~,idx] = sort(sequenceLengths);
sequencesSource = sequencesSource(idx);
sequencesTarget = sequencesTarget(idx);

学習の進行状況プロットを初期化します。

figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])

xlabel("Iteration")
ylabel("Loss")
grid on

関数 adamupdate 用の値を初期化します。

trailingAvg = [];
trailingAvgSq = [];

モデルに学習させます。各ミニバッチで次を行います。

  • シーケンスのミニバッチを読み取り、パディングを追加します。

  • データを dlarray に変換。

  • 損失と勾配を計算。

  • 関数 adamupdate を使用して符号化器および復号化器のモデル パラメーターを更新。

  • 学習の進行状況プロットを更新します。

numObservations = numel(sequencesSource);
numIterationsPerEpoch = floor(numObservations/miniBatchSize);

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
        
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        
        % Read mini-batch of data and pad.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        [X, sequenceLengthsSource] = padSequences(sequencesSource(idx), inputSize);
        [T, sequenceLengthsTarget] = padSequences(sequencesTarget(idx), outputSize);

        % Convert mini-batch of data to dlarray.
        dlX = dlarray(X);
        
        % Compute loss and gradients.
        [gradients, loss] = dlfeval(@modelGradients, parameters, dlX, T, ...
            sequenceLengthsSource, sequenceLengthsTarget, dropout);
        
        % Update parameters using adamupdate.
        [parameters, trailingAvg, trailingAvgSq] = adamupdate(parameters, gradients, trailingAvg, trailingAvgSq, ...
            iteration, learnRate, gradientDecayFactor, squaredGradientDecayFactor);
        
        % Display the training progress.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        addpoints(lineLossTrain,iteration,double(gather(loss)))
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

変換の生成

学習済みのモデルを使用して新しいデータの変換を生成するには、学習時と同じ手順を使用してテキスト データを数字のシーケンスに変換し、そのシーケンスを符号化器-復号化器モデルに入力し、トークン インデックスを使用して結果のシーケンスをテキストに変換し直します。

学習時と同じ手順を使用してテキスト データを前処理します。この例の最後にリストされている関数 transformText を使用し、テキストを文字に分割して開始と停止のトークンを追加します。

strSource = dataTest{:,1};
strTarget = dataTest{:,2};

関数 modelPredictions を使用してテキストを変換します。

maxSequenceLength = 10;
delimiter = "";

strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter);

テスト ソース テキスト、ターゲット テキスト、および変換結果を格納する table を作成します。

tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = strTranslated;

ランダムに選択した変換結果を表示します。

idx = randperm(size(dataTest,1),miniBatchSize);
tbl(idx,:)
ans=32×3 table
    Source      Target      Translated
    ______    __________    __________

    "936"     "CMXXXVI"     "CMXXXVI" 
    "423"     "CDXXIII"     "CDXXIII" 
    "981"     "CMLXXXI"     "CMLXXXIX"
    "200"     "CC"          "CC"      
    "224"     "CCXXIV"      "CCXXIV"  
    "56"      "LVI"         "DLVI"    
    "330"     "CCCXXX"      "CCCXXX"  
    "336"     "CCCXXXVI"    "CCCXXXVI"
    "524"     "DXXIV"       "DXXIV"   
    "860"     "DCCCLX"      "DCCCLX"  
    "318"     "CCCXVIII"    "CCCXVIII"
    "902"     "CMII"        "CMII"    
    "681"     "DCLXXXI"     "DCLXXXI" 
    "299"     "CCXCIX"      "CCXCIX"  
    "931"     "CMXXXI"      "CMXXXIX" 
    "859"     "DCCCLIX"     "DCCCLIX" 
      ⋮

テキスト変換関数

関数 transformText は、テキストを文字に分割して開始と停止のトークンを追加し、変換のために入力テキストの前処理とトークン化を行います。テキストを文字ではなく単語に分割してテキストを変換するには、最初の手順をスキップします。

function documents = transformText(str,startToken,stopToken)

str = strip(replace(str,""," "));
str = startToken + str + stopToken;
documents = tokenizedDocument(str,'CustomTokens',[startToken stopToken]);

end

シーケンス パディング関数

関数 padSequences はシーケンスのミニバッチとパディング値を受け取り、パディングされたシーケンスと対応するパディング マスクを返します。

function [X, sequenceLengths] = padSequences(sequences, paddingValue)

% Initialize mini-batch with padding.
numObservations = size(sequences,1);
sequenceLengths = cellfun(@(x) size(x,2), sequences);
maxLength = max(sequenceLengths);
X = repmat(paddingValue, [1 numObservations maxLength]);

% Insert sequences.
for n = 1:numObservations
    X(1,n,1:sequenceLengths(n)) = sequences{n};
end

end

モデル勾配関数

関数 modelGradients は、符号化器と復号化器のモデル パラメーター、入力データのミニバッチと入力データに対応するパディング マスク、およびドロップアウトの確率を受け取り、学習可能なモデル パラメーターについての損失の勾配と、対応する損失を返します。

function [gradients, loss] = modelGradients(parameters, dlX, T, ...
    sequenceLengthsSource, sequenceLengthsTarget, dropout)

% Forward through encoder.
[dlZ, hiddenState] = modelEncoder(parameters.encoder, dlX, sequenceLengthsSource);

% Decoder Output.
doTeacherForcing = rand < 0.5;
sequenceLength = size(T,3);
dlY = decoderPredictions(parameters.decoder,dlZ,T,hiddenState,dropout,...
     doTeacherForcing,sequenceLength);

% Masked loss.
dlY = dlY(:,:,1:end-1);
T = T(:,:,2:end);
T = onehotencode(T,1,'ClassNames',1:size(dlY,1));
loss = maskedCrossEntropy(dlY,T,sequenceLengthsTarget-1);

% Update gradients.
gradients = dlgradient(loss, parameters);

% For plotting, return loss normalized by sequence length.
loss = extractdata(loss) ./ sequenceLength;

end

符号化器モデル関数

関数 modelEncoder は、入力データ、モデル パラメーター、学習の正しい出力の判断に使用するオプションのマスクを受け取り、モデル出力と LSTM の隠れ状態を返します。

sequenceLengths が空の場合、この関数が出力をマスクすることはありません。予測に関数 modelEncoder を使用する場合は、sequenceLengths に空の値を指定します。

function [dlZ, hiddenState] = modelEncoder(parametersEncoder, dlX, sequenceLengths)

% Embedding.
weights = parametersEncoder.emb.Weights;
dlZ = embed(dlX,weights,'DataFormat','CBT');

% LSTM 1.
inputWeights = parametersEncoder.lstm1.InputWeights;
recurrentWeights = parametersEncoder.lstm1.RecurrentWeights;
bias = parametersEncoder.lstm1.Bias;

numHiddenUnits = size(recurrentWeights, 2);
initialHiddenState = dlarray(zeros([numHiddenUnits 1]));
initialCellState = dlarray(zeros([numHiddenUnits 1]));

dlZ = lstm(dlZ, initialHiddenState, initialCellState, inputWeights, ...
    recurrentWeights, bias, 'DataFormat', 'CBT');

% LSTM 2.
inputWeights = parametersEncoder.lstm2.InputWeights;
recurrentWeights = parametersEncoder.lstm2.RecurrentWeights;
bias = parametersEncoder.lstm2.Bias;

[dlZ, hiddenState] = lstm(dlZ,initialHiddenState, initialCellState, ...
    inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Masking for training.
if ~isempty(sequenceLengths)
    miniBatchSize = size(dlZ,2);
    for n = 1:miniBatchSize
        hiddenState(:,n) = dlZ(:,n,sequenceLengths(n));
    end
end

end

復号化器モデル関数

関数 modelDecoder は、入力データ、モデル パラメーター、コンテキスト ベクトル、LSTM の初期隠れ状態、符号化器の出力、およびドロップアウトの確率を受け取り、復号化器の出力、更新されたコンテキスト ベクトル、更新された LSTM 状態、およびアテンション スコアを返します。

function [dlY, context, hiddenState, attentionScores] = modelDecoder(parametersDecoder, dlX, context, ...
    hiddenState, dlZ, dropout)

% Embedding.
weights = parametersDecoder.emb.Weights;
dlX = embed(dlX, weights,'DataFormat','CBT');

% RNN input.
sequenceLength = size(dlX,3);
dlY = cat(1, dlX, repmat(context, [1 1 sequenceLength]));

% LSTM 1.
inputWeights = parametersDecoder.lstm1.InputWeights;
recurrentWeights = parametersDecoder.lstm1.RecurrentWeights;
bias = parametersDecoder.lstm1.Bias;

initialCellState = dlarray(zeros(size(hiddenState)));

dlY = lstm(dlY, hiddenState, initialCellState, inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Dropout.
mask = ( rand(size(dlY), 'like', dlY) > dropout );
dlY = dlY.*mask;

% LSTM 2.
inputWeights = parametersDecoder.lstm2.InputWeights;
recurrentWeights = parametersDecoder.lstm2.RecurrentWeights;
bias = parametersDecoder.lstm2.Bias;

[dlY, hiddenState] = lstm(dlY, hiddenState, initialCellState,inputWeights, recurrentWeights, bias, 'DataFormat', 'CBT');

% Attention.
weights = parametersDecoder.attn.Weights;
[attentionScores, context] = attention(hiddenState, dlZ, weights);

% Concatenate.
dlY = cat(1, dlY, repmat(context, [1 1 sequenceLength]));

% Fully connect.
weights = parametersDecoder.fc.Weights;
bias = parametersDecoder.fc.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CBT');

% Softmax.
dlY = softmax(dlY,'DataFormat','CBT');

end

アテンション関数

関数 attention は、Luong の "一般的な" スコアリングに従ったアテンション スコアと、更新されたコンテキスト ベクトルを返します。各タイム ステップにおけるエネルギーは、隠れ状態および学習可能なアテンションの重みと、符号化器の出力のドット積です。

function [attentionScores, context] = attention(hiddenState, encoderOutputs, weights)

% Initialize attention energies.
[miniBatchSize, sequenceLength] = size(encoderOutputs, 2:3);
attentionEnergies = zeros([sequenceLength miniBatchSize],'like',hiddenState);

% Attention energies.
hWX = hiddenState .* pagemtimes(weights,encoderOutputs);
for tt = 1:sequenceLength
    attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1);
end

% Attention scores.
attentionScores = softmax(attentionEnergies, 'DataFormat', 'CB');

% Context.
encoderOutputs = permute(encoderOutputs, [1 3 2]);
attentionScores = permute(attentionScores,[1 3 2]);
context = pagemtimes(encoderOutputs,attentionScores);
context = squeeze(context);

end

マスクされた交差エントロピー損失

関数 maskedCrossEntropy はシーケンス長の指定されたベクトルを使用して、指定された入力シーケンスとターゲット シーケンスの間の損失を、パディングを含むタイム ステップを無視して計算します。

function loss = maskedCrossEntropy(dlY,T,sequenceLengths)

% Initialize loss.
loss = 0;

% Loop over mini-batch.
miniBatchSize = size(dlY,2);
for n = 1:miniBatchSize
    idx = 1:sequenceLengths(n);
    loss = loss + crossentropy(dlY(:,n,idx), T(:,n,idx),'DataFormat','CBT');
end

% Normalize.
loss = loss / miniBatchSize;

end

復号化器モデル予測関数

関数 decoderModelPredictions は、入力シーケンス、ターゲット シーケンス、隠れ状態、ドロップアウトの確率、教師強制を有効にするためのフラグ、シーケンス長を所与として、予測されたシーケンス dlY を返します。

function dlY = decoderPredictions(parametersDecoder,dlZ,T,hiddenState,dropout, ...
    doTeacherForcing,sequenceLength)

% Convert to dlarray.
dlT = dlarray(T);

% Initialize context.
miniBatchSize = size(dlT,2);
numHiddenUnits = size(dlZ,1);
context = zeros([numHiddenUnits miniBatchSize],'like',dlZ);

if doTeacherForcing
    % Forward through decoder.
    dlY = modelDecoder(parametersDecoder, dlT, context, hiddenState, dlZ, dropout);
else
    % Get first time step for decoder.
    decoderInput = dlT(:,:,1);
    
    % Initialize output.
    numClasses = numel(parametersDecoder.fc.Bias);
    dlY = zeros([numClasses miniBatchSize sequenceLength],'like',decoderInput);
    
    % Loop over time steps.
    for t = 1:sequenceLength
        % Forward through decoder.
        [dlY(:,:,t), context, hiddenState] = modelDecoder(parametersDecoder, decoderInput, context, ...
            hiddenState, dlZ, dropout);
        
        % Update decoder input.
        [~, decoderInput] = max(dlY(:,:,t),[],1);
    end
end

end

テキスト変換関数

関数 translateText は、ミニバッチを反復することによってテキストの配列を変換します。この関数は、モデル パラメーター、入力 string 配列、最大シーケンス長、ミニバッチ サイズ、ソースおよびターゲットの単語符号化オブジェクト、開始トークンおよび停止トークン、出力を集計するための区切り記号を入力として受け取ります。

function strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter)

% Transform text.
documentsSource = transformText(strSource,startToken,stopToken);
sequencesSource = doc2sequence(encSource,documentsSource, ...
    'PaddingDirection','right', ...
    'PaddingValue',encSource.NumWords + 1);

% Convert to dlarray.
X = cat(3,sequencesSource{:});
X = permute(X,[1 3 2]);
dlX = dlarray(X);

% Initialize output.
numObservations = numel(strSource);
strTranslated = strings(numObservations,1);

% Loop over mini-batches.
numIterations = ceil(numObservations / miniBatchSize);
for i = 1:numIterations
    idxMiniBatch = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    miniBatchSize = numel(idxMiniBatch);
    
    % Encode using model encoder.
    sequenceLengths = [];
    [dlZ, hiddenState] = modelEncoder(parameters.encoder, dlX(:,idxMiniBatch,:), sequenceLengths);
        
    % Decoder predictions.
    doTeacherForcing = false;
    dropout = 0;
    decoderInput = repmat(word2ind(encTarget,startToken),[1 miniBatchSize]);
    decoderInput = dlarray(decoderInput);
    dlY = decoderPredictions(parameters.decoder,dlZ,decoderInput,hiddenState,dropout, ...
        doTeacherForcing,maxSequenceLength);
    [~, idxPred] = max(extractdata(dlY), [], 1);
    
    % Keep translating flag.
    idxStop = word2ind(encTarget,stopToken);
    keepTranslating = idxPred ~= idxStop;
     
    % Loop over time steps.
    t = 1;
    while t <= maxSequenceLength && any(keepTranslating(:,:,t))
    
        % Update output.
        newWords = ind2word(encTarget, idxPred(:,:,t))';
        idxUpdate = idxMiniBatch(keepTranslating(:,:,t));
        strTranslated(idxUpdate) = strTranslated(idxUpdate) + delimiter + newWords(keepTranslating(:,:,t));
        
        t = t + 1;
    end
end

end

参考

(Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | | | | | | (Text Analytics Toolbox)

関連するトピック