ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

カスタム ミニバッチ データストアを使用したメモリ外のテキスト データの分類

この例では、カスタム ミニバッチ データストアを使用して深層学習ネットワークでメモリ外のテキスト データを分類する方法を説明します。

"ミニバッチ データストア" とは、バッチ単位でのデータの読み取りをサポートするデータストアの実装です。ミニバッチ データストアは、深層学習アプリケーションの学習データセット、検証データセット、テスト データセット、および予測データセットのソースとして使用できます。ミニバッチ データストアを使用して、メモリ外のデータを読み取るか、データのバッチを読み取る際に特定の前処理演算を実行します。

ネットワークの学習時に、入力データのパディング、切り捨て、または分割が行われ、同じ長さのシーケンスのミニバッチが作成されます。関数 trainingOptions には、入力シーケンスのパディングと切り捨てを自動的に行うオプションが用意されていますが、これらのオプションは、単語ベクトルのシーケンスにはあまり適していません。さらに、この関数はカスタム データストアにあるデータのパディングをサポートしていません。代わりに、シーケンスのパディングと切り捨てを手動で行わなければなりません。単語ベクトルのシーケンスの "左パディング" と切り捨てを行うと、学習が改善される可能性があります。

Classify Text Data Using Deep Learning (Text Analytics Toolbox)では、すべてのドキュメントが同じ長さになるように切り捨てとパディングを行っています。このプロセスでは、非常に短いドキュメントに多数のパディングが追加され、非常に長いドキュメントから多数のデータが破棄されます。

別の方法として、大量のパディングの追加や大量のデータの破棄を避けるため、ミニバッチをネットワークに入力するカスタム ミニバッチ データストアを作成します。カスタム ミニバッチ データストア textDatastore.m は、ドキュメントのミニバッチをシーケンスまたは単語インデックスに変換し、ミニバッチ内にある最長のドキュメントの長さになるように各ミニバッチの左パディングを行います。並べ替えられたデータの場合、ドキュメントが固定長にパディングされないので、このデータストアは、データに追加されるパディングの量を減らすのに役立ちます。同様に、このデータストアはドキュメントからデータをまったく破棄しません。

この例では、カスタム ミニバッチ データストア textDatastore.m を使用します。関数をカスタマイズして、このデータストアをデータに適応させることができます。独自のカスタム ミニバッチ データストアを作成する方法を示す例については、カスタム ミニバッチ データストアの開発を参照してください。

事前学習済みの単語埋め込みの読み込み

データストア textDatastore では、ドキュメントをベクトルのシーケンスに変換するために単語埋め込みが必要です。fastTextWordEmbedding を使用して、事前学習済みの単語埋め込みを読み込みます。この関数には、Text Analytics Toolbox™ Model for fastText English 16 Billion Token Word Embedding サポート パッケージが必要です。このサポート パッケージがインストールされていない場合、関数によってダウンロード用リンクが表示されます。

emb = fastTextWordEmbedding;

ドキュメントのミニバッチ データストアの作成

学習用のデータを含むデータストアを作成します。カスタム ミニバッチ データストア textDatastore は CSV ファイルから予測子とラベルを読み取ります。予測子については、データストアはドキュメントを単語インデックスのシーケンスに変換し、応答については、データストアは各ドキュメントの categorical ラベルを返します。

このデータストアを作成するには、まず、カスタム ミニバッチ データストア textDatastore.m をパスに保存します。カスタム ミニバッチ データストアの作成の詳細は、カスタム ミニバッチ データストアの開発を参照してください。

学習データについては、CSV ファイル "weatherReportsTrain.csv" を指定し、テキストとラベルがそれぞれ列 "event_narrative""event_type" にあることを指定します。

filenameTrain = "weatherReportsTrain.csv";
textName = "event_narrative";
labelName = "event_type";
dsTrain = textDatastore(filenameTrain,textName,labelName,emb)
dsTrain = 
  textDatastore with properties:

          Datastore: [1×1 matlab.io.datastore.TabularTextDatastore]
           TextName: "event_narrative"
          LabelName: "event_type"
            Classes: [1×39 string]
         NumClasses: 39
          Embedding: [1×1 wordEmbedding]
      MiniBatchSize: 128
    NumObservations: 19683

