Main Content

『Pride and Prejudice』と MATLAB

この例では、深層学習 LSTM ネットワークに学習させ、文字の埋め込みを使用してテキストを生成する方法を説明します。

テキスト生成用の深層学習ネットワークに学習させるには、文字のシーケンスにおける次の文字を予測する sequence-to-sequence LSTM ネットワークに学習させます。次の文字を予測するネットワークに学習させるには、1 タイム ステップ分シフトした入力シーケンスになるように応答を指定します。

文字の埋め込みを使用するために、各学習観測値を整数のシーケンスに変換します。それらの整数は文字のボキャブラリに対してインデックス付けされています。文字の埋め込みを学習して整数をベクトルにマッピングする単語埋め込み層をネットワークに含めます。

学習データの読み込み

Project Gutenberg で公開されている Jane Austen の『Pride and Prejudice』の電子書籍から HTML コードを読み取り、webreadhtmlTree を使用して解析します。

url = "https://www.gutenberg.org/files/1342/1342-h/1342-h.htm";
code = webread(url);
tree = htmlTree(code);

p 要素を見つけて、段落を抽出します。CSS セレクター ':not(.toc)' を使用してクラス "toc" を含む段落要素を無視するよう指定します。

paragraphs = findElement(tree,'p:not(.toc)');

extractHTMLText を使用して段落からテキスト データを抽出し、空の string を削除します。

textData = extractHTMLText(paragraphs);
textData(textData == "") = [];

20 文字未満の string を削除します。

idx = strlength(textData) < 20;
textData(idx) = [];

テキスト データをワード クラウドで可視化します。

figure
wordcloud(textData);
title("Pride and Prejudice")

テキスト データのシーケンスへの変換

テキスト データを予測子の文字インデックスのシーケンスと応答のカテゴリカル シーケンスに変換します。

関数 categorical は、改行と空白のエントリを未定義として扱います。これらの文字の categorical 要素を作成するには、それぞれ特殊文字 "" (段落記号、"\x00B6") および "·" (中黒、"\x00B7") に置き換えます。あいまいさを避けるために、テキストに現れない特殊文字を選択しなければなりません。これらの文字は学習データに現れないため、この目的に使用できます。

newlineCharacter = compose("\x00B6");
whitespaceCharacter = compose("\x00B7");
textData = replace(textData,[newline " "],[newlineCharacter whitespaceCharacter]);

テキスト データについてループし、各観測値の文字を表す文字インデックスのシーケンスと応答用の文字のカテゴリカル シーケンスを作成します。各観測値の終わりを表すために、特殊文字 "␃" (テキスト終結、"\x2403") を含めます。

