Main Content

アテンションを使用した sequence-to-sequence 変換

この例では、アテンションを使用した再帰型 sequence-to-sequence 符号化器-復号化器モデルを用いて数字の文字列をローマ数字に変換する方法を説明します。

再帰型符号化器-復号化器モデルは、要旨のテキストの要約やニューラル機械翻訳のようなタスクにおける有効性が証明されています。このモデルは、通常、LSTM のような再帰層で入力データを処理する "符号化器" と、第 2 の再帰層で符号化された入力を目的の出力にマッピングする "復号化器" で構成されます。"アテンション メカニズム" を組み込んだモデルでは、変換を生成しながら、符号化された入力の一部に復号化器の焦点を当てることが可能です。

符号化器モデルについて、この例では 1 つの埋め込みとそれに続く 2 つの LSTM 演算で構成されるシンプルなネットワークを使用します。埋め込みとは、categorical トークンを数値ベクトルに変換する手法です。

復号化器モデルについて、この例では 2 つの LSTM から成る符号化器と非常によく似たネットワークを使用します。ただし、重要な違いは、復号化器にはアテンション メカニズムが含まれることです。アテンション メカニズムにより、復号化器が符号化器の出力の特定部分に "注意を払う" ことができます。

学習データの読み込み

"romanNumerals.csv" から数字とローマ数字のペアをダウンロードします。

filename = fullfile("romanNumerals.csv");

options = detectImportOptions(filename, ...
    TextType="string", ...
    ReadVariableNames=false);
options.VariableNames = ["Source" "Target"];
options.VariableTypes = ["string" "string"];

data = readtable(filename,options);

データを学習用とテスト用の区画に 50% ずつ分割します。

idx = randperm(size(data,1),500);
dataTrain = data(idx,:);
dataTest = data;
dataTest(idx,:) = [];

数字とローマ数字のペアの一部を表示します。

head(dataTrain)
ans=8×2 table
    Source     Target  
    ______    _________

    "713"     "DCCXIII"
    "752"     "DCCLII" 
    "434"     "CDXXXIV"
    "641"     "DCXLI"  
    "68"      "LXVIII" 
    "87"      "LXXXVII"
    "469"     "CDLXIX" 
    "715"     "DCCXV"  

データの前処理

例の最後にリストされている関数 transformText を使用し、テキスト データの前処理を行います。関数 transformText は、テキストを文字に分割して開始と停止のトークンを追加し、変換のために入力テキストの前処理とトークン化を行います。テキストを文字ではなく単語に分割してテキストを変換するには、最初の手順をスキップします。

startToken = "<start>";
stopToken = "<stop>";

strSource = dataTrain.Source;
documentsSource = transformText(strSource,startToken,stopToken);

wordEncoding オブジェクトを作成します。このオブジェクトは、ボキャブラリを使用し、トークンと数値インデックスを相互にマッピングします。

encSource = wordEncoding(documentsSource);

単語符号化を使用し、ソースのテキスト データを数値のシーケンスに変換します。

sequencesSource = doc2sequence(encSource,documentsSource,PaddingDirection="none");

同じ手順でターゲット データをシーケンスに変換します。

strTarget = dataTrain.Target;
documentsTarget = transformText(strTarget,startToken,stopToken);
encTarget = wordEncoding(documentsTarget);
sequencesTarget = doc2sequence(encTarget,documentsTarget,PaddingDirection="none");

シーケンスを長さで並べ替えます。シーケンスの長さを増やし、並べ替えたシーケンスを使用して学習させると、バッチ内のシーケンスのシーケンス長がほぼ同じになり、小さいシーケンスのバッチを長いシーケンスのバッチよりも先にモデルの更新に確実に使用できます。

sequenceLengths = cellfun(@(sequence) size(sequence,2),sequencesSource);
[~,idx] = sort(sequenceLengths);
sequencesSource = sequencesSource(idx);
sequencesTarget = sequencesTarget(idx);

ソース データとターゲット データを含む arrayDatastore オブジェクトを作成し、関数 combine を使用してそれらを結合します。

sequencesSourceDs = arrayDatastore(sequencesSource,OutputType="same");
sequencesTargetDs = arrayDatastore(sequencesTarget,OutputType="same");

