メインコンテンツ

カスタム学習ループを使用したテキスト データの分類

この例では、カスタム学習ループのある深層学習の双方向長短期記憶 (BiLSTM) ネットワークを使用してテキスト データを分類する方法を説明します。

trainnet 関数を使用して深層学習ネットワークに学習させるときに、trainingOptions で必要なオプション (カスタム ソルバーなど) が提供されない場合は、自動微分を使用して独自のカスタム学習ループを定義できます。関数 trainnet を使用してテキスト データを分類する方法を示す例については、深層学習を使用したテキスト データの分類 (Deep Learning Toolbox)を参照してください。

この例では、確率的勾配降下アルゴリズム (モーメンタムなし) を使用してテキスト データを分類するためにネットワークに学習させます。

データのインポート

工場レポートのデータをインポートします。このデータには、出荷時のイベントを説明するラベル付きテキストが含まれています。テキスト データを string としてインポートするために、テキスト タイプを "string" に指定します。

filename = "factoryReports.csv";
data = readtable(filename,TextType="string");
head(data)
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

この例の目的は、Category 列のラベルによって事象を分類することです。データをクラスに分割するために、これらのラベルを categorical に変換します。

data.Category = categorical(data.Category);

ヒストグラムを使用してデータのクラスの分布を表示します。

figure
histogram(data.Category);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

次の手順は、これを学習セットと検証セットに分割することです。データを学習区画と、検証およびテスト用のホールドアウト区画に分割します。ホールドアウトの割合を 20% に指定します。

cvp = cvpartition(data.Category,Holdout=0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

分割した table からテキスト データとラベルを抽出します。

textDataTrain = dataTrain.Description;
textDataValidation = dataValidation.Description;
TTrain = dataTrain.Category;
TValidation = dataValidation.Category;

データが正しくインポートされたことを確認するために、ワード クラウドを使用して学習テキスト データを可視化します。

figure
wordcloud(textDataTrain);
title("Training Data")

クラス数を表示します。

classes = categories(TTrain);
numClasses = numel(classes)
numClasses = 
4

テキスト データの前処理

テキスト データをトークン化および前処理する関数を作成します。例の最後にリストされている関数 preprocessText は以下のステップを実行します。

  1. tokenizedDocument を使用してテキストをトークン化する。

  2. lower を使用してテキストを小文字に変換する。

  3. erasePunctuation を使用して句読点を消去する。

関数 preprocessText を使用して学習データと検証データを前処理する。

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

前処理した学習ドキュメントを最初の数個表示します。

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

     9 tokens: items are occasionally getting stuck in the scanner spools
    10 tokens: there are cuts to the power when starting the plant
     5 tokens: fried capacitors in the assembler
     4 tokens: mixer tripped the fuses
     8 tokens: things continue to tumble off of the belt

arrayDatastore オブジェクトを作成してから、combine 関数を使用してそれらを結合することによって、学習ドキュメントとラベルの両方を含む単一のデータストアを作成します。

dsDocumentsTrain = arrayDatastore(documentsTrain,OutputType="cell");
dsTTrain = arrayDatastore(TTrain,OutputType="cell");
dsTrain = combine(dsDocumentsTrain,dsTTrain);

同じ手順を使用して、検証データ用のデータストアを作成します。

dsDocumentsValidation = arrayDatastore(documentsValidation,OutputType="cell");
dsTValidation = arrayDatastore(TValidation,OutputType="cell");
dsValidation = combine(dsDocumentsValidation,dsTValidation);

単語符号化の作成

ドキュメントを BiLSTM ネットワークに入力するために、単語符号化を使用してドキュメントを数値インデックスのシーケンスに変換します。

単語符号化を作成するには、関数 wordEncoding を使用します。

enc = wordEncoding(documentsTrain)
enc = 
  wordEncoding with properties:

      NumWords: 425
    Vocabulary: ["items"    "are"    "occasionally"    "getting"    "stuck"    "in"    "the"    "scanner"    "spools"    "there"    "cuts"    "to"    "power"    "when"    "starting"    "plant"    "fried"    "capacitors"    …    ] (1×425 string)

ネットワークの定義

BiLSTM ネットワーク アーキテクチャを定義します。シーケンス データをネットワークに入力するために、シーケンス入力層を含め、入力サイズを 1 に設定します。次に、25 次元の単語埋め込み層と、同じ数の単語の単語符号化を含めます。次に、BiLSTM 層を含め、隠れユニット数を 40 に設定します。sequence-to-label 分類問題に BiLSTM 層を使用するには、出力モードを "last" に設定します。最後に、クラスの数と同じサイズの全結合層と、ソフトマックス層を追加します。

inputSize = 1;
embeddingDimension = 25;
numHiddenUnits = 40;

numWords = enc.NumWords;

layers = [
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension,numWords)
    bilstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer]
