Main Content

カスタム学習ループを使用したテキスト データの分類

この例では、カスタム学習ループのある深層学習の双方向長短期記憶 (BiLSTM) ネットワークを使用してテキスト データを分類する方法を説明します。

関数 trainNetwork を使用して深層学習ネットワークの学習を行う場合、必要なオプション (たとえば、カスタム学習率スケジュール) が関数 trainingOptions に用意されていなければ、自動微分を使用して独自のカスタム学習ループを定義できます。関数 trainNetwork を使用してテキスト データを分類する方法を示す例については、深層学習を使用したテキスト データの分類を参照してください。

この例では、"時間ベースの減衰" 学習率スケジュールでテキスト データを分類するようにネットワークに学習させます。各反復で、ソルバーは ρt=ρ01+kt によって与えられる学習率を使用します。ここで、t は反復回数、ρ0 は初期学習率、k は減衰です。

データのインポート

工場レポートのデータをインポートします。このデータには、出荷時のイベントを説明するラベル付きテキストが含まれています。テキスト データを string としてインポートするために、テキスト タイプを "string" に指定します。

filename = "factoryReports.csv";
data = readtable(filename,TextType="string");
head(data)
ans=8×5 table
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

この例の目的は、Category 列のラベルによって事象を分類することです。データをクラスに分割するために、これらのラベルを categorical に変換します。

data.Category = categorical(data.Category);

ヒストグラムを使用してデータのクラスの分布を表示します。

figure
histogram(data.Category);
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

次の手順は、これを学習セットと検証セットに分割することです。データを学習区画と、検証およびテスト用のホールドアウト区画に分割します。ホールドアウトの割合を 20% に指定します。