sequencesDs = combine(sequencesSourceDs,sequencesTargetDs);

モデル パラメーターの初期化

モデル パラメーターを初期化します。符号化器と復号化器のそれぞれについて、128 の埋め込み次元、200 個の隠れユニットをもつ 2 つの LSTM 層、および確率 0.05 でランダムにドロップアウトするドロップアウト層を指定します。

embeddingDimension = 128;
numHiddenUnits = 200;
dropout = 0.05;

符号化器モデル パラメーターの初期化

関数 initializeGaussian を使用し、ガウスで符号化埋め込みの重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。平均値を 0、標準偏差を 0.01 に指定します。詳細については、ガウスによる初期化を参照してください。

inputSize = encSource.NumWords + 1;
sz = [embeddingDimension inputSize];
mu = 0;
sigma = 0.01;
parameters.encoder.emb.Weights = initializeGaussian(sz,mu,sigma);

符号化器の LSTM 演算に関する学習可能なパラメーターを次のように初期化します。

  • 関数 initializeGlorot を使用し、Glorot 初期化子で入力の重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、Glorot の初期化を参照してください。

  • 関数 initializeOrthogonal を使用し、直交初期化子で再帰重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、直交初期化を参照してください。

  • 関数 initializeUnitForgetGate を使用し、ユニット忘却ゲート初期化子でバイアスを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、ユニット忘却ゲートによる初期化を参照してください。

最初の符号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

sz = [4*numHiddenUnits embeddingDimension];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension;

parameters.encoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

2 番目の符号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.encoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.encoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.encoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

復号化器モデル パラメーターの初期化

関数 initializeGaussian を使用し、ガウスで符号化埋め込みの重みを初期化します。平均値を 0、標準偏差を 0.01 に指定します。

outputSize = encTarget.NumWords + 1;
sz = [embeddingDimension outputSize];
mu = 0;
sigma = 0.01;
parameters.decoder.emb.Weights = initializeGaussian(sz,mu,sigma);

関数 initializeGlorot を使用し、Glorot 初期化子でアテンション メカニズムの重みを初期化します。

sz = [numHiddenUnits numHiddenUnits];
numOut = numHiddenUnits;
numIn = numHiddenUnits;
parameters.decoder.attn.Weights = initializeGlorot(sz,numOut,numIn);

復号化器の LSTM 演算に関する学習可能なパラメーターを次のように初期化します。

  • 関数 initializeGlorot を使用し、Glorot 初期化子で入力の重みを初期化します。

  • 関数 initializeOrthogonal を使用し、直交初期化子で再帰重みを初期化します。

  • 関数 initializeUnitForgetGate を使用し、ユニット忘却ゲート初期化子でバイアスを初期化します。

最初の復号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