layers = 
  5×1 Layer array with layers:

     1   ''   Sequence Input         Sequence input with 1 dimensions
     2   ''   Word Embedding Layer   Word embedding layer with 25 dimensions and 425 unique words
     3   ''   BiLSTM                 BiLSTM with 40 hidden units
     4   ''   Fully Connected        4 fully connected layer
     5   ''   Softmax                softmax

層配列を dlnetwork オブジェクトに変換します。

net = dlnetwork(layers)
net = 
  dlnetwork with properties:

         Layers: [5×1 nnet.cnn.layer.Layer]
    Connections: [4×2 table]
     Learnables: [6×3 table]
          State: [2×3 table]
     InputNames: {'sequenceinput'}
    OutputNames: {'softmax'}
    Initialized: 1

  View summary with summary.

モデル損失関数の定義

モデル損失関数を作成します。関数 modelLoss は、dlnetwork オブジェクト net、入力データ X のミニバッチとそれに対応するターゲット ラベル T を受け取り、net 内の学習可能なパラメーターについての損失の勾配、および損失を返します。勾配を自動的に計算するには、関数 dlgradient を使用します。

function [loss,gradients] = modelLoss(net,X,T)

Y = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss,net.Learnables);

end

SGD 関数の定義

パラメーターと、パラメーターについての損失の勾配を受け取り、確率的勾配降下アルゴリズムを使用して更新されたパラメーターを返す関数 sgdStep を作成します。このアルゴリズムは θt+1=θt-ρL として表現されます。ここで、t は反復回数、ρ は学習率、L は勾配 (学習可能なパラメーターについての損失の微分) を表します。

function parameters = sgdStep(parameters,gradients,learnRate)

parameters = parameters - learnRate .* gradients;

end

カスタム更新関数の定義は、カスタム学習ループに必要な手順ではありません。あるいは、sgdmupdate (Deep Learning Toolbox)adamupdate (Deep Learning Toolbox)、およびrmspropupdate (Deep Learning Toolbox)などの組み込み更新関数を使用することもできます。

学習オプションの指定

ミニバッチ サイズを 64、学習率を 0.1 として、300 エポック学習させます。

numEpochs = 300;
miniBatchSize = 64;
learnRate = 0.1;

モデルの学習

データのミニバッチを処理および管理する minibatchqueue オブジェクトを作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用し、ドキュメントをシーケンスに変換して、ラベルを one-hot 符号化します。単語符号化をミニバッチに渡すには、2 つの入力を受け取る無名関数を作成します。

  • 予測子とターゲットを、それぞれ次元ラベル "BTC" (バッチ、時間、チャネル) と次元ラベル "CB" (チャネル、バッチ) で形式を整えます。minibatchqueue オブジェクトは、既定では、基となる型が singledlarray オブジェクトにデータを変換します。

  • GPU が利用できる場合、GPU で学習を行います。minibatchqueue オブジェクトは、既定では、GPU が利用可能な場合、各出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

mbq = minibatchqueue(dsTrain, ...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@(X,T) preprocessMiniBatch(X,T,enc), ...
    MiniBatchFormat=["BTC" "CB"]);

同じ手順を使用して、検証ドキュメント用の minibatchqueue オブジェクトを作成し、部分的なミニバッチも返すように指定します。

mbqValidation = minibatchqueue(dsValidation, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@(X,T) preprocessMiniBatch(X,T,enc), ...
    MiniBatchFormat=["BTC" "CB"], ...
    PartialMiniBatch="return");

学習の進行状況モニター用に合計反復回数を計算します。

numObservationsTrain = numel(documentsTrain);
numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize);
numIterations = numEpochs * numIterationsPerEpoch;

TrainingProgressMonitor オブジェクトを初期化します。監視オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。

monitor = trainingProgressMonitor( ...
    Metrics=["TrainingLoss","ValidationLoss"], ...
    Info="Epoch", ...
    XLabel="Iteration");

groupSubPlot(monitor,"Loss",["TrainingLoss" "ValidationLoss"])

ネットワークに学習をさせます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。反復が終了するたびに、学習の進行状況を表示します。各エポックの最後に、検証データを使用してネットワークを検証します。

各ミニバッチで次を行います。

  • ドキュメントを整数のシーケンスに変換し、ラベルを one-hot 符号化します。

  • 関数 dlfeval および modelLoss を使用してモデルの損失と勾配を評価します。

  • カスタム更新関数と dlupdate 関数を使用して、ネットワーク パラメーターを更新します。

  • 学習プロットを更新します。

  • モニターの Stop プロパティが true の場合は停止します。[停止] ボタンをクリックすると、TrainingProgressMonitor オブジェクトの Stop プロパティ値が true に変わります。