同じステップを使用して CSV ファイル "weatherReportsValidation.csv" から検証データを含むデータストアを作成します。

filenameValidation = "weatherReportsValidation.csv";
dsValidation = textDatastore(filenameValidation,textName,labelName,emb)
dsValidation = 
  textDatastore with properties:

          Datastore: [1×1 matlab.io.datastore.TabularTextDatastore]
           TextName: "event_narrative"
          LabelName: "event_type"
            Classes: [1×39 string]
         NumClasses: 39
          Embedding: [1×1 wordEmbedding]
      MiniBatchSize: 128
    NumObservations: 4218

LSTM ネットワークの作成と学習

LSTM ネットワーク アーキテクチャを定義します。シーケンス データをネットワークに入力するために、シーケンス入力層を含め、入力サイズを埋め込み次元に設定します。次に、LSTM 層を含め、隠れサイズを 180 に指定します。sequence-to-label 分類問題用の LSTM 層を使用するために、出力モードを 'last' に設定します。最後に、出力サイズがクラスの数に等しい全結合層や、ソフトマックス層と分類層を追加します。

inputSize = dsTrain.Embedding.Dimension;
hiddenSize = 180;
numClasses = dsTrain.NumClasses;

layers = [ ...
    sequenceInputLayer(inputSize)
    lstmLayer(hiddenSize,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

学習オプションを指定します。ソルバーを 'adam' に設定し、勾配しきい値を 1 に設定します。初期学習率を 0.01 に設定します。データストア textDatastore.m はシャッフルをサポートしていないため、'Shuffle''never' に設定します (シャッフルをサポートするデータストアを実装する方法を示す例については、カスタム ミニバッチ データストアの開発を参照)。検証データを指定するには、'ValidationData' オプションを使用します。学習の進行状況を監視するには、'Plots' オプションを 'training-progress' に設定します。詳細出力を表示しないようにするには、'Verbose'false に設定します。

既定では、利用可能な GPU がある場合、trainNetwork は GPU を使用します (Parallel Computing Toolbox™、および Compute Capability 3.0 以上の CUDA® 対応 GPU が必要)。そうでない場合は CPU が使用されます。実行環境を手動で指定するには、trainingOptions の名前と値のペアの引数 'ExecutionEnvironment' を使用します。CPU での学習にかかる時間は、GPU での学習よりも大幅に長くなる可能性があります。

options = trainingOptions('adam', ...
    'GradientThreshold',1, ...
    'InitialLearnRate',0.01, ...
    'Shuffle','never', ...
    'Plots','training-progress', ...
    'Verbose',false, ...
    'ValidationData',dsValidation);

関数 trainNetwork を使用して LSTM ネットワークに学習させます。

net = trainNetwork(dsTrain,layers,options);

LSTM ネットワークのテスト

ドキュメントとラベルを含むデータストアを作成します。

filenameTest = "weatherReportsTest.csv";
dsTest = textDatastore(filenameTest,textName,labelName,emb)
dsTest = 
  textDatastore with properties:

          Datastore: [1×1 matlab.io.datastore.TabularTextDatastore]
           TextName: "event_narrative"
          LabelName: "event_type"
            Classes: [1×39 string]
         NumClasses: 39
          Embedding: [1×1 wordEmbedding]
      MiniBatchSize: 128
    NumObservations: 4217

学習済みの LSTM ネットワークを使用してテスト ドキュメントを分類します。

YPred = classify(net,dsTest);

分類精度を計算します。この精度は、ネットワークによって予測が正しく行われるラベルの割合です。

YTest = readLabels(dsTest);
accuracy = sum(YPred == YTest)/numel(YPred)
accuracy = 0.8328

参考

| | | | | | | | | |

関連するトピック