Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

自己符号化器を使用したテキストの生成

この例では、自己符号化器を使用してテキスト データを生成する方法を示します。

自己符号化器は、入力を複製するように学習させられた深層学習ネットワークの一種です。自己符号化器は、小規模な 2 つのネットワーク (符号化器および復号化器) で構成されます。符号化器は、入力データを何らかの潜在空間における特徴ベクトルにマッピングします。復号化器は、この潜在空間におけるベクトルを使用してデータを再構成します。

学習プロセスは教師なし学習となります。つまり、このモデルはラベル付けされたデータを必要としません。テキストを生成するために、復号化器を使用し、任意の入力からテキストを再構成することができます。

この例では、テキストを生成するように自己符号化器の学習を行います。符号化器は、単語埋め込みと LSTM 演算を使用し、入力テキストを潜在ベクトルにマッピングします。復号化器は、LSTM 演算および同じ埋め込みを使用し、潜在ベクトルからテキストを再構成します。

データの読み込み

ファイル sonnets.txt には、シェイクスピアのソネット全集が 1 つのテキスト ファイルとして格納されています。

ファイル "sonnets.txt" からシェイクスピアのソネットのデータを読み取ります。

filename = "sonnets.txt";
textData = fileread(filename);

ソネットは、2 つの空白文字でインデントされています。replace を使用してインデントを削除し、関数 split を使用してテキストを個別の行に分割します。最初の 9 つの要素と短いソネット タイトルからヘッダーを削除します。

textData = replace(textData,"  ","");
textData = split(textData,newline);
textData(1:9) = [];
textData(strlength(textData)<5) = [];

データの準備

テキスト データをトークン化および前処理する関数を作成します。例の最後にリストされている関数 preprocessText は以下のステップを実行します。

  1. 各入力ストリングの前後に、指定された開始トークンと停止トークンをそれぞれ付加する。

  2. tokenizedDocument を使用してテキストをトークン化する。

テキスト データを前処理し、開始トークン "<start>" と停止トークン "<stop>" をそれぞれ指定します。

startToken = "<start>";
stopToken = "<stop>";
documents = preprocessText(textData,startToken,stopToken);

トークン化されたドキュメントから単語符号化オブジェクトを作成します。

enc = wordEncoding(documents);

深層学習モデルの学習を行う際には、入力データは固定長のシーケンスを含む数値配列でなければなりません。ドキュメントの長さは異なるため、短いシーケンスはパディング値でパディングしなければなりません。

パディング トークンを含み、そのトークンのインデックスも決定するように、単語符号化を再作成します。

paddingToken = "<pad>";
newVocabulary = [enc.Vocabulary paddingToken];
enc = wordEncoding(newVocabulary);
paddingIdx = word2ind(enc,paddingToken)
paddingIdx = 3595

モデル パラメーターの初期化

次のモデルのパラメーターを初期化します。

ここで、T はシーケンス長、x1,,xT は単語インデックスの入力シーケンス、y1,,yT は再構成されたシーケンスです。

符号化器は、単語インデックスのシーケンスを潜在ベクトルにマッピングします。このマッピングは、埋め込みによって入力を単語ベクトルのシーケンスに変換し、単語ベクトル シーケンスを LSTM 演算に入力し、全結合演算を LSTM 出力の最後のタイム ステップに適用することによって行われます。復号化器は、符号化器の出力を初期化した LSTM を使用して入力を再構成します。タイム ステップごとに、復号化器は次のタイム ステップを予測し、次のタイム ステップ予測の出力を使用します。符号化器と復号化器の両方で同じ埋め込みを使用します。

パラメーターの次元を指定します。

embeddingDimension = 100;
numHiddenUnits = 150;
latentDimension = 75;
vocabularySize = enc.NumWords;

パラメーターの struct を作成します。

parameters = struct;

関数 initializeGaussian を使用し、ガウスで埋め込みの重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。平均値を 0、標準偏差を 0.01 に指定します。詳細については、ガウスによる初期化を参照してください。

sz = [embeddingDimension vocabularySize];
mu = 0;
sigma = 0.01;
parameters.emb.Weights = initializeGaussian(sz,mu,sigma);