sz = [4*numHiddenUnits embeddingDimension+numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension + numHiddenUnits;

parameters.decoder.lstm1.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm1.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm1.Bias = initializeUnitForgetGate(numHiddenUnits);

2 番目の復号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

sz = [4*numHiddenUnits numHiddenUnits];
numOut = 4*numHiddenUnits;
numIn = numHiddenUnits;

parameters.decoder.lstm2.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.lstm2.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.decoder.lstm2.Bias = initializeUnitForgetGate(numHiddenUnits);

復号化器の全結合演算に関する学習可能なパラメーターを初期化します。

  • Glorot 初期化子を使用して重みを初期化します。

  • 関数 initializeZeros を使用し、ゼロでバイアスを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、ゼロでの初期化を参照してください。

sz = [outputSize 2*numHiddenUnits];
numOut = outputSize;
numIn = 2*numHiddenUnits;

parameters.decoder.fc.Weights = initializeGlorot(sz,numOut,numIn);
parameters.decoder.fc.Bias = initializeZeros([outputSize 1]);

モデルの関数の定義

この例の最後にリストされている関数 modelEncoder および modelDecoder を作成し、符号化器および復号化器モデルの出力をそれぞれ計算します。

この例の符号化器モデル関数の節にリストされている関数 modelEncoder は、入力データ、モデル パラメーター、学習の正しい出力の判断に使用するオプションのマスクを受け取り、モデルの出力と LSTM の隠れ状態を返します。

この例の復号化器モデル関数の節にリストされている関数 modelDecoder は、入力データ、モデル パラメーター、コンテキスト ベクトル、LSTM の初期隠れ状態、符号化器の出力、およびドロップアウトの確率を受け取り、復号化器の出力、更新されたコンテキスト ベクトル、更新された LSTM 状態、およびアテンション スコアを返します。

モデル損失関数の定義

この例のモデル損失関数の節にリストされている関数 modelLoss を作成します。この関数は、符号化器と復号化器のモデル パラメーター、入力データのミニバッチと入力データに対応するパディング マスク、およびドロップアウトの確率を受け取り、モデル内の学習可能パラメーターについての損失と損失の勾配を返します。

学習オプションの指定

ミニバッチ サイズを 32、学習率を 0.002 として 75 エポック学習させます。

miniBatchSize = 32;
numEpochs = 75;
learnRate = 0.002;

Adam のオプションを初期化します。

gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

モデルの学習

カスタム学習ループを使用してモデルに学習させます。minibatchqueue を使用し、学習中にイメージのミニバッチを処理および管理します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、ミニバッチ内のすべてのシーケンスの長さを求め、ソース シーケンスとターゲット シーケンスについて、それぞれのシーケンスを最長のシーケンスと同じ長さにパディングします。

  • パディングされたシーケンスの 2 番目と 3 番目の次元を入れ替えます。

  • 基になるデータ型が single である、ミニバッチ変数の書式化されていない dlarray オブジェクトを返します。他のすべての出力は、データ型 single の配列です。

  • GPU が利用できる場合、GPU で学習を行います。利用可能な場合は、GPU 上のすべてのミニバッチ変数を返します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリースごとの GPU サポートを参照してください。

minibatchqueue オブジェクトは、ミニバッチごとに、ソース シーケンス、ターゲット シーケンス、ミニバッチ内のすべてのソース シーケンスの長さ、およびターゲット シーケンスのシーケンス マスクの 4 つの出力引数を返します。

numMiniBatchOutputs = 4;

mbq = minibatchqueue(sequencesDs,numMiniBatchOutputs,...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@(x,t) preprocessMiniBatch(x,t,inputSize,outputSize));

学習の進行状況プロットを初期化します。

figure
C = colororder;
lineLossTrain = animatedline(Color=C(2,:));
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

関数 adamupdate 用の値を初期化します。

trailingAvg = [];
trailingAvgSq = [];

モデルに学習させます。各ミニバッチで次を行います。

  • パディングされたシーケンスのミニバッチの読み取り。

  • 損失と勾配の計算。

  • 関数 adamupdate を使用した符号化器および復号化器のモデル パラメーターの更新。

  • 学習の進行状況プロットの更新。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    reset(mbq);
        
    % Loop over mini-batches.
    while hasdata(mbq)
    
        iteration = iteration + 1;
        
        [X,T,sequenceLengthsSource,maskSequenceTarget] = next(mbq);
        
        % Compute loss and gradients.
        [loss,gradients] = dlfeval(@modelLoss,parameters,X,T,sequenceLengthsSource,...
            maskSequenceTarget,dropout);
        
        % Update parameters using adamupdate.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients,trailingAvg,trailingAvgSq,...
            iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);
        
        % Display the training progress. Normalize loss by sequence length.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        loss = double(loss);
        loss = loss ./ size(T,3);
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow
    end
end

変換の生成

学習済みのモデルを使用して新しいデータの変換を生成するには、学習時と同じ手順を使用してテキスト データを数字のシーケンスに変換し、そのシーケンスを符号化器-復号化器モデルに入力し、トークン インデックスを使用して結果のシーケンスをテキストに変換し直します。

学習時と同じ手順を使用してテキスト データを前処理します。この例の最後にリストされている関数 transformText を使用し、テキストを文字に分割して開始と停止のトークンを追加します。

strSource = dataTest.Source;
strTarget = dataTest.Target;

関数 modelPredictions を使用してテキストを変換します。

maxSequenceLength = 10;
delimiter = "";

strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter);

テスト ソース テキスト、ターゲット テキスト、および変換結果を格納する table を作成します。

tbl = table;
tbl.Source = strSource;
tbl.Target = strTarget;
tbl.Translated = strTranslated;

ランダムに選択した変換結果を表示します。

