ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

深層学習を使用したテキスト データの分類

この例では、深層学習長短期記憶 (LSTM) ネットワークを使用して天気予報の説明文を分類する方法を説明します。

テキスト データでは必然的にデータが順に並んでいます。1 つのテキストは単語のシーケンスであり、それらの単語間には依存関係がある可能性があります。長期的な依存関係を学習して、シーケンス データの分類に使用するために、LSTM ニューラル ネットワークを使用します。LSTM ネットワークは、再帰型ニューラル ネットワーク (RNN) の一種で、シーケンス データのタイム ステップ間の長期的な依存関係を学習できます。

テキストを LSTM ネットワークに入力するには、まず、テキスト データを数値シーケンスに変換します。ドキュメントを数値インデックスのシーケンスにマッピングする単語符号化を使用して、これを実現できます。また、より正確な結果を得るため、単語埋め込み層をネットワークに含めます。単語埋め込みは、ボキャブラリ内の単語をスカラー インデックスではなく数値ベクトルにマッピングします。これらの埋め込みでは、意味の似ている単語が類似のベクトルを持つように、単語のセマンティックな詳細を取得します。また、ベクトル演算を介して単語間の関係をモデル化します。たとえば、"king is to queen as man is to woman" という関係は方程式 king man + woman = queen で記述されます。

この例では、4 つのステップで LSTM ネットワークに学習させてそれを使用します。

  • データをインポートして前処理します。

  • 単語符号化を使用して単語を数値シーケンスに変換します。

  • 単語埋め込み層のある LSTM ネットワークを作成し、このネットワークに学習させます。

  • 学習済み LSTM ネットワークを使用して新しいテキスト データを分類します。

データのインポート

天気予報データをインポートします。このデータには、天気事象を説明するラベル付きテキストが含まれています。テキスト データを string としてインポートするために、テキスト タイプを 'string' に指定します。

filename = "weatherReports.csv";
data = readtable(filename,'TextType','string');
head(data)
ans=8×16 table
            Time             event_id          state              event_type         damage_property    damage_crops    begin_lat    begin_lon    end_lat    end_lon                                                                                             event_narrative                                                                                             storm_duration    begin_day    end_day    year       end_timestamp    
    ____________________    __________    ________________    ___________________    _______________    ____________    _________    _________    _______    _______    _________________________________________________________________________________________________________________________________________________________________________________________________    ______________    _________    _______    ____    ____________________

    22-Jul-2016 16:10:00    6.4433e+05    "MISSISSIPPI"       "Thunderstorm Wind"       ""                "0.00K"         34.14        -88.63     34.122     -88.626    "Large tree down between Plantersville and Nettleton."                                                                                                                                                  00:05:00          22          22       2016    22-Jul-0016 16:15:00
    15-Jul-2016 17:15:00    6.5182e+05    "SOUTH CAROLINA"    "Heavy Rain"              "2.00K"           "0.00K"         34.94        -81.03      34.94      -81.03    "One to two feet of deep standing water developed on a street on the Winthrop University campus after more than an inch of rain fell in less than an hour. One vehicle was stalled in the water."       00:00:00          15          15       2016    15-Jul-0016 17:15:00
    15-Jul-2016 17:25:00    6.5183e+05    "SOUTH CAROLINA"    "Thunderstorm Wind"       "0.00K"           "0.00K"         35.01        -80.93      35.01      -80.93    "NWS Columbia relayed a report of trees blown down along Tom Hall St."                                                                                                                                  00:00:00          15          15       2016    15-Jul-0016 17:25:00
    16-Jul-2016 12:46:00    6.5183e+05    "NORTH CAROLINA"    "Thunderstorm Wind"       "0.00K"           "0.00K"         35.64        -82.14      35.64      -82.14    "Media reported two trees blown down along I-40 in the Old Fort area."                                                                                                                                  00:00:00          16          16       2016    16-Jul-0016 12:46:00
    15-Jul-2016 14:28:00    6.4332e+05    "MISSOURI"          "Hail"                    ""                ""              36.45        -89.97      36.45      -89.97    ""                                                                                                                                                                                                      00:07:00          15          15       2016    15-Jul-0016 14:35:00
    15-Jul-2016 16:31:00    6.4332e+05    "ARKANSAS"          "Thunderstorm Wind"       ""                "0.00K"         35.85         -90.1     35.838     -90.087    "A few tree limbs greater than 6 inches down on HWY 18 in Roseland."                                                                                                                                    00:09:00          15          15       2016    15-Jul-0016 16:40:00
    15-Jul-2016 16:03:00    6.4343e+05    "TENNESSEE"         "Thunderstorm Wind"       "20.00K"          "0.00K"        35.056       -89.937      35.05     -89.904    "Awning blown off a building on Lamar Avenue. Multiple trees down near the intersection of Winchester and Perkins."                                                                                     00:07:00          15          15       2016    15-Jul-0016 16:10:00
    15-Jul-2016 17:27:00    6.4344e+05    "TENNESSEE"         "Hail"                    ""                ""             35.385        -89.78     35.385      -89.78    "Quarter size hail near Rosemark."                                                                                                                                                                      00:05:00          15          15       2016    15-Jul-0016 17:32:00