符号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

  • 関数 initializeGlorot を使用し、Glorot 初期化子で入力の重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、Glorot の初期化を参照してください。

  • 関数 initializeOrthogonal を使用し、直交初期化子で再帰重みを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、直交初期化を参照してください。

  • 関数 initializeUnitForgetGate を使用し、ユニット忘却ゲート初期化子でバイアスを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、ユニット忘却ゲートによる初期化を参照してください。

sz = [4*numHiddenUnits embeddingDimension];
numOut = 4*numHiddenUnits;
numIn = embeddingDimension;

parameters.lstmEncoder.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.lstmEncoder.RecurrentWeights = initializeOrthogonal([4*numHiddenUnits numHiddenUnits]);
parameters.lstmEncoder.Bias = initializeUnitForgetGate(numHiddenUnits);

符号化器の全結合演算に関する学習可能なパラメーターを初期化します。

  • Glorot 初期化子を使用して重みを初期化します。

  • 関数 initializeZeros を使用し、ゼロでバイアスを初期化します。この関数は、この例にサポート ファイルとして添付されています。詳細については、ゼロでの初期化を参照してください。

sz = [latentDimension numHiddenUnits];
numOut = latentDimension;
numIn = numHiddenUnits;

parameters.fcEncoder.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fcEncoder.Bias = initializeZeros([latentDimension 1]);

復号化器の LSTM 演算に関する学習可能なパラメーターを初期化します。

  • Glorot 初期化子を使用して入力の重みを初期化します。

  • 直交初期化子を使用して再帰重みを初期化します。

  • ユニット忘却ゲート初期化子でバイアスを初期化します。

sz = [4*latentDimension embeddingDimension];
numOut = 4*latentDimension;
numIn = embeddingDimension;

parameters.lstmDecoder.InputWeights = initializeGlorot(sz,numOut,numIn);
parameters.lstmDecoder.RecurrentWeights = initializeOrthogonal([4*latentDimension latentDimension]);
parameters.lstmDecoder.Bias = initializeZeros([4*latentDimension 1]);

復号化器の全結合演算に関する学習可能なパラメーターを初期化します。

  • Glorot 初期化子を使用して重みを初期化します。

  • ゼロでバイアスを初期化します。

sz = [vocabularySize latentDimension];
numOut = vocabularySize;
numIn = latentDimension;

parameters.fcDecoder.Weights = initializeGlorot(sz,numOut,numIn);
parameters.fcDecoder.Bias = initializeZeros([vocabularySize 1]);

重みの初期化の詳細については、モデル関数の学習可能パラメーターの初期化を参照してください。

モデル符号化器関数の定義

この例の符号化器モデル関数の節にリストされている関数 modelEncoder を作成し、符号化器モデルの出力を計算します。関数 modelEncoder は、モデル パラメーターおよびシーケンス長を単語インデックスの入力シーケンスとして受け取り、対応する潜在特徴ベクトルを返します。モデル符号化器関数の定義の詳細については、テキスト符号化器モデル関数の定義を参照してください。

モデル復号化器関数の定義

この例の復号化器モデル関数の節にリストされている関数 modelDecoder を作成し、復号化器モデルの出力を計算します。関数 modelDecoder は、モデル パラメーターおよびシーケンス長を単語インデックスの入力シーケンスとして受け取り、対応する潜在特徴ベクトルを返します。モデル復号化器関数の定義の詳細については、テキスト復号化器モデル関数の定義を参照してください。

モデル損失関数の定義

この例のモデル損失関数の節にリストされている関数 modelLoss は、モデルの学習可能なパラメーター、入力データ、マスク用のシーケンス長のベクトルを入力として受け取り、学習可能なパラメーターについての損失、およびその損失の勾配を返します。モデル損失関数の定義の詳細については、カスタム学習ループのモデル損失関数の定義を参照してください。

学習オプションの指定

学習用のオプションを指定します。

ミニバッチ サイズを 128 として 100 エポック学習させます。

miniBatchSize = 128;
numEpochs = 100;

学習率を 0.01 にして学習を行います。