idx = randperm(size(dataTest,1),miniBatchSize);
tbl(idx,:)
ans=32×3 table
    Source       Target        Translated 
    ______    ____________    ____________

    "318"     "CCCXVIII"      "CCCXVIII"  
    "461"     "CDLXI"         "DCLII"     
    "71"      "LXXI"          "DCXXI"     
    "520"     "DXX"           "DCX"       
    "490"     "CDXC"          "CDX"       
    "526"     "DXXVI"         "DCXVI"     
    "898"     "DCCCXCVIII"    "DCCCXCVIII"
    "67"      "LXVII"         "DCLXVII"   
    "4"       "IV"            "CLII"      
    "53"      "LIII"          "DCXIII"    
    "116"     "CXVI"          "CXVI"      
    "362"     "CCCLXII"       "CCCLIII"   
    "709"     "DCCIX"         "DCCII"     
    "291"     "CCXCI"         "CCXCI"     
    "390"     "CCCXC"         "CCCX"      
    "336"     "CCCXXXVI"      "CCCXXVI"   
      ⋮

テキスト変換関数

関数 transformText は、テキストを文字に分割して開始と停止のトークンを追加し、変換のために入力テキストの前処理とトークン化を行います。テキストを文字ではなく単語に分割してテキストを変換するには、最初の手順をスキップします。

function documents = transformText(str,startToken,stopToken)

str = strip(replace(str,""," "));
str = startToken + str + stopToken;
documents = tokenizedDocument(str,CustomTokens=[startToken stopToken]);

end

ミニバッチ前処理関数

例の「モデルの学習」の節で説明されている関数 preprocessMiniBatch は、学習用にデータを前処理します。関数は、次の手順でデータを前処理します。

  1. ミニバッチ内のすべてのソース シーケンスとターゲット シーケンスの長さを判定します。

  2. 関数 padsequences を使用し、ミニバッチ内で最長のシーケンスと同じ長さになるようにシーケンスをパディングします。

  3. シーケンスの最後の 2 つの次元を入れ替えます。

function [X,T,sequenceLengthsSource,maskTarget] = preprocessMiniBatch(sequencesSource,sequencesTarget,inputSize,outputSize)

sequenceLengthsSource = cellfun(@(x) size(x,2),sequencesSource);

X = padsequences(sequencesSource,2,PaddingValue=inputSize);
X = permute(X,[1 3 2]);

[T,maskTarget] = padsequences(sequencesTarget,2,PaddingValue=outputSize);
T = permute(T,[1 3 2]);
maskTarget = permute(maskTarget,[1 3 2]);

end

モデル損失関数

関数 modelLoss は、符号化器と復号化器のモデル パラメーター、入力データのミニバッチと入力データに対応するパディング マスク、およびドロップアウトの確率を受け取り、モデルの学習可能パラメーターについての損失と損失の勾配を返します。

function [loss,gradients] = modelLoss(parameters,X,T,...
    sequenceLengthsSource,maskTarget,dropout)

% Forward through encoder.
[Z,hiddenState] = modelEncoder(parameters.encoder,X,sequenceLengthsSource);

% Decoder Output.
doTeacherForcing = rand < 0.5;
sequenceLength = size(T,3);
Y = decoderPredictions(parameters.decoder,Z,T,hiddenState,dropout,...
     doTeacherForcing,sequenceLength);

% Masked loss.
Y = Y(:,:,1:end-1);
T = extractdata(gather(T(:,:,2:end)));
T = onehotencode(T,1,ClassNames=1:size(Y,1));

maskTarget = maskTarget(:,:,2:end); 
maskTarget = repmat(maskTarget,[size(Y,1),1,1]);

loss = crossentropy(Y,T,Mask=maskTarget,Dataformat="CBT");

% Update gradients.
gradients = dlgradient(loss,parameters);

end

符号化器モデル関数

関数 modelEncoder は、入力データ、モデル パラメーター、学習の正しい出力の判断に使用するオプションのマスクを受け取り、モデル出力と LSTM の隠れ状態を返します。

sequenceLengths が空の場合、この関数が出力をマスクすることはありません。予測に関数 modelEncoder を使用する場合は、sequenceLengths に空の値を指定します。