予報内容が空の table の行を削除します。

idxEmpty = strlength(data.event_narrative) == 0;
data(idxEmpty,:) = [];

この例の目的は、event_type 列のラベルによって事象を分類することです。データをクラスに分割するために、これらのラベルを categorical に変換します。

data.event_type = categorical(data.event_type);

ヒストグラムを使用してデータのクラスの分布を表示します。ラベルを読み取りやすくするために、Figure の幅を大きくします。

f = figure;
f.Position(3) = 1.5*f.Position(3);

h = histogram(data.event_type);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

データのクラスは不均衡で、ほとんど観測値が含まれていないクラスが多数あります。クラスがこのように不均衡な場合、あまり正確でないモデルにネットワークが収束する可能性があります。この問題を防ぐために、出現回数が 10 回未満のクラスを削除します。

ヒストグラムからクラスの頻度数とクラス名を取得します。

classCounts = h.BinCounts;
classNames = h.Categories;

観測値が 10 個未満のクラスを見つけます。

idxLowCounts = classCounts < 10;
infrequentClasses = classNames(idxLowCounts)
infrequentClasses = 1×8 cell array
    {'Freezing Fog'}    {'Hurricane'}    {'Lakeshore Flood'}    {'Marine Dense Fog'}    {'Marine Strong Wind'}    {'Marine Tropical Depression'}    {'Seiche'}    {'Sneakerwave'}

頻度の少ないこれらのクラスをデータから削除します。removecats を使用して、categorical データから未使用のカテゴリを削除します。

idxInfrequent = ismember(data.event_type,infrequentClasses);
data(idxInfrequent,:) = [];
data.event_type = removecats(data.event_type);

これで、データは妥当なサイズのクラスに選別されます。次のステップは、これを学習セット、検証セット、およびテスト セットに分割することです。データを学習区画と、検証およびテスト用のホールドアウト区画に分割します。ホールドアウトの割合を 30% に指定します。

cvp = cvpartition(data.event_type,'Holdout',0.3);
dataTrain = data(training(cvp),:);
dataHeldOut = data(test(cvp),:);

検証セットを取得するためにホールドアウト セットをもう一度分割します。ホールドアウトの割合を 50% に指定します。この結果、学習用観測値 70%、検証用観測値 15%、およびテスト用観測値 15% に分割されます。

cvp = cvpartition(dataHeldOut.event_type,'HoldOut',0.5);
dataValidation = dataHeldOut(training(cvp),:);
dataTest = dataHeldOut(test(cvp),:);

分割した table からテキスト データとラベルを抽出します。

textDataTrain = dataTrain.event_narrative;
textDataValidation = dataValidation.event_narrative;
textDataTest = dataTest.event_narrative;
YTrain = dataTrain.event_type;
YValidation = dataValidation.event_type;
YTest = dataTest.event_type;