learnRate = 0.01;

ネットワークの学習

カスタム学習ループを使用してネットワークに学習させます。

Adam オプティマイザーのパラメーターを初期化します。

trailingAvg = [];
trailingAvgSq = [];

学習の進行状況プロットを初期化します。対応する反復に対する損失をプロットする、アニメーションの線を作成します。

figure
C = colororder;
lineLossTrain = animatedline(Color=C(2,:));
xlabel("Iteration")
ylabel("Loss")
ylim([0 inf])
grid on

モデルに学習させます。最初のエポックについて、データをシャッフルしてデータのミニバッチをループ処理します。

各ミニバッチで次を行います。

  • テキスト データを単語インデックスのシーケンスに変換します。

  • データを dlarray に変換します。

  • GPU での学習用に、データを gpuArray オブジェクトに変換します。

  • 損失と勾配を計算します。

  • 関数 adamupdate を使用して学習可能なパラメーターを更新します。

  • 学習の進行状況プロットを更新します。

GPU が利用できる場合、GPU で学習を行います。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

学習を行うのに時間がかかる場合があります。

numObservations = numel(documents);
numIterationsPerEpoch = floor(numObservations / miniBatchSize);

iteration = 0;
start = tic;

for epoch = 1:numEpochs

    % Shuffle.
    idx = randperm(numObservations);
    documents = documents(idx);

    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;

        % Read mini-batch.
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        documentsBatch = documents(idx);

        % Convert to sequences.
        X = doc2sequence(enc,documentsBatch, ...
            PaddingDirection="right", ...
            PaddingValue=paddingIdx);

        X = cat(1,X{:});

        % Convert to dlarray.
        X = dlarray(X,"BTC");

        % If training on a GPU, then convert data to gpuArray.
        if canUseGPU
            X = gpuArray(X);
        end

        % Calculate sequence lengths.
        sequenceLengths = doclength(documentsBatch);

        % Evaluate model loss and gradients.
        [loss,gradients] = dlfeval(@modelLoss, parameters, X, sequenceLengths);

        % Update learnable parameters.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAvgSq,iteration,learnRate);

        % Display the training progress.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        loss = double(loss);
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))

        drawnow
    end
end

テキストの生成

ランダムな異なる状態で復号化器を初期化することにより、閉ループ生成を使用してテキストを生成します。閉ループ生成では、モデルが一度に 1 タイム ステップずつデータを生成し、前の予測を次の予測の入力として使用します。

長さ 16 のシーケンスを 3 つ生成するように指定します。

numGenerations = 3;
sequenceLength = 16;

乱数値の配列を作成して復号化器の状態を初期化します。

Z = dlarray(randn(latentDimension,numGenerations),"CB");

GPU で予測する場合、データを gpuArray に変換します。

if canUseGPU
    Z = gpuArray(Z);
end

例の最後にリストされている関数 modelPredictions を使用して予測を行います。関数 modelPredictions は、モデル パラメーター、復号化器の初期状態、最大シーケンス長、単語符号化、開始トークン、ミニバッチ サイズを所与として、復号化器の出力スコアを返します。

Y = modelDecoderPredictions(parameters,Z,sequenceLength,enc,startToken,miniBatchSize);

最高のスコアをもつ単語インデックスを見つけます。

[~,idx] = max(Y,[],1);
idx = squeeze(idx);

数値インデックスを単語に変換し、関数 join を使用してそれらを連結します。

strGenerated = join(enc.Vocabulary(idx));

関数 extractBefore を使用し、最初の停止トークン前のテキストを抽出します。停止トークンがない場合に関数が欠損値を返さないように、各シーケンスの終わりに停止トークンを追加します。

strGenerated = extractBefore(strGenerated+stopToken,stopToken);

パディング トークンを削除します。

strGenerated = erase(strGenerated,paddingToken);

生成過程は各予測の間に空白文字を追加するため、一部の句読点文字が前後に不要な空白を伴って出現することになります。該当する句読点文字の前後の空白を削除し、生成されたテキストを再構成します。

指定された句読点文字の前に出現する空白を削除します。

