このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
深層学習を使用したテキスト データの分類
この例では、深層学習の長短期記憶 (LSTM) ネットワークを使用してテキスト データを分類する方法を示します。
テキスト データでは必然的にデータが順に並んでいます。一部のテキストは単語のシーケンスであり、それらの単語間には依存関係がある可能性があります。長期的な依存関係を学習してシーケンス データの分類に使用するために、LSTM ニューラル ネットワークを使用します。LSTM ネットワークは、再帰型ニューラル ネットワーク (RNN) の一種で、シーケンス データのタイム ステップ間の長期的な依存関係を学習できます。
テキストを LSTM ネットワークに入力するには、まず、テキスト データを数値シーケンスに変換します。文書を数値インデックスのシーケンスにマッピングする単語符号化を使用して、これを実現できます。また、より正確な結果を得るため、単語埋め込み層をネットワークに含めます。単語埋め込みは、ボキャブラリ内の単語をスカラー インデックスではなく数値ベクトルにマッピングします。これらの埋め込みでは、意味の似ている単語が類似のベクトルをもつように、単語のセマンティックな詳細を取得します。また、ベクトル演算を使用して単語間の関係をモデル化します。たとえば、"Rome is to Italy as Paris is to France" (イタリアに対してのローマは、フランスに対してのパリに同じ) という関係は、方程式 Italy – Rome + Paris = France で記述されます。
この例では、4 つの手順で LSTM ネットワークに学習させてそれを使用します。
データをインポートして前処理します。
単語符号化を使用して単語を数値シーケンスに変換します。
単語埋め込み層のある LSTM ネットワークを作成し、このネットワークに学習させます。
学習済み LSTM ネットワークを使用して新しいテキスト データを分類します。
データのインポート
工場レポートのデータをインポートします。このデータには、出荷時のイベントを説明するラベル付きテキストが含まれています。テキスト データを string としてインポートするために、テキスト タイプを 'string'
に指定します。
filename = "factoryReports.csv"; data = readtable(filename,'TextType','string'); head(data)
ans=8×5 table
Description Category Urgency Resolution Cost
_____________________________________________________________________ ____________________ ________ ____________________ _____
"Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45
"Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35
"There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200
"Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352
"Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55
"Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371
"A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441
"Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
この例の目的は、Category
列のラベルによってイベントを分類することです。データをクラスに分割するために、これらのラベルを categorical に変換します。
data.Category = categorical(data.Category);
ヒストグラムを使用してデータ内のクラスの分布を表示します。
figure histogram(data.Category); xlabel("Class") ylabel("Frequency") title("Class Distribution")
次の手順は、これを学習セットと検証セットに分割することです。データを学習区画と、検証およびテスト用のホールドアウト区画に分割します。ホールドアウトの割合を 20% に指定します。
cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);
分割した table からテキスト データとラベルを抽出します。
textDataTrain = dataTrain.Description; textDataValidation = dataValidation.Description; YTrain = dataTrain.Category; YValidation = dataValidation.Category;
データが正しくインポートされたことを確認するために、ワード クラウドを使用して学習テキスト データを可視化します。
figure
wordcloud(textDataTrain);
title("Training Data")
テキスト データの前処理
テキスト データをトークン化および前処理する関数を作成します。例の最後にリストされている関数 preprocessText
は、以下の手順を実行します。
tokenizedDocument
を使用してテキストをトークン化します。lower
を使用してテキストを小文字に変換します。erasePunctuation
を使用して句読点を消去します。
関数 preprocessText
を使用して学習データと検証データを前処理します。
documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);
前処理した学習文書を最初の数個表示します。
documentsTrain(1:5)
ans = 5×1 tokenizedDocument: 9 tokens: items are occasionally getting stuck in the scanner spools 10 tokens: loud rattling and banging sounds are coming from assembler pistons 10 tokens: there are cuts to the power when starting the plant 5 tokens: fried capacitors in the assembler 4 tokens: mixer tripped the fuses
シーケンスへの文書の変換
文書を LSTM ネットワークに入力するために、単語符号化を使用して文書を数値インデックスのシーケンスに変換します。
単語符号化を作成するには、関数 wordEncoding
を使用します。
enc = wordEncoding(documentsTrain);
次の変換ステップは、すべての文書が同じ長さになるようにパディングと切り捨てを行うことです。関数 trainingOptions
には、入力シーケンスのパディングと切り捨てを自動的に行うオプションが用意されています。ただし、これらのオプションは、単語ベクトルのシーケンスにはあまり適していません。代わりに、シーケンスのパディングと切り捨てを手動で行います。単語ベクトルのシーケンスを "左パディング" し、切り捨てを行うことで、学習が改善される可能性があります。
文書のパディングと切り捨てを行うには、まず、ターゲットの長さを選択し、それより長い文書を切り捨て、それより短い文書を左パディングします。最良の結果を得るには、大量のデータを破棄することなくターゲットの長さを短くする必要があります。適切なターゲットの長さを求めるために、学習文書の長さのヒストグラムを表示します。
documentLengths = doclength(documentsTrain); figure histogram(documentLengths) title("Document Lengths") xlabel("Length") ylabel("Number of Documents")
学習文書のほとんどは 10 トークン未満です。これを切り捨てとパディングのターゲットの長さとして使用します。
doc2sequence
を使用して、文書を数値インデックスのシーケンスに変換します。シーケンスの長さが 10 になるように切り捨てと左パディングを行うために、'Length'
オプションを 10 に設定します。
sequenceLength = 10;
XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength);
XTrain(1:5)
ans=5×1 cell array
{1×10 double}
{1×10 double}
{1×10 double}
{1×10 double}
{1×10 double}
同じオプションを使用して検証文書をシーケンスに変換します。
XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);
LSTM ネットワークの作成と学習
LSTM ネットワーク アーキテクチャを定義します。シーケンス データをネットワークに入力するために、シーケンス入力層を含め、入力サイズを 1 に設定します。次に、次元が 50 の単語埋め込み層と、単語符号化と同じ数の単語を含めます。次に、LSTM 層を含め、隠れユニット数を 80 に設定します。sequence-to-label 分類問題に LSTM 層を使用するには、出力モードを 'last'
に設定します。最後に、クラスの数と同じサイズの全結合層、ソフトマックス層、および分類層を追加します。
inputSize = 1; embeddingDimension = 50; numHiddenUnits = 80; numWords = enc.NumWords; numClasses = numel(categories(YTrain)); layers = [ ... sequenceInputLayer(inputSize) wordEmbeddingLayer(embeddingDimension,numWords) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer]
layers = 6x1 Layer array with layers: 1 '' Sequence Input Sequence input with 1 dimensions 2 '' Word Embedding Layer Word embedding layer with 50 dimensions and 423 unique words 3 '' LSTM LSTM with 80 hidden units 4 '' Fully Connected 4 fully connected layer 5 '' Softmax softmax 6 '' Classification Output crossentropyex
学習オプションの指定
学習オプションを指定します。
Adam ソルバーを使用して学習させます。
ミニバッチ サイズを 16 に指定します。
すべてのエポックでデータをシャッフルします。
'Plots'
オプションを'training-progress'
に設定して、学習の進行状況を監視します。'ValidationData'
オプションを使用して、検証データを指定します。'Verbose'
オプションをfalse
に設定して、詳細出力を非表示にします。
既定では、使用可能な GPU がある場合、trainNetwork
は GPU を使用します。そうでない場合は CPU が使用されます。実行環境を手動で指定するには、trainingOptions
の名前と値のペアの引数 'ExecutionEnvironment'
を使用します。CPU での学習は、GPU での学習よりも大幅に時間がかかる場合があります。GPU を使用した学習には、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
options = trainingOptions('adam', ... 'MiniBatchSize',16, ... 'GradientThreshold',2, ... 'Shuffle','every-epoch', ... 'ValidationData',{XValidation,YValidation}, ... 'Plots','training-progress', ... 'Verbose',false);
関数 trainNetwork
を使用して LSTM ネットワークに学習させます。
net = trainNetwork(XTrain,YTrain,layers,options);
新しいデータを使用した予測
3 つの新しいレポートのイベント タイプを分類します。新しいレポートを格納する string 配列を作成します。
reportsNew = [ ... "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
学習文書と同じ前処理手順を使用してテキスト データを前処理します。
documentsNew = preprocessText(reportsNew);
学習シーケンスの作成時と同じオプションで doc2sequence
を使用して、テキスト データをシーケンスに変換します。
XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);
学習済みの LSTM ネットワークを使用して新しいシーケンスを分類します。
labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
前処理関数
関数 preprocessText
は、以下の手順を実行します。
tokenizedDocument
を使用してテキストをトークン化します。lower
を使用してテキストを小文字に変換します。erasePunctuation
を使用して句読点を消去します。
function documents = preprocessText(textData) % Tokenize the text. documents = tokenizedDocument(textData); % Convert to lowercase. documents = lower(documents); % Erase punctuation. documents = erasePunctuation(documents); end
参考
fastTextWordEmbedding
| wordEmbeddingLayer
| tokenizedDocument
| lstmLayer
(Deep Learning Toolbox) | trainNetwork
(Deep Learning Toolbox) | trainingOptions
(Deep Learning Toolbox) | doc2sequence
| sequenceInputLayer
(Deep Learning Toolbox) | wordcloud
関連するトピック
- Classify Text Data Using Convolutional Neural Network
- Classify Out-of-Memory Text Data Using Deep Learning
- 深層学習を使用したテキストの生成 (Deep Learning Toolbox)
- 深層学習を使用した単語単位のテキスト生成
- 分類用の単純なテキスト モデルの作成
- トピック モデルを使用したテキスト データの解析
- マルチワード フレーズを使用したテキスト データの解析
- センチメント分類器の学習
- 深層学習を使用したシーケンスの分類 (Deep Learning Toolbox)
- MATLAB による深層学習 (Deep Learning Toolbox)