epoch = 0;
iteration = 0;

% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;

    % Shuffle data.
    shuffle(mbq);

    % Loop over mini-batches.
    while hasdata(mbq) && ~monitor.Stop
        iteration = iteration + 1;

        % Read mini-batch of data.
        [X,T] = next(mbq);

        % Evaluate the model loss and gradients using dlfeval and the
        % modelLoss function.
        [loss,gradients] = dlfeval(@modelLoss,net,X,T);

        % Update the network parameters using SGD.
        updateFcn = @(parameters,gradients) sgdStep(parameters,gradients,learnRate);
        net = dlupdate(updateFcn,net,gradients);

        % Display the training progress.
        recordMetrics(monitor,iteration,TrainingLoss=loss);
        updateInfo(monitor,Epoch=(epoch+" of "+numEpochs));

        % Validate network.
        if iteration == 1 || ~hasdata(mbq)
            lossValidation = testnet(net,mbqValidation,"crossentropy");

            % Update plot.
            recordMetrics(monitor,iteration,ValidationLoss=lossValidation);
        end

        monitor.Progress = 100*iteration/numIterations;
    end
end

モデルのテスト

testnet (Deep Learning Toolbox)関数を使用してニューラル ネットワークをテストします。単一ラベルの分類では、精度を評価します。精度は、正しい予測の割合です。既定では、testnet 関数は利用可能な GPU がある場合にそれを使用します。そうでない場合、関数は CPU を使用します。実行環境を手動で選択するには、testnet 関数の ExecutionEnvironment 引数を使用します。

accuracy = testnet(net,mbqValidation,"accuracy")
accuracy = 
92.7083

新しいデータを使用した予測

3 つの新しいレポートの事象タイプを分類します。新しいレポートを含む string 配列を作成します。

reportsNew = [
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

学習ドキュメントと同じ前処理手順を使用してテキスト データを前処理します。

documentsNew = preprocessText(reportsNew);
dsNew = arrayDatastore(documentsNew,OutputType="cell");

データのミニバッチを処理および管理する minibatchqueue オブジェクトを作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatchPredictors (この例の最後に定義) を使用し、ドキュメントをシーケンスに変換します。この前処理関数はラベル データを必要としません。単語符号化をミニバッチに渡すには、入力を 1 つだけ受け取る無名関数を作成します。

  • 次元ラベル "BTC" (batch、time、channel) を使用して予測子の形式を整えます。minibatchqueue オブジェクトは、既定では、基となる型が singledlarray オブジェクトにデータを変換します。

  • すべての観測値に対して予測を行うには、部分的なミニバッチを返します。

mbqNew = minibatchqueue(dsNew, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@(X) preprocessMiniBatchPredictors(X,enc), ...
    MiniBatchFormat="BTC", ...
    PartialMiniBatch="return");

minibatchpredict (Deep Learning Toolbox)関数を使用して予測を行い、scores2label (Deep Learning Toolbox)関数を使用して分類スコアをラベルに変換します。既定では、minibatchpredict 関数は利用可能な GPU がある場合にそれを使用します。そうでない場合、関数は CPU を使用します。実行環境を手動で選択するには、minibatchpredict 関数の ExecutionEnvironment 引数を使用します。

scores = minibatchpredict(net,mbqNew);
YNew = scores2label(scores,classes)'
YNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

サポート関数

テキスト前処理関数

関数 preprocessText は以下のステップを実行します。

  1. tokenizedDocument を使用してテキストをトークン化する。

  2. lower を使用してテキストを小文字に変換する。

  3. erasePunctuation を使用して句読点を消去する。

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);

end

ミニバッチ前処理関数

関数 preprocessMiniBatch は、ドキュメントのミニバッチを整数のシーケンスに変換し、ラベル データを one-hot 符号化します。

function [X,T] = preprocessMiniBatch(dataX,dataT,enc)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(dataX,enc);

% Extract labels from cell and concatenate.
T = cat(1,dataT{1:end});

% One-hot encode labels.
T = onehotencode(T,2);

% Transpose the encoded labels to match the network output.
T = T';

end

ミニバッチ予測子前処理関数

関数 preprocessMiniBatchPredictors は、ドキュメントのミニバッチを整数のシーケンスに変換します。

function X = preprocessMiniBatchPredictors(dataX,enc)

% Extract documents from cell and concatenate.
documents = cat(4,dataX{1:end});

% Convert documents to sequences of integers.
X = doc2sequence(enc,documents);
X = cat(1,X{:});

end

参考

| | (Deep Learning Toolbox) | | (Deep Learning Toolbox) | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) |

トピック