punctuationCharacters = ["." "," "’" ")" ":" ";" "?" "!"];
strGenerated = replace(strGenerated," " + punctuationCharacters,punctuationCharacters);

指定された句読文字の後に出現する空白を削除します。

punctuationCharacters = ["(" "‘"];
strGenerated = replace(strGenerated,punctuationCharacters + " ",punctuationCharacters);

関数 strip を使用して先頭と末尾の空白を削除し、生成されたテキストを表示します。

strGenerated = strip(strGenerated)
strGenerated = 3×1 string
    "me whose fool black grounded less waning travels less pine pine sing cool thrive kindness this"
    "perjur'd outward a looks black, here might."
    "birds him antique side his hours age,"

符号化器モデル関数

関数 modelEncoder は、モデル パラメーター、単語インデックスのシーケンス、シーケンス長を入力として受け取り、対応する潜在特徴ベクトルを返します。

入力データには長さの異なるパディング済みシーケンスが含まれるため、パディングによって損失計算に悪影響の及ぶ可能性があります。LSTM 演算について、シーケンスの最後のタイム ステップの出力 (多数のパディング値を処理した後の LSTM 状態に相当する可能性が高い) を返す代わりに、sequenceLengths 入力によって与えられた実際の最後のタイム ステップを決定します。

function Z = modelEncoder(parameters,X,sequenceLengths)

% Embedding.
weights = parameters.emb.Weights;
Z = embed(X,weights);

% LSTM.
inputWeights = parameters.lstmEncoder.InputWeights;
recurrentWeights = parameters.lstmEncoder.RecurrentWeights;
bias = parameters.lstmEncoder.Bias;

numHiddenUnits = size(recurrentWeights,2);
hiddenState = zeros(numHiddenUnits,1,"like",X);
cellState = zeros(numHiddenUnits,1,"like",X);

Z1 = lstm(Z,hiddenState,cellState,inputWeights,recurrentWeights,bias);

% Output mode 'last' with masking.
miniBatchSize = size(Z1,2);
Z = zeros(numHiddenUnits,miniBatchSize,"like",Z1);

for n = 1:miniBatchSize
    t = sequenceLengths(n);
    Z(:,n) = Z1(:,n,t);
end

% Fully connect.
weights = parameters.fcEncoder.Weights;
bias = parameters.fcEncoder.Bias;
Z = fullyconnect(Z,weights,bias,DataFormat="CB");

end

復号化器モデル関数

関数 modelDecoder は、モデル パラメーター、単語インデックスのシーケンス、ネットワークの状態を入力として受け取り、復号化されたシーケンスを返します。

関数 lstm"ステートフル" (時系列が入力として与えられたとき、関数が各タイム ステップ間の状態を伝播および更新する) であり、関数 embed および fullyconnect は既定で時間分散型である (時系列が入力として与えられたとき、関数が各タイム ステップに対して個別に演算を行う) ため、関数 modelDecoder はシーケンスと単一のタイム ステップ入力の両方をサポートしています。

function [Y,state] = modelDecoder(parameters,X,state)

% Embedding.
weights = parameters.emb.Weights;
X = embed(X,weights);

% LSTM.
inputWeights = parameters.lstmDecoder.InputWeights;
recurrentWeights = parameters.lstmDecoder.RecurrentWeights;
bias = parameters.lstmDecoder.Bias;

hiddenState = state.HiddenState;
cellState = state.CellState;

[Y,hiddenState,cellState] = lstm(X,hiddenState,cellState, ...
    inputWeights,recurrentWeights,bias);

state.HiddenState = hiddenState;
state.CellState = cellState;

% Fully connect.
weights = parameters.fcDecoder.Weights;
bias = parameters.fcDecoder.Bias;
Y = fullyconnect(Y,weights,bias);

% Softmax.
Y = softmax(Y);

end

モデル損失関数

関数 modelLoss は、モデルの学習可能なパラメーター、入力データ X、マスク用のシーケンス長のベクトルを入力として受け取り、学習可能なパラメーターについての損失および損失の勾配を返します。

