『Pride and Prejudice』と MATLAB
この例では、深層学習 LSTM ネットワークに学習させ、文字の埋め込みを使用してテキストを生成する方法を説明します。
テキスト生成用の深層学習ネットワークに学習させるには、文字のシーケンスにおける次の文字を予測する sequence-to-sequence LSTM ネットワークに学習させます。次の文字を予測するネットワークに学習させるには、1 タイム ステップ分シフトした入力シーケンスになるように応答を指定します。
文字の埋め込みを使用するために、各学習観測値を整数のシーケンスに変換します。それらの整数は文字のボキャブラリに対してインデックス付けされています。文字の埋め込みを学習して整数をベクトルにマッピングする単語埋め込み層をネットワークに含めます。
学習データの読み込み
Project Gutenberg で公開されている Jane Austen の『Pride and Prejudice』の電子書籍から HTML コードを読み取り、webreadとhtmlTree (Text Analytics Toolbox)を使用して解析します。
url = "https://www.gutenberg.org/files/1342/1342-h/1342-h.htm";
code = webread(url);
tree = htmlTree(code);p 要素を見つけて、段落を抽出します。目次に関連する情報を削除するには、CSS セレクター ':not(.toc)' を使用して、クラス "toc" を含む段落要素を無視します。
paragraphs = findElement(tree,'p:not(.toc)');extractHTMLText (Text Analytics Toolbox)関数を使用して、段落からテキスト データを抽出します。空の string を削除します。
textData = extractHTMLText(paragraphs);
textData(textData == "") = [];20 文字未満の string を削除します。
idx = strlength(textData) < 20; textData(idx) = [];
テキスト データをワード クラウドで可視化します。
figure
wordcloud(textData);
title("Pride and Prejudice")
テキスト データのシーケンスへの変換
テキスト データを予測子の文字インデックスのシーケンスと応答のカテゴリカル シーケンスに変換します。
categorical関数は、改行と空白のエントリを未定義として扱います。これらの文字の categorical 要素を作成するには、それぞれ特殊文字 "¶" (段落記号、"\x00B6") および "·" (中黒、"\x00B7") に置き換えます。あいまいさを避けるために、テキストに現れない特殊文字を選択します。
newlineCharacter = compose("\x00B6"); whitespaceCharacter = compose("\x00B7"); textData = replace(textData,[newline " "],[newlineCharacter whitespaceCharacter]);
テキスト データについてループし、各観測値の文字を表す文字インデックスのシーケンスと応答用の文字のカテゴリカル シーケンスを作成します。各観測値の終わりを表すために、特殊文字 "␃" (テキスト終結、"\x2403") を含めます。categorical 配列のカテゴリを、テキスト データに現れるすべての文字となるように指定します。
endOfTextCharacter = compose("\x2403"); numDocuments = numel(textData); uniqueCharacters = unique([textData{:}]); for i = 1:numDocuments characters = textData{i}; X = double(characters); % Create vector of categorical responses with end of text character. charactersShifted = [cellstr(characters(2:end)')' endOfTextCharacter]; Y = categorical(charactersShifted, [string(uniqueCharacters'); endOfTextCharacter]); XTrain{i} = X; YTrain{i} = Y; end
既定では、学習中に、学習データはミニバッチに分割され、パディングによってシーケンスの長さが揃えられます。過度なパディングは、ネットワーク性能に悪影響を与える可能性があります。
学習プロセスでの過度のパディングを防ぐため、学習データをシーケンス長に並べ替えることができます。そうすることで、どのミニバッチ内のシーケンスも同じような長さになります。
numObservations = numel(XTrain); for i=1:numObservations sequence = XTrain{i}; sequenceLengths(i) = size(sequence,2); end [~,idx] = sort(sequenceLengths); XTrain = XTrain(idx); YTrain = YTrain(idx);
LSTM ネットワークの作成と学習
LSTM アーキテクチャを定義します。隠れユニットが 400 個の sequence-to-sequence LSTM 分類ネットワークを指定します。入力サイズを学習データの特徴次元になるように設定します。文字インデックスのシーケンスでは、特徴次元は 1 です。次元が 200 の単語埋め込み層を指定し、単語 (文字に対応) の数を入力データの最大文字値になるように指定します。全結合層の出力サイズが応答のカテゴリ数になるように設定します。過適合を防止するために、LSTM 層の後にドロップアウト層を含めます。
単語埋め込み層は、文字の埋め込みを学習して各文字を 200 次元のベクトルにマッピングします。
inputSize = size(XTrain{1},1);
classes = categories([YTrain{:}]);
numClasses = numel(classes);
numCharacters = max([textData{:}]);
layers = [
sequenceInputLayer(inputSize)
wordEmbeddingLayer(200,numCharacters)
lstmLayer(400,OutputMode="sequence")
dropoutLayer(0.2);
fullyConnectedLayer(numClasses)
softmaxLayer];学習オプションを指定します。ミニバッチ サイズ 32 と初期学習率 0.01 で学習するように指定します。勾配の発散を防ぐために、勾配しきい値を 1 に設定します。データを並べ替えられた状態に保つには、Shuffle を "never" に設定します。学習の進行状況を監視するには、Plots オプションを "training-progress" に設定します。詳細出力を表示しないようにするには、Verbose を false に設定します。学習データには、チャネルとタイム ステップにそれぞれ対応する行と列を含むシーケンスがあるため、入力とターゲットのデータ形式 "CTB" (チャネル、時間、バッチ) を指定します。
options = trainingOptions("adam", ... InputDataFormats = "CTB", ... TargetDataFormats = "CTB", ... Metrics = "accuracy", ... MiniBatchSize = 32,... InitialLearnRate = 0.01, ... GradientThreshold = 0.1, ... Shuffle = "never", ... Plots = "training-progress", ... Verbose = false, ... ExecutionEnvironment = "auto");
関数trainnetを使用してネットワークに学習させます。
net = trainnet(XTrain,YTrain,layers,"crossentropy",options);
新しいテキストの生成
学習データに含まれるテキストの最初の文字が従う確率分布から文字をサンプリングしてテキストの最初の文字を生成します。学習済み LSTM ネットワークを使用して、生成されたテキストの現在のシーケンスに基づいて次のシーケンスを予測し、残りの文字を生成します。ネットワークが "テキスト終結" 文字を予測するまで、文字を 1 つずつ生成し続けます。
学習データに含まれる最初の文字の分布に従って、最初の文字をサンプリングします。
initialCharacters = extractBefore(textData,2); firstCharacter = datasample(initialCharacters,1); generatedText = firstCharacter;
最初の文字を数値インデックスに変換します。
X = double(char(firstCharacter));
残りの予測については、ネットワークの予測スコアに従って、次の文字をサンプリングします。予測スコアは次の文字の確率分布を表します。ネットワークの出力層のクラス名で与えられた文字のボキャブラリから文字をサンプリングします。学習データからボキャブラリを取得します。
vocabulary = string(classes);
predictを使用して 1 文字ずつ予測します。予測ごとに、前の文字のインデックスを入力します。ネットワークがテキスト終結文字を予測するか、生成されたテキストの長さが 500 文字になったら予測を停止します。
maxLength = 500; while strlength(generatedText) < maxLength % Predict the next character scores and output the network state. [characterScores,state] = predict(net,X); % Update the state. net.State = state; % Sample the next character. newCharacter = datasample(vocabulary,1,Weights=gather(characterScores)); % Stop predicting at the end of text. if newCharacter == endOfTextCharacter break end % Add the character to the generated text. generatedText = generatedText + newCharacter; % Get the numeric index of the character. X = double(char(newCharacter)); end
特殊文字を対応する空白文字や改行文字に置き換えて、生成されたテキストを再構成します。
generatedText = replace(generatedText,[newlineCharacter whitespaceCharacter],[newline " "])generatedText = "“You dread, in must at Mr. Darcy’s more she had coldured with, to pubid since mistaken part of their return out general futual limentable or town Mact last ledge of the whire of these attended more longer impostible to looke able. They need marely bedre. Be all looking enough to this vortimily before, shook back beauty, she could inte; and moresning to it parton, that which she had out also on Jane, though to Loss arain Miss depenting, as mightmon that chuel to Mr. Darcy’s much man’s po rithment"
複数のテキストを生成するには、resetStateを使用して生成間のネットワークの状態をリセットします。
net = resetState(net);
参考
wordEmbeddingLayer (Text Analytics Toolbox) | doc2sequence (Text Analytics Toolbox) | tokenizedDocument (Text Analytics Toolbox) | lstmLayer | trainnet | trainingOptions | dlnetwork | sequenceInputLayer | wordcloud (Text Analytics Toolbox) | extractHTMLText (Text Analytics Toolbox) | findElement (Text Analytics Toolbox) | htmlTree (Text Analytics Toolbox)
トピック
- 深層学習を使用したテキストの生成
- 深層学習を使用した単語単位のテキスト生成 (Text Analytics Toolbox)
- 分類用の単純なテキスト モデルの作成 (Text Analytics Toolbox)
- トピック モデルを使用したテキスト データの解析 (Text Analytics Toolbox)
- マルチワード フレーズを使用したテキスト データの解析 (Text Analytics Toolbox)
- センチメント分類器の学習 (Text Analytics Toolbox)
- 深層学習を使用したシーケンスの分類
- MATLAB による深層学習