function [Z,hiddenState] = modelEncoder(parametersEncoder,X,sequenceLengths)

% Embedding.
weights = parametersEncoder.emb.Weights;
Z = embed(X,weights,DataFormat="CBT");

% LSTM 1.
inputWeights = parametersEncoder.lstm1.InputWeights;
recurrentWeights = parametersEncoder.lstm1.RecurrentWeights;
bias = parametersEncoder.lstm1.Bias;

numHiddenUnits = size(recurrentWeights, 2);
initialHiddenState = dlarray(zeros([numHiddenUnits 1]));
initialCellState = dlarray(zeros([numHiddenUnits 1]));

Z = lstm(Z,initialHiddenState,initialCellState,inputWeights, ...
    recurrentWeights,bias,DataFormat="CBT");

% LSTM 2.
inputWeights = parametersEncoder.lstm2.InputWeights;
recurrentWeights = parametersEncoder.lstm2.RecurrentWeights;
bias = parametersEncoder.lstm2.Bias;

[Z, hiddenState] = lstm(Z,initialHiddenState,initialCellState, ...
    inputWeights,recurrentWeights,bias,DataFormat="CBT");

% Masking for training.
if ~isempty(sequenceLengths)
    miniBatchSize = size(Z,2);
    for n = 1:miniBatchSize
        hiddenState(:,n) = Z(:,n,sequenceLengths(n));
    end
end

end

復号化器モデル関数

関数 modelDecoder は、入力データ、モデル パラメーター、コンテキスト ベクトル、LSTM の初期隠れ状態、符号化器の出力、およびドロップアウトの確率を受け取り、復号化器の出力、更新されたコンテキスト ベクトル、更新された LSTM 状態、およびアテンション スコアを返します。

function [Y, context, hiddenState, attentionScores] = modelDecoder(parametersDecoder,X,context, ...
    hiddenState,Z,dropout)

% Embedding.
weights = parametersDecoder.emb.Weights;
X = embed(X,weights,DataFormat="CBT");

% RNN input.
sequenceLength = size(X,3);
Y = cat(1, X, repmat(context, [1 1 sequenceLength]));

% LSTM 1.
inputWeights = parametersDecoder.lstm1.InputWeights;
recurrentWeights = parametersDecoder.lstm1.RecurrentWeights;
bias = parametersDecoder.lstm1.Bias;

initialCellState = dlarray(zeros(size(hiddenState)));

Y = lstm(Y,hiddenState,initialCellState,inputWeights,recurrentWeights,bias,DataFormat="CBT");

% Dropout.
mask = rand(size(Y),"like",Y) > dropout;
Y = Y.*mask;

% LSTM 2.
inputWeights = parametersDecoder.lstm2.InputWeights;
recurrentWeights = parametersDecoder.lstm2.RecurrentWeights;
bias = parametersDecoder.lstm2.Bias;

[Y, hiddenState] = lstm(Y,hiddenState,initialCellState,inputWeights,recurrentWeights,bias,DataFormat="CBT");

% Attention.
weights = parametersDecoder.attn.Weights;
[attentionScores, context] = attention(hiddenState, Z, weights);

% Concatenate.
Y = cat(1, Y, repmat(context,[1 1 sequenceLength]));

% Fully connect.
weights = parametersDecoder.fc.Weights;
bias = parametersDecoder.fc.Bias;
Y = fullyconnect(Y,weights,bias,DataFormat="CBT");

% Softmax.
Y = softmax(Y,DataFormat="CBT");

end

アテンション関数

関数 attention は、Luong の "一般的な" スコアリングに従ったアテンション スコアと、更新されたコンテキスト ベクトルを返します。各タイム ステップにおけるエネルギーは、隠れ状態および学習可能なアテンションの重みと、符号化器の出力のドット積です。

function [attentionScores, context] = attention(hiddenState, encoderOutputs, weights)

% Initialize attention energies.
[miniBatchSize, sequenceLength] = size(encoderOutputs, 2:3);
attentionEnergies = zeros([sequenceLength miniBatchSize],"like",hiddenState);

% Attention energies.
hWX = hiddenState .* pagemtimes(weights,encoderOutputs);
for tt = 1:sequenceLength
    attentionEnergies(tt, :) = sum(hWX(:, :, tt), 1);
end