マスクされた損失を計算するために、モデル損失関数は、例の最後にリストされている関数 maskedCrossEntropy を使用します。次のタイム ステップを予測する復号化器に学習させるには、1 タイム ステップ分シフトした入力シーケンスになるようにターゲットを指定します。

モデル損失関数の定義の詳細については、カスタム学習ループのモデル損失関数の定義を参照してください。

function [loss,gradients] = modelLoss(parameters,X,sequenceLengths)

% Model encoder.
Z = modelEncoder(parameters,X,sequenceLengths);

% Initialize LSTM state.
state = struct;
state.HiddenState = Z;
state.CellState = zeros(size(Z),"like",Z);

% Teacher forcing.
Y = modelDecoder(parameters,X,state);

% Loss.
Y = Y(:,:,1:end-1);
T = X(:,:,2:end);
loss = mean(maskedCrossEntropy(Y,T,sequenceLengths));

% Gradients.
gradients = dlgradient(loss,parameters);

% Normalize loss for plotting.
sequenceLength = size(X,3);
loss = loss / sequenceLength;

end

モデル予測関数

関数 modelPredictions は、モデル パラメーター、復号化器の初期状態、最大シーケンス長、単語符号化、開始トークン、ミニバッチ サイズを所与として、復号化器の出力スコアを返します。

function Y = modelDecoderPredictions(parameters,Z,maxLength,enc,startToken,miniBatchSize)

numObservations = size(Z,2);
numIterations = ceil(numObservations / miniBatchSize);

startTokenIdx = word2ind(enc,startToken);
vocabularySize = enc.NumWords;

Y = zeros(vocabularySize,numObservations,maxLength,"like",Z);

% Loop over mini-batches.
for i = 1:numIterations
    idxMiniBatch = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    miniBatchSize = numel(idxMiniBatch);

    % Initialize state.
    state = struct;
    state.HiddenState = Z(:,idxMiniBatch);
    state.CellState = zeros(size(Z(:,idxMiniBatch)),"like",Z);

    % Initialize decoder input.
    decoderInput = dlarray(repmat(startTokenIdx,[1 miniBatchSize]),"CBT");

    % Loop over time steps.
    for t = 1:maxLength
        % Predict next time step.
        [Y(:,idxMiniBatch,t), state] = modelDecoder(parameters,decoderInput,state);

        % Closed loop generation.
        [~,idx] = max(Y(:,idxMiniBatch,t));
        decoderInput = dlarray(idx,"CB");
    end
end

end

マスクされた交差エントロピー損失関数

関数 maskedCrossEntropy はシーケンス長の指定されたベクトルを使用して、指定された入力シーケンスとターゲット シーケンスの間の損失を、パディングを含むタイム ステップを無視して計算します。

function maskedLoss = maskedCrossEntropy(Y,T,sequenceLengths)

numClasses = size(Y,1);
miniBatchSize = size(Y,2);
sequenceLength = size(Y,3);

maskedLoss = zeros(sequenceLength,miniBatchSize,"like",Y);

for t = 1:sequenceLength
    T1 = single(oneHot(T(:,:,t),numClasses));

    mask = (t<=sequenceLengths)';

    maskedLoss(t,:) = mask .* crossentropy(Y(:,:,t),T1);
end

maskedLoss = sum(maskedLoss,1);

end

テキスト前処理関数

関数 preprocessText は以下のステップを実行します。

  1. 各入力ストリングの前後に、指定された開始トークンと停止トークンをそれぞれ付加する。

  2. tokenizedDocument を使用してテキストをトークン化する。

function documents = preprocessText(textData,startToken,stopToken)

% Add start and stop tokens.
textData = startToken + textData + stopToken;

% Tokenize the text.
documents = tokenizedDocument(textData,'CustomTokens',[startToken stopToken]);

end

one-hot 符号化関数

関数 oneHot は、数値インデックスの配列を one-hot 符号化されたベクトルに変換します。

function oh = oneHot(idx, outputSize)

miniBatchSize = numel(idx);
oh = zeros(outputSize,miniBatchSize);

for n = 1:miniBatchSize
    c = idx(n);
    oh(c,n) = 1;
end

end

参考

| |

関連するトピック