cvp = cvpartition(data.Category,Holdout=0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

分割した table からテキスト データとラベルを抽出します。

textDataTrain = dataTrain.Description;
textDataValidation = dataValidation.Description;
TTrain = dataTrain.Category;
TValidation = dataValidation.Category;

データが正しくインポートされたことを確認するために、ワード クラウドを使用して学習テキスト データを可視化します。

figure
wordcloud(textDataTrain);
title("Training Data")

クラス数を表示します。

classes = categories(TTrain);
numClasses = numel(classes)
numClasses = 4

テキスト データの前処理

テキスト データをトークン化および前処理する関数を作成します。例の最後にリストされている関数 preprocessText は以下のステップを実行します。

  1. tokenizedDocument を使用してテキストをトークン化する。

  2. lower を使用してテキストを小文字に変換する。

  3. erasePunctuation を使用して句読点を消去する。

関数 preprocessText を使用して学習データと検証データを前処理する。

documentsTrain = preprocessText(textDataTrain);
documentsValidation = preprocessText(textDataValidation);

前処理した学習ドキュメントを最初の数個表示します。

documentsTrain(1:5)
ans = 
  5×1 tokenizedDocument:

     9 tokens: items are occasionally getting stuck in the scanner spools
    10 tokens: loud rattling and banging sounds are coming from assembler pistons
     5 tokens: fried capacitors in the assembler
     4 tokens: mixer tripped the fuses
     9 tokens: burst pipe in the constructing agent is spraying coolant

arrayDatastore オブジェクトを作成してから、関数 combine を使用してそれらを結合することによって、ドキュメントとラベルの両方を含む単のデータストアを作成します。

dsDocumentsTrain = arrayDatastore(documentsTrain,OutputType="cell");
dsTTrain = arrayDatastore(TTrain,OutputType="cell");
dsTrain = combine(dsDocumentsTrain,dsTTrain);

検証ドキュメント用の配列データストアを作成します。

dsDocumentsValidation = arrayDatastore(documentsValidation,OutputType="cell");

単語符号化の作成

ドキュメントを BiLSTM ネットワークに入力するために、単語符号化を使用してドキュメントを数値インデックスのシーケンスに変換します。

単語符号化を作成するには、関数 wordEncoding を使用します。

enc = wordEncoding(documentsTrain)
enc = 
  wordEncoding with properties:

      NumWords: 417
    Vocabulary: ["items"    "are"    "occasionally"    "getting"    "stuck"    "in"    "the"    "scanner"    "spools"    "loud"    "rattling"    "and"    "banging"    "sounds"    "coming"    "from"    "assembler"    "pistons"    "fried"    …    ]

ネットワークの定義

BiLSTM ネットワーク アーキテクチャを定義します。シーケンス データをネットワークに入力するために、シーケンス入力層を含め、入力サイズを 1 に設定します。次に、25 次元の単語埋め込み層と、同じ数の単語の単語符号化を含めます。次に、BiLSTM 層を含め、隠れユニット数を 40 に設定します。sequence-to-label 分類問題に BiLSTM 層を使用するには、出力モードを "last" に設定します。最後に、クラスの数と同じサイズの全結合層と、ソフトマックス層を追加します。

inputSize = 1;
embeddingDimension = 25;
numHiddenUnits = 40;

numWords = enc.NumWords;

layers = [
    sequenceInputLayer(inputSize)
    wordEmbeddingLayer(embeddingDimension,numWords)
    bilstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer]
layers = 
  5×1 Layer array with layers:

     1   ''   Sequence Input         Sequence input with 1 dimensions
     2   ''   Word Embedding Layer   Word embedding layer with 25 dimensions and 417 unique words
     3   ''   BiLSTM                 BiLSTM with 40 hidden units
     4   ''   Fully Connected        4 fully connected layer
     5   ''   Softmax                softmax

層配列を dlnetwork オブジェクトに変換します。

net = dlnetwork(layers)
net = 
  dlnetwork with properties:

         Layers: [5×1 nnet.cnn.layer.Layer]
    Connections: [4×2 table]
     Learnables: [6×3 table]
          State: [2×3 table]
     InputNames: {'sequenceinput'}
    OutputNames: {'softmax'}
    Initialized: 1

モデル損失関数の定義

例の最後にリストされている関数 modelLoss を作成します。この関数は、dlnetwork オブジェクト、入力データのミニバッチとそれに対応するラベルを受け取り、ネットワーク内の学習可能パラメーターについての損失および損失の勾配を返します。

学習オプションの指定

ミニバッチ サイズを 16 として 30 エポック学習させます。

numEpochs = 30;
miniBatchSize = 16;

Adam 最適化のオプションを指定します。初期学習率を 0.001、減衰を 0.01、勾配の減衰係数を 0.9、2 乗勾配の減衰係数を 0.999 に指定します。

initialLearnRate = 0.001;
decay = 0.01;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

モデルの学習

データのミニバッチを処理および管理する minibatchqueue オブジェクトを作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用し、ドキュメントをシーケンスに変換して、ラベルを one-hot 符号化します。単語符号化をミニバッチに渡すには、2 つの入力を受け取る無名関数を作成します。

  • 次元ラベル "BTC" (batch、time、channel) を使用して予測子を書式設定します。minibatchqueue オブジェクトは、既定では、基となる型が singledlarray オブジェクトにデータを変換します。

  • GPU が利用できる場合、GPU で学習を行います。minibatchqueue オブジェクトは、既定では、GPU が利用可能な場合、各出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

mbq = minibatchqueue(dsTrain, ...
    MiniBatchSize=miniBatchSize,...
    MiniBatchFcn=@(X,T) preprocessMiniBatch(X,T,enc), ...
    MiniBatchFormat=["BTC" ""]);

検証ドキュメント用の minibatchqueue オブジェクトを作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatchPredictors (この例の最後に定義) を使用し、ドキュメントをシーケンスに変換します。この前処理関数はラベル データを必要としません。単語符号化をミニバッチに渡すには、入力を 1 つだけ受け取る無名関数を作成します。

  • 次元ラベル "BTC" (batch、time、channel) を使用して予測子を書式設定します。minibatchqueue オブジェクトは、既定では、基となる型が singledlarray オブジェクトにデータを変換します。

  • すべての観測値に対して予測を行うには、部分的なミニバッチを返します。

mbqValidation = minibatchqueue(dsDocumentsValidation, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@(X) preprocessMiniBatchPredictors(X,enc), ...
    MiniBatchFormat="BTC", ...
    PartialMiniBatch="return");

検証損失の計算を簡単にするため、検証ラベルを one-hot 符号化されたベクトルに変換し、符号化されたラベルを転置し、ネットワークの出力形式に適合させます。

TValidation = onehotencode(TValidation,2);
TValidation = TValidation';

学習の進行状況プロットを初期化します。

figure
C = colororder;
lineLossTrain = animatedline(Color=C(2,:));

lineLossValidation = animatedline( ...
    LineStyle="--", ...
    Marker="o", ...
    MarkerFaceColor="black");

ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Adam のパラメーターを初期化します。

trailingAvg = [];
trailingAvgSq = [];

ネットワークに学習をさせます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。反復が終了するたびに、学習の進行状況を表示します。各エポックの最後に、検証データを使用してネットワークを検証します。

各ミニバッチで次を行います。

  • ドキュメントを整数のシーケンスに変換し、ラベルを one-hot 符号化します。

  • 基となる型が single の dlarray オブジェクトにデータを変換し、次元ラベル "BTC" (batch、time、channel) を指定。

  • GPU で学習する場合、gpuArray オブジェクトに変換。

  • 関数 dlfeval および modelLoss を使用してモデルの損失と勾配を評価。

  • 時間ベースの減衰学習率スケジュールの学習率を決定。

  • 関数 adamupdate を使用してネットワーク パラメーターを更新します。

  • 学習プロットを更新。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs

    % Shuffle data.
    shuffle(mbq);

    % Loop over mini-batches.
    while hasdata(mbq)
        iteration = iteration + 1;

        % Read mini-batch of data.
        [X,T] = next(mbq);

        % Evaluate the model loss and gradients using dlfeval and the
        % modelLoss function.
        [loss,gradients] = dlfeval(@modelLoss,net,X,T);

        % Determine learning rate for time-based decay learning rate schedule.
        learnRate = initialLearnRate/(1 + decay*iteration);

        % Update the network parameters using the Adam optimizer.
        [net,trailingAvg,trailingAvgSq] = adamupdate(net, gradients, ...
            trailingAvg, trailingAvgSq, iteration, learnRate, ...
            gradientDecayFactor, squaredGradientDecayFactor);

        % Display the training progress.
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        loss = double(loss);
        addpoints(lineLossTrain,iteration,loss)
        title("Epoch: " + epoch + ", Elapsed: " + string(D))
        drawnow

        % Validate network.
        if iteration == 1 || ~hasdata(mbq)
            [~,scoresValidation] = modelPredictions(net,mbqValidation,classes);
            lossValidation = crossentropy(scoresValidation,TValidation);

            % Update plot.
            lossValidation = double(lossValidation);
            addpoints(lineLossValidation,iteration,lossValidation)
            drawnow
        end
    end
end

モデルのテスト

真のラベルをもつ検証セットで予測を比較し、モデルの分類精度をテストします。

例の最後にリストされている関数 modelPredictions を使用し、検証データを分類します。

YNew = modelPredictions(net,mbqValidation,classes);

検証精度の計算を簡単にするため、one-hot 符号化された検証ラベルを categorical に変換して転置します。

TValidation = onehotdecode(TValidation,classes,1)';

分類精度を評価します。

accuracy = mean(YNew == TValidation)
accuracy = 0.8854

新しいデータを使用した予測

3 つの新しいレポートの事象タイプを分類します。新しいレポートを含む string 配列を作成します。

reportsNew = [
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

学習ドキュメントと同じ前処理手順を使用してテキスト データを前処理します。

documentsNew = preprocessText(reportsNew);
dsNew = arrayDatastore(documentsNew,OutputType="cell");

データのミニバッチを処理および管理する minibatchqueue オブジェクトを作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatchPredictors (この例の最後に定義) を使用し、ドキュメントをシーケンスに変換します。この前処理関数はラベル データを必要としません。単語符号化をミニバッチに渡すには、入力を 1 つだけ受け取る無名関数を作成します。

  • 次元ラベル "BTC" (batch、time、channel) を使用して予測子を書式設定します。minibatchqueue オブジェクトは、既定では、基となる型が singledlarray オブジェクトにデータを変換します。

  • すべての観測値に対して予測を行うには、部分的なミニバッチを返します。

mbqNew = minibatchqueue(dsNew, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=@(X) preprocessMiniBatchPredictors(X,enc), ...
    MiniBatchFormat="BTC", ...
    PartialMiniBatch="return");

例の最後にリストされている関数 modelPredictions を使用してテキスト データを分類し、スコアが最も高いクラスを見つけます。

YNew = modelPredictions(net,mbqNew,classes)
YNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

サポート関数

テキスト前処理関数

関数 preprocessText は以下のステップを実行します。

  1. tokenizedDocument を使用してテキストをトークン化する。

  2. lower を使用してテキストを小文字に変換する。

  3. erasePunctuation を使用して句読点を消去する。

function documents = preprocessText(textData)

% Tokenize the text.
documents = tokenizedDocument(textData);

% Convert to lowercase.
documents = lower(documents);

% Erase punctuation.
documents = erasePunctuation(documents);

end

ミニバッチ前処理関数

関数 preprocessMiniBatch は、ドキュメントのミニバッチを整数のシーケンスに変換し、ラベル データを one-hot 符号化します。

function [X,T] = preprocessMiniBatch(dataX,dataT,enc)

% Preprocess predictors.
X = preprocessMiniBatchPredictors(dataX,enc);

% Extract labels from cell and concatenate.
T = cat(1,dataT{1:end});

% One-hot encode labels.
T = onehotencode(T,2);

% Transpose the encoded labels to match the network output.
T = T';

end

ミニバッチ予測子前処理関数

関数 preprocessMiniBatchPredictors は、ドキュメントのミニバッチを整数のシーケンスに変換します。

function X = preprocessMiniBatchPredictors(dataX,enc)

% Extract documents from cell and concatenate.
documents = cat(4,dataX{1:end});

% Convert documents to sequences of integers.
X = doc2sequence(enc,documents);
X = cat(1,X{:});

end

モデル損失関数

関数 modelLoss は、dlnetwork オブジェクト net、入力データ X のミニバッチとそれに対応するターゲット ラベル T を受け取り、net 内の学習可能パラメーターについての損失の勾配、および損失を返します。勾配を自動的に計算するには、関数 dlgradient を使用します。

function [loss,gradients] = modelLoss(net,X,T)

Y = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss,net.Learnables);

end

モデル予測関数

関数 modelPredictions は、dlnetwork オブジェクト net およびミニバッチ キューを受け取り、キュー内のミニバッチを反復することによりモデルの予測とスコアを出力します。

function [predictions,scores] = modelPredictions(net,mbq,classes)

% Initialize predictions.
predictions = [];
scores = [];

% Reset mini-batch queue.
reset(mbq);

% Loop over mini-batches.
while hasdata(mbq)

    % Make predictions.
    X = next(mbq);
    Y = predict(net,X);

    scores = [scores Y];

    Y = onehotdecode(Y,classes,1)';
    predictions = [predictions; Y];
end

end

参考

(Text Analytics Toolbox) | (Text Analytics Toolbox) | | (Text Analytics Toolbox) | | (Text Analytics Toolbox) | | |

関連するトピック