% Attention scores.
attentionScores = softmax(attentionEnergies,DataFormat="CB");

% Context.
encoderOutputs = permute(encoderOutputs, [1 3 2]);
attentionScores = permute(attentionScores,[1 3 2]);
context = pagemtimes(encoderOutputs,attentionScores);
context = squeeze(context);

end

復号化器モデル予測関数

関数 decoderModelPredictions は、入力シーケンス、ターゲット シーケンス、隠れ状態、ドロップアウトの確率、教師強制を有効にするためのフラグ、シーケンス長を所与として、予測されたシーケンス Y を返します。

function Y = decoderPredictions(parametersDecoder,Z,T,hiddenState,dropout, ...
    doTeacherForcing,sequenceLength)

% Convert to dlarray.
T = dlarray(T);

% Initialize context.
miniBatchSize = size(T,2);
numHiddenUnits = size(Z,1);
context = zeros([numHiddenUnits miniBatchSize],"like",Z);

if doTeacherForcing
    % Forward through decoder.
    Y = modelDecoder(parametersDecoder,T,context,hiddenState,Z,dropout);
else
    % Get first time step for decoder.
    decoderInput = T(:,:,1);
    
    % Initialize output.
    numClasses = numel(parametersDecoder.fc.Bias);
    Y = zeros([numClasses miniBatchSize sequenceLength],"like",decoderInput);
    
    % Loop over time steps.
    for t = 1:sequenceLength
        % Forward through decoder.
        [Y(:,:,t), context, hiddenState] = modelDecoder(parametersDecoder,decoderInput,context, ...
            hiddenState,Z,dropout);
        
        % Update decoder input.
        [~, decoderInput] = max(Y(:,:,t),[],1);
    end
end

end

テキスト変換関数

関数 translateText は、ミニバッチを反復することによってテキストの配列を変換します。この関数は、モデル パラメーター、入力 string 配列、最大シーケンス長、ミニバッチ サイズ、ソースおよびターゲットの単語符号化オブジェクト、開始トークンおよび停止トークン、出力を集計するための区切り記号を入力として受け取ります。

function strTranslated = translateText(parameters,strSource,maxSequenceLength,miniBatchSize, ...
    encSource,encTarget,startToken,stopToken,delimiter)

% Transform text.
documentsSource = transformText(strSource,startToken,stopToken);
sequencesSource = doc2sequence(encSource,documentsSource, ...
    PaddingDirection="right", ...
    PaddingValue=encSource.NumWords + 1);

% Convert to dlarray.
X = cat(3,sequencesSource{:});
X = permute(X,[1 3 2]);
X = dlarray(X);

% Initialize output.
numObservations = numel(strSource);
strTranslated = strings(numObservations,1);

% Loop over mini-batches.
numIterations = ceil(numObservations / miniBatchSize);
for i = 1:numIterations
    idxMiniBatch = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    miniBatchSize = numel(idxMiniBatch);
    
    % Encode using model encoder.
    sequenceLengths = [];
    [Z, hiddenState] = modelEncoder(parameters.encoder,X(:,idxMiniBatch,:),sequenceLengths);
        
    % Decoder predictions.
    doTeacherForcing = false;
    dropout = 0;
    decoderInput = repmat(word2ind(encTarget,startToken),[1 miniBatchSize]);
    decoderInput = dlarray(decoderInput);
    Y = decoderPredictions(parameters.decoder,Z,decoderInput,hiddenState,dropout, ...
        doTeacherForcing,maxSequenceLength);
    [~, idxPred] = max(extractdata(Y),[],1);
    
    % Keep translating flag.
    idxStop = word2ind(encTarget,stopToken);
    keepTranslating = idxPred ~= idxStop;
     
    % Loop over time steps.
    t = 1;
    while t <= maxSequenceLength && any(keepTranslating(:,:,t))
    
        % Update output.
        newWords = ind2word(encTarget, idxPred(:,:,t))';
        idxUpdate = idxMiniBatch(keepTranslating(:,:,t));
        strTranslated(idxUpdate) = strTranslated(idxUpdate) + delimiter + newWords(keepTranslating(:,:,t));
        
        t = t + 1;
    end
end

end

参考

(Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | | | | | | (Text Analytics Toolbox)

関連するトピック