Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

深層学習を使用したテキスト データの分類

この例では、深層学習の長短期記憶 (LSTM) ネットワークを使用してテキスト データを分類する方法を示します。

テキスト データでは必然的にデータが順に並んでいます。一部のテキストは単語のシーケンスであり、それらの単語間には依存関係がある可能性があります。長期的な依存関係を学習してシーケンス データの分類に使用するために、LSTM ニューラル ネットワークを使用します。LSTM ネットワークは、再帰型ニューラル ネットワーク (RNN) の一種で、シーケンス データのタイム ステップ間の長期的な依存関係を学習できます。

テキストを LSTM ネットワークに入力するには、まず、テキスト データを数値シーケンスに変換します。文書を数値インデックスのシーケンスにマッピングする単語符号化を使用して、これを実現できます。また、より正確な結果を得るため、単語埋め込み層をネットワークに含めます。単語埋め込みは、ボキャブラリ内の単語をスカラー インデックスではなく数値ベクトルにマッピングします。これらの埋め込みでは、意味の似ている単語が類似のベクトルをもつように、単語のセマンティックな詳細を取得します。また、ベクトル演算を使用して単語間の関係をモデル化します。たとえば、"Rome is to Italy as Paris is to France" (イタリアに対してのローマは、フランスに対してのパリに同じ) という関係は、方程式 Italy Rome + Paris = France で記述されます。

この例では、4 つの手順で LSTM ネットワークに学習させてそれを使用します。

  • データをインポートして前処理します。

  • 単語符号化を使用して単語を数値シーケンスに変換します。

  • 単語埋め込み層のある LSTM ネットワークを作成し、このネットワークに学習させます。

  • 学習済み LSTM ネットワークを使用して新しいテキスト データを分類します。

データのインポート

工場レポートのデータをインポートします。このデータには、出荷時のイベントを説明するラベル付きテキストが含まれています。テキスト データを string としてインポートするために、テキスト タイプを 'string' に指定します。

filename = "factoryReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

この例の目的は、Category 列のラベルによってイベントを分類することです。データをクラスに分割するために、これらのラベルを categorical に変換します。

data.Category = categorical(data.Category);

ヒストグラムを使用してデータ内のクラスの分布を表示します。

figure
histogram(data.Category);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

次の手順は、これを学習セットと検証セットに分割することです。データを学習区画と、検証およびテスト用のホールドアウト区画に分割します。ホールドアウトの割合を 20% に指定します。

cvp = cvpartition(data.Category,'Holdout',0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

分割した table からテキスト データとラベルを抽出します。

textDataTrain = dataTrain.Description;
textDataValidation = dataValidation.Description;
YTrain = dataTrain.Category;
YValidation = dataValidation.Category;

データが正しくインポートされたことを確認するために、ワード クラウドを使用して学習テキスト データを可視化します。

figure
wordcloud(textDataTrain);
title("Training Data")

テキスト データの前処理

テキスト データをトークン化および前処理する関数を作成します。例の最後にリストされている関数 preprocessText は、以下の手順を実行します。

  1. tokenizedDocument を使用してテキストをトークン化します。

  2. lower を使用してテキストを小文字に変換します。

  3. erasePunctuation を使用して句読点を消去します。

関数 preprocessText を使用して学習データと検証データを前処理します。

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

前処理した学習文書を最初の数個表示します。

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

     9 tokens: items are occasionally getting stuck in the scanner spools
    10 tokens: loud rattling and banging sounds are coming from assembler pistons
    10 tokens: there are cuts to the power when starting the plant
     5 tokens: fried capacitors in the assembler
     4 tokens: mixer tripped the fuses

シーケンスへの文書の変換

文書を LSTM ネットワークに入力するために、単語符号化を使用して文書を数値インデックスのシーケンスに変換します。

単語符号化を作成するには、関数 wordEncoding を使用します。

enc = wordEncoding(documentsTrain);

次の変換ステップは、すべての文書が同じ長さになるようにパディングと切り捨てを行うことです。関数 trainingOptions には、入力シーケンスのパディングと切り捨てを自動的に行うオプションが用意されています。ただし、これらのオプションは、単語ベクトルのシーケンスにはあまり適していません。代わりに、シーケンスのパディングと切り捨てを手動で行います。単語ベクトルのシーケンスを "左パディング" し、切り捨てを行うことで、学習が改善される可能性があります。

文書のパディングと切り捨てを行うには、まず、ターゲットの長さを選択し、それより長い文書を切り捨て、それより短い文書を左パディングします。最良の結果を得るには、大量のデータを破棄することなくターゲットの長さを短くする必要があります。適切なターゲットの長さを求めるために、学習文書の長さのヒストグラムを表示します。

documentLengths = doclength(documentsTrain);
figure
histogram(documentLengths)
title("Document Lengths")
xlabel("Length")
ylabel("Number of Documents")

学習文書のほとんどは 10 トークン未満です。これを切り捨てとパディングのターゲットの長さとして使用します。

doc2sequence を使用して、文書を数値インデックスのシーケンスに変換します。シーケンスの長さが 10 になるように切り捨てと左パディングを行うために、'Length' オプションを 10 に設定します。

sequenceLength = 10;
XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength);
XTrain(1:5)
ans=5×1 cell array
    {1×10 double}
    {1×10 double}
    {1×10 double}
    {1×10 double}
    {1×10 double}