データが正しくインポートされたことを確認するために、ワード クラウドを使用して学習テキスト データを可視化します。

figure
wordcloud(textDataTrain);
title("Training Data")

テキスト データの前処理

学習データを前処理します。テキストを小文字に変換し、トークン化してから、句読点を消去します。単語埋め込みの適合が悪化する原因となるため、単語の語幹処理や削除は行いません。

textDataTrain = lower(textDataTrain);
documentsTrain = tokenizedDocument(textDataTrain);
documentsTrain = erasePunctuation(documentsTrain);

textDataValidation = lower(textDataValidation);
documentsValidation = tokenizedDocument(textDataValidation);
documentsValidation = erasePunctuation(documentsValidation);

前処理した学習ドキュメントを最初の数個表示します。

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

   (1,1)   7 tokens: large tree down between plantersville and nettleton
   (2,1)  37 tokens: one to two feet of deep standing water developed on a stre…
   (3,1)  13 tokens: nws columbia relayed a report of trees blown down along to…
   (4,1)  13 tokens: media reported two trees blown down along i40 in the old f…
   (5,1)  14 tokens: a few tree limbs greater than 6 inches down on hwy 18 in r…

ドキュメントのシーケンスへの変換

ドキュメントを LSTM ネットワークに入力するために、単語符号化を使用してドキュメントを数値インデックスのシーケンスに変換します。

単語符号化を作成するには、関数 wordEncoding を使用します。

enc = wordEncoding(documentsTrain);

次の変換ステップは、すべてのドキュメントが同じ長さになるようにパディングと切り捨てを行うことです。関数 trainingOptions には、入力シーケンスのパディングと切り捨てを自動的に行うオプションが用意されています。ただし、これらのオプションは、単語ベクトルのシーケンスにはあまり適していません。代わりに、シーケンスのパディングと切り捨てを手動で行います。単語ベクトルのシーケンスの "左パディング" と切り捨てを行うと、学習が改善される可能性があります。

ドキュメントのパディングと切り捨てを行うには、まず、ターゲット長さを選択し、それより長いドキュメントを切り捨て、それにより短いドキュメントの左パディングを行います。最良の結果を得るには、大量のデータを破棄せずにターゲット長さを短くする必要があります。適切なターゲット長さを求めるために、学習ドキュメントの長さのヒストグラムを表示します。

documentLengths = doclength(documentsTrain);
figure
histogram(documentLengths)
title("Document Lengths")
xlabel("Length")
ylabel("Number of Documents")

学習ドキュメントのほとんどは 75 トークン未満です。これを切り捨てとパディングのターゲット長さとして使用します。

doc2sequence を使用してドキュメントを数値インデックスのシーケンスに変換します。シーケンスの長さが 75 になるように切り捨てと左パディングを行うために、'Length' オプションを 75 に設定します。

XTrain = doc2sequence(enc,documentsTrain,'Length',75);
XTrain(1:5)
ans = 5×1 cell array
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}

同じオプションを使用して検証ドキュメントをシーケンスに変換します。

XValidation = doc2sequence(enc,documentsValidation,'Length',75);

LSTM ネットワークの作成と学習

LSTM ネットワーク アーキテクチャを定義します。シーケンス データをネットワークに入力するために、シーケンス入力層を含め、入力サイズを 1 に設定します。次に、次元が 100 で単語の数が単語符号化と同じ単語埋め込み層を含めます。さらに、LSTM 層を含め、隠れユニットの数を 180 に設定します。sequence-to-label 分類問題用の LSTM 層を使用するために、出力モードを 'last' に設定します。最後に、クラスの数と同じサイズの全結合層や、ソフトマックス層と分類層を追加します。

inputSize = 1;
embeddingDimension = 100;
numHiddenUnits = enc.NumWords;
hiddenSize = 180;
numClasses = numel(categories(YTrain));

layers = [ ...
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension,numHiddenUnits)
    lstmLayer(hiddenSize,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]