endOfTextCharacter = compose("\x2403");
numDocuments = numel(textData);
for i = 1:numDocuments
    characters = textData{i};
    X = double(characters);
    
    % Create vector of categorical responses with end of text character.
    charactersShifted = [cellstr(characters(2:end)')' endOfTextCharacter];
    Y = categorical(charactersShifted);
    
    XTrain{i} = X;
    YTrain{i} = Y;
end

既定では、学習中に、学習データはミニバッチに分割され、パディングによってシーケンスの長さが揃えられます。過度なパディングは、ネットワーク性能に悪影響を与える可能性があります。

学習プロセスでの過度のパディングを防ぐため、シーケンス長で学習データを並べ替えて、ミニバッチ内のシーケンスが似たような長さになるようにミニバッチのサイズを選択できます。

各観測値のシーケンス長を取得します。

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end

シーケンス長でデータを並べ替えます。

[~,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);

LSTM ネットワークの作成と学習

LSTM アーキテクチャを定義します。隠れユニットが 400 個の sequence-to-sequence LSTM 分類ネットワークを指定します。入力サイズを学習データの特徴次元になるように設定します。文字インデックスのシーケンスでは、特徴次元は 1 です。次元が 200 の単語埋め込み層を指定し、単語 (文字に対応) の数を入力データの最大文字値になるように指定します。全結合層の出力サイズが応答のカテゴリ数になるように設定します。過適合を防止するために、LSTM 層の後にドロップアウト層を含めます。

単語埋め込み層は、文字の埋め込みを学習して各文字を 200 次元のベクトルにマッピングします。

inputSize = size(XTrain{1},1);
numClasses = numel(categories([YTrain{:}]));
numCharacters = max([textData{:}]);

layers = [
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(200,numCharacters)
    lstmLayer(400,'OutputMode','sequence')
    dropoutLayer(0.2);
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

学習オプションを指定します。ミニバッチ サイズ 32 と初期学習率 0.01 で学習するように指定します。勾配の発散を防ぐために、勾配しきい値を 1 に設定します。データを並べ替えられた状態に保つには、'Shuffle''never' に設定します。学習の進行状況を監視するには、'Plots' オプションを 'training-progress' に設定します。詳細出力を表示しないようにするには、'Verbose'false に設定します。

options = trainingOptions('adam', ...
    'MiniBatchSize',32,...
    'InitialLearnRate',0.01, ...
    'GradientThreshold',1, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false);

ネットワークに学習をさせます。

net = trainNetwork(XTrain,YTrain,layers,options);

新しいテキストの生成

学習データに含まれるテキストの最初の文字が従う確率分布から文字をサンプリングしてテキストの最初の文字を生成します。学習済み LSTM ネットワークを使用して、生成されたテキストの現在のシーケンスに基づいて次のシーケンスを予測し、残りの文字を生成します。ネットワークが "テキスト終結" 文字を予測するまで、文字を 1 つずつ生成し続けます。

学習データに含まれる最初の文字の分布に従って、最初の文字をサンプリングします。

initialCharacters = extractBefore(textData,2);
firstCharacter = datasample(initialCharacters,1);
generatedText = firstCharacter;

最初の文字を数値インデックスに変換します。

X = double(char(firstCharacter));

残りの予測については、ネットワークの予測スコアに従って、次の文字をサンプリングします。予測スコアは次の文字の確率分布を表します。ネットワークの出力層のクラス名で与えられた文字のボキャブラリから文字をサンプリングします。ネットワークの分類層からボキャブラリを取得します。

vocabulary = string(net.Layers(end).ClassNames);

predictAndUpdateState を使用して 1 文字ずつ予測します。予測ごとに、前の文字のインデックスを入力します。ネットワークがテキスト終結文字を予測するか、生成されたテキストの長さが 500 文字になったら予測を停止します。データの大規模なコレクション、長いシーケンス、または大規模ネットワークの場合は、通常、GPU での予測の方が CPU での予測より計算時間が短縮されます。そうでない場合、通常、CPU での予測の計算の方が高速です。1 タイム ステップの予測には、CPU を使用します。予測に CPU を使用するには、predictAndUpdateState'ExecutionEnvironment' オプションを 'cpu' に設定します。

maxLength = 500;
while strlength(generatedText) < maxLength
    % Predict the next character scores.
    [net,characterScores] = predictAndUpdateState(net,X,'ExecutionEnvironment','cpu');
    
    % Sample the next character.
    newCharacter = datasample(vocabulary,1,'Weights',characterScores);
    
    % Stop predicting at the end of text.
    if newCharacter == endOfTextCharacter
        break
    end
    
    % Add the character to the generated text.
    generatedText = generatedText + newCharacter;
    
    % Get the numeric index of the character.
    X = double(char(newCharacter));
end

特殊文字を対応する空白文字や改行文字に置き換えて、生成されたテキストを再構成します。

generatedText = replace(generatedText,[newlineCharacter whitespaceCharacter],[newline " "])
generatedText = 
"“I wish Mr. Darcy, upon latter of my sort sincerely fixed in the regard to relanth. We were to join on the Lucases. They are married with him way Sir Wickham, for the possibility which this two od since to know him one to do now thing, and the opportunity terms as they, and when I read; nor Lizzy, who thoughts of the scent; for a look for times, I never went to the advantage of the case; had forcibling himself. They pility and lively believe she was to treat off in situation because, I am exceal"

複数のテキストを生成するには、resetState を使用して生成間のネットワークの状態をリセットします。

net = resetState(net);

参考

(Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

関連するトピック