同じオプションを使用して検証文書をシーケンスに変換します。

XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);

LSTM ネットワークの作成と学習

LSTM ネットワーク アーキテクチャを定義します。シーケンス データをネットワークに入力するために、シーケンス入力層を含め、入力サイズを 1 に設定します。次に、次元が 50 の単語埋め込み層と、単語符号化と同じ数の単語を含めます。次に、LSTM 層を含め、隠れユニット数を 80 に設定します。sequence-to-label 分類問題に LSTM 層を使用するには、出力モードを 'last' に設定します。最後に、クラスの数と同じサイズの全結合層、ソフトマックス層、および分類層を追加します。

inputSize = 1;
embeddingDimension = 50;
numHiddenUnits = 80;

numWords = enc.NumWords;
numClasses = numel(categories(YTrain));

layers = [ ...
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension,numWords)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  6x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   Word Embedding Layer    Word embedding layer with 50 dimensions and 423 unique words
     3   ''   LSTM                    LSTM with 80 hidden units
     4   ''   Fully Connected         4 fully connected layer
     5   ''   Softmax                 softmax
     6   ''   Classification Output   crossentropyex

学習オプションの指定

学習オプションを指定します。

  • Adam ソルバーを使用して学習させます。

  • ミニバッチ サイズを 16 に指定します。

  • すべてのエポックでデータをシャッフルします。

  • 'Plots' オプションを 'training-progress' に設定して、学習の進行状況を監視します。

  • 'ValidationData' オプションを使用して、検証データを指定します。

  • 'Verbose' オプションを false に設定して、詳細出力を非表示にします。

既定では、使用可能な GPU がある場合、trainNetwork は GPU を使用します。そうでない場合は CPU が使用されます。実行環境を手動で指定するには、trainingOptions の名前と値のペアの引数 'ExecutionEnvironment' を使用します。CPU での学習は、GPU での学習よりも大幅に時間がかかる場合があります。GPU を使用した学習には、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

options = trainingOptions('adam', ...
    'MiniBatchSize',16, ...
    'GradientThreshold',2, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{XValidation,YValidation}, ...
    'Plots','training-progress', ...
    'Verbose',false);

関数 trainNetwork を使用して LSTM ネットワークに学習させます。

net = trainNetwork(XTrain,YTrain,layers,options);

新しいデータを使用した予測

3 つの新しいレポートのイベント タイプを分類します。新しいレポートを格納する string 配列を作成します。

reportsNew = [ ...
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

学習文書と同じ前処理手順を使用してテキスト データを前処理します。

documentsNew = preprocessText(reportsNew);

学習シーケンスの作成時と同じオプションで doc2sequence を使用して、テキスト データをシーケンスに変換します。

XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);

学習済みの LSTM ネットワークを使用して新しいシーケンスを分類します。

labelsNew = classify(net,XNew)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

前処理関数

関数 preprocessText は、以下の手順を実行します。

  1. tokenizedDocument を使用してテキストをトークン化します。

  2. lower を使用してテキストを小文字に変換します。

  3. erasePunctuation を使用して句読点を消去します。

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);

end

参考

| | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | | (Deep Learning Toolbox) |

関連するトピック