layers = 
  6x1 Layer array with layers:

     1   ''   Sequence Input          Sequence input with 1 dimensions
     2   ''   Word Embedding Layer    Word embedding layer with 100 dimensions and 16954 unique words
     3   ''   LSTM                    LSTM with 180 hidden units
     4   ''   Fully Connected         39 fully connected layer
     5   ''   Softmax                 softmax
     6   ''   Classification Output   crossentropyex

学習オプションを指定します。ソルバーを 'adam' に設定し、10 エポック学習させ、勾配しきい値を 1 に設定します。初期学習率を 0.01 に設定します。学習の進行状況を監視するには、'Plots' オプションを 'training-progress' に設定します。検証データを指定するには、'ValidationData' オプションを使用します。詳細出力を表示しないようにするには、'Verbose'false に設定します。

既定では、利用可能な GPU がある場合、trainNetwork は GPU を使用します (Parallel Computing Toolbox™、および Compute Capability 3.0 以上の CUDA® 対応 GPU が必要)。そうでない場合は CPU が使用されます。実行環境を手動で指定するには、trainingOptions の名前と値のペアの引数 'ExecutionEnvironment' を使用します。CPU での学習にかかる時間は、GPU での学習よりも大幅に長くなる可能性があります。

options = trainingOptions('adam', ...
    'MaxEpochs',10, ...    
    'GradientThreshold',1, ...
    'InitialLearnRate',0.01, ...
    'ValidationData',{XValidation,YValidation}, ...
    'Plots','training-progress', ...
    'Verbose',false);

関数 trainNetwork を使用して LSTM ネットワークに学習させます。

net = trainNetwork(XTrain,YTrain,layers,options);

LSTM ネットワークのテスト

LSTM ネットワークをテストするには、まず、学習データと同じようにテスト データを準備します。その後、学習済みの LSTM ネットワーク net を使用して、前処理済みのテスト データについて予測を行います。

学習ドキュメントと同じステップを使用してテスト データを前処理します。

textDataTest = lower(textDataTest);
documentsTest = tokenizedDocument(textDataTest);
documentsTest = erasePunctuation(documentsTest);

学習シーケンスの作成時と同じオプションで doc2sequence を使用して、テスト ドキュメントをシーケンスに変換します。

XTest = doc2sequence(enc,documentsTest,'Length',75);
XTest(1:5)
ans = 5×1 cell array
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}
    {1×75 double}

学習済みの LSTM ネットワークを使用してテスト ドキュメントを分類します。

YPred = classify(net,XTest);

分類精度を計算します。この精度は、ネットワークによって予測が正しく行われるラベルの割合です。

accuracy = sum(YPred == YTest)/numel(YPred)
accuracy = 0.8691

新しいデータを使用した予測

3 つの新しい天気予報の事象タイプを分類します。新しい天気予報を含む string 配列を作成します。

reportsNew = [ ...
    "Lots of water damage to computer equipment inside the office."
    "A large tree is downed and blocking traffic outside Apple Hill."
    "Damage to many car windshields in parking lot."];

学習ドキュメントと同じステップを使用してテキスト データを前処理します。

reportsNew = lower(reportsNew);
documentsNew = tokenizedDocument(reportsNew);
documentsNew = erasePunctuation(documentsNew);

学習シーケンスの作成時と同じオプションで doc2sequence を使用して、テキスト データをシーケンスに変換します。

XNew = doc2sequence(enc,documentsNew,'Length',75);

学習済みの LSTM ネットワークを使用して新しいシーケンスを分類します。

[labelsNew,score] = classify(net,XNew);

予測されたラベルと共に天気予報を表示します。

[reportsNew string(labelsNew)]
ans = 3×2 string array
    "lots of water damage to computer equipment inside the office."      "Flash Flood"      
    "a large tree is downed and blocking traffic outside apple hill."    "Thunderstorm Wind"
    "damage to many car windshields in parking lot."                     "Hail"             

参考

| | | | | | | |

関連するトピック