カスタム学習ループを使用したテキスト データの分類
この例では、カスタム学習ループのある深層学習の双方向長短期記憶 (BiLSTM) ネットワークを使用してテキスト データを分類する方法を説明します。
関数 trainNetwork
を使用して深層学習ネットワークの学習を行う場合、必要なオプション (たとえば、カスタム学習率スケジュール) が関数 trainingOptions
に用意されていなければ、自動微分を使用して独自のカスタム学習ループを定義できます。関数 trainNetwork
を使用してテキスト データを分類する方法を示す例については、深層学習を使用したテキスト データの分類を参照してください。
この例では、"時間ベースの減衰" 学習率スケジュールでテキスト データを分類するようにネットワークに学習させます。各反復で、ソルバーは によって与えられる学習率を使用します。ここで、t は反復回数、 は初期学習率、k は減衰です。
データのインポート
工場レポートのデータをインポートします。このデータには、出荷時のイベントを説明するラベル付きテキストが含まれています。テキスト データを string としてインポートするために、テキスト タイプを "string"
に指定します。
filename = "factoryReports.csv"; data = readtable(filename,TextType="string"); head(data)
Description Category Urgency Resolution Cost _____________________________________________________________________ ____________________ ________ ____________________ _____ "Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45 "Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35 "There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200 "Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352 "Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55 "Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371 "A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441 "Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
この例の目的は、Category
列のラベルによって事象を分類することです。データをクラスに分割するために、これらのラベルを categorical に変換します。
data.Category = categorical(data.Category);
ヒストグラムを使用してデータのクラスの分布を表示します。
figure histogram(data.Category); xlabel("Class") ylabel("Frequency") title("Class Distribution")
次の手順は、これを学習セットと検証セットに分割することです。データを学習区画と、検証およびテスト用のホールドアウト区画に分割します。ホールドアウトの割合を 20% に指定します。
cvp = cvpartition(data.Category,Holdout=0.2); dataTrain = data(training(cvp),:); dataValidation = data(test(cvp),:);
分割した table からテキスト データとラベルを抽出します。
textDataTrain = dataTrain.Description; textDataValidation = dataValidation.Description; TTrain = dataTrain.Category; TValidation = dataValidation.Category;
データが正しくインポートされたことを確認するために、ワード クラウドを使用して学習テキスト データを可視化します。
figure
wordcloud(textDataTrain);
title("Training Data")
クラス数を表示します。
classes = categories(TTrain); numClasses = numel(classes)
numClasses = 4
テキスト データの前処理
テキスト データをトークン化および前処理する関数を作成します。例の最後にリストされている関数 preprocessText
は以下のステップを実行します。
tokenizedDocument
を使用してテキストをトークン化する。lower
を使用してテキストを小文字に変換する。erasePunctuation
を使用して句読点を消去する。
関数 preprocessText
を使用して学習データと検証データを前処理する。
documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);
前処理した学習ドキュメントを最初の数個表示します。
documentsTrain(1:5)
ans = 5×1 tokenizedDocument: 9 tokens: items are occasionally getting stuck in the scanner spools 10 tokens: loud rattling and banging sounds are coming from assembler pistons 10 tokens: there are cuts to the power when starting the plant 5 tokens: fried capacitors in the assembler 4 tokens: mixer tripped the fuses
arrayDatastore
オブジェクトを作成してから、関数 combine
を使用してそれらを結合することによって、ドキュメントとラベルの両方を含む単のデータストアを作成します。
dsDocumentsTrain = arrayDatastore(documentsTrain,OutputType="cell"); dsTTrain = arrayDatastore(TTrain,OutputType="cell"); dsTrain = combine(dsDocumentsTrain,dsTTrain);
検証ドキュメント用の配列データストアを作成します。
dsDocumentsValidation = arrayDatastore(documentsValidation,OutputType="cell");
単語符号化の作成
ドキュメントを BiLSTM ネットワークに入力するために、単語符号化を使用してドキュメントを数値インデックスのシーケンスに変換します。
単語符号化を作成するには、関数 wordEncoding
を使用します。
enc = wordEncoding(documentsTrain)
enc = wordEncoding with properties: NumWords: 424 Vocabulary: ["items" "are" "occasionally" "getting" "stuck" "in" "the" "scanner" "spools" "loud" "rattling" "and" "banging" "sounds" "coming" "from" "assembler" "pistons" "there" … ]
ネットワークの定義
BiLSTM ネットワーク アーキテクチャを定義します。シーケンス データをネットワークに入力するために、シーケンス入力層を含め、入力サイズを 1 に設定します。次に、25 次元の単語埋め込み層と、同じ数の単語の単語符号化を含めます。次に、BiLSTM 層を含め、隠れユニット数を 40 に設定します。sequence-to-label 分類問題に BiLSTM 層を使用するには、出力モードを "last"
に設定します。最後に、クラスの数と同じサイズの全結合層と、ソフトマックス層を追加します。
inputSize = 1;
embeddingDimension = 25;
numHiddenUnits = 40;
numWords = enc.NumWords;
layers = [
sequenceInputLayer(inputSize)
wordEmbeddingLayer(embeddingDimension,numWords)
bilstmLayer(numHiddenUnits,OutputMode="last")
fullyConnectedLayer(numClasses)
softmaxLayer]
layers = 5×1 Layer array with layers: 1 '' Sequence Input Sequence input with 1 dimensions 2 '' Word Embedding Layer Word embedding layer with 25 dimensions and 424 unique words 3 '' BiLSTM BiLSTM with 40 hidden units 4 '' Fully Connected 4 fully connected layer 5 '' Softmax softmax
層配列を dlnetwork
オブジェクトに変換します。
net = dlnetwork(layers)
net = dlnetwork with properties: Layers: [5×1 nnet.cnn.layer.Layer] Connections: [4×2 table] Learnables: [6×3 table] State: [2×3 table] InputNames: {'sequenceinput'} OutputNames: {'softmax'} Initialized: 1 View summary with summary.
モデル損失関数の定義
例の最後にリストされている関数 modelLoss
を作成します。この関数は、dlnetwork
オブジェクト、入力データのミニバッチとそれに対応するラベルを受け取り、ネットワーク内の学習可能パラメーターについての損失および損失の勾配を返します。
学習オプションの指定
ミニバッチ サイズを 16 として 30 エポック学習させます。
numEpochs = 30; miniBatchSize = 16;
Adam 最適化のオプションを指定します。初期学習率を 0.001、減衰を 0.01、勾配の減衰係数を 0.9、2 乗勾配の減衰係数を 0.999 に指定します。
initialLearnRate = 0.001; decay = 0.01; gradientDecayFactor = 0.9; squaredGradientDecayFactor = 0.999;
モデルの学習
データのミニバッチを処理および管理する minibatchqueue
オブジェクトを作成します。各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
(この例の最後に定義) を使用し、ドキュメントをシーケンスに変換して、ラベルを one-hot 符号化します。単語符号化をミニバッチに渡すには、2 つの入力を受け取る無名関数を作成します。次元ラベル
"BTC"
(batch、time、channel) を使用して予測子を書式設定します。minibatchqueue
オブジェクトは、既定では、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。GPU が利用できる場合、GPU で学習を行います。
minibatchqueue
オブジェクトは、既定では、GPU が利用可能な場合、各出力をgpuArray
に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
mbq = minibatchqueue(dsTrain, ... MiniBatchSize=miniBatchSize,... MiniBatchFcn=@(X,T) preprocessMiniBatch(X,T,enc), ... MiniBatchFormat=["BTC" ""]);
検証ドキュメント用の minibatchqueue
オブジェクトを作成します。各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatchPredictors
(この例の最後に定義) を使用し、ドキュメントをシーケンスに変換します。この前処理関数はラベル データを必要としません。単語符号化をミニバッチに渡すには、入力を 1 つだけ受け取る無名関数を作成します。次元ラベル
"BTC"
(batch、time、channel) を使用して予測子を書式設定します。minibatchqueue
オブジェクトは、既定では、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。すべての観測値に対して予測を行うには、部分的なミニバッチを返します。
mbqValidation = minibatchqueue(dsDocumentsValidation, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=@(X) preprocessMiniBatchPredictors(X,enc), ... MiniBatchFormat="BTC", ... PartialMiniBatch="return");
検証損失の計算を簡単にするため、検証ラベルを one-hot 符号化されたベクトルに変換し、符号化されたラベルを転置し、ネットワークの出力形式に適合させます。
TValidation = onehotencode(TValidation,2); TValidation = TValidation';
Adam のパラメーターを初期化します。
trailingAvg = []; trailingAvgSq = [];
学習の進行状況モニター用に合計反復回数を計算します。
numObservationsTrain = numel(documentsTrain); numIterationsPerEpoch = ceil(numObservationsTrain / miniBatchSize); numIterations = numEpochs * numIterationsPerEpoch;
TrainingProgressMonitor
オブジェクトを初期化します。monitor オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。
monitor = trainingProgressMonitor( ... Metrics=["TrainingLoss","ValidationLoss"], ... Info=["Epoch","LearnRate"], ... XLabel="Iteration"); groupSubPlot(monitor,"Loss",["TrainingLoss","ValidationLoss"])
ネットワークに学習をさせます。各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。反復が終了するたびに、学習の進行状況を表示します。各エポックの最後に、検証データを使用してネットワークを検証します。
各ミニバッチで次を行います。
ドキュメントを整数のシーケンスに変換し、ラベルを one-hot 符号化します。
基となる型が single の
dlarray
オブジェクトにデータを変換し、次元ラベル"BTC"
(batch、time、channel) を指定。GPU で学習する場合、
gpuArray
オブジェクトに変換。関数
dlfeval
およびmodelLoss
を使用してモデルの損失と勾配を評価。時間ベースの減衰学習率スケジュールの学習率を決定。
関数
adamupdate
を使用してネットワーク パラメーターを更新。学習プロットを更新。
Stop プロパティが true の場合は停止。[停止] ボタンをクリックすると、
TrainingProgressMonitor
オブジェクトの Stop プロパティ値が true に変わります。
epoch = 0; iteration = 0; % Loop over epochs. while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) && ~monitor.Stop iteration = iteration + 1; % Read mini-batch of data. [X,T] = next(mbq); % Evaluate the model loss and gradients using dlfeval and the % modelLoss function. [loss,gradients] = dlfeval(@modelLoss,net,X,T); % Determine learning rate for time-based decay learning rate schedule. learnRate = initialLearnRate/(1 + decay*iteration); % Update the network parameters using the Adam optimizer. [net,trailingAvg,trailingAvgSq] = adamupdate(net, gradients, ... trailingAvg, trailingAvgSq, iteration, learnRate, ... gradientDecayFactor, squaredGradientDecayFactor); % Display the training progress. recordMetrics(monitor,iteration,TrainingLoss=loss); updateInfo(monitor,LearnRate=learnRate,Epoch=(epoch+" of "+numEpochs)); % Validate network. if iteration == 1 || ~hasdata(mbq) [~,scoresValidation] = modelPredictions(net,mbqValidation,classes); lossValidation = crossentropy(scoresValidation,TValidation); % Update plot. recordMetrics(monitor,iteration,ValidationLoss=lossValidation); end monitor.Progress = 100*iteration/numIterations; end end
モデルのテスト
真のラベルをもつ検証セットで予測を比較し、モデルの分類精度をテストします。
例の最後にリストされている関数 modelPredictions
を使用し、検証データを分類します。
YNew = modelPredictions(net,mbqValidation,classes);
検証精度の計算を簡単にするため、one-hot 符号化された検証ラベルを categorical に変換して転置します。
TValidation = onehotdecode(TValidation,classes,1)';
分類精度を評価します。
accuracy = mean(YNew == TValidation)
accuracy = 0.8646
新しいデータを使用した予測
3 つの新しいレポートの事象タイプを分類します。新しいレポートを含む string 配列を作成します。
reportsNew = [ "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];
学習ドキュメントと同じ前処理手順を使用してテキスト データを前処理します。
documentsNew = preprocessText(reportsNew);
dsNew = arrayDatastore(documentsNew,OutputType="cell");
データのミニバッチを処理および管理する minibatchqueue
オブジェクトを作成します。各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatchPredictors
(この例の最後に定義) を使用し、ドキュメントをシーケンスに変換します。この前処理関数はラベル データを必要としません。単語符号化をミニバッチに渡すには、入力を 1 つだけ受け取る無名関数を作成します。次元ラベル
"BTC"
(batch、time、channel) を使用して予測子を書式設定します。minibatchqueue
オブジェクトは、既定では、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。すべての観測値に対して予測を行うには、部分的なミニバッチを返します。
mbqNew = minibatchqueue(dsNew, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=@(X) preprocessMiniBatchPredictors(X,enc), ... MiniBatchFormat="BTC", ... PartialMiniBatch="return");
例の最後にリストされている関数 modelPredictions
を使用してテキスト データを分類し、スコアが最も高いクラスを見つけます。
YNew = modelPredictions(net,mbqNew,classes)
YNew = 3×1 categorical
Leak
Electronic Failure
Mechanical Failure
サポート関数
テキスト前処理関数
関数 preprocessText
は以下のステップを実行します。
tokenizedDocument
を使用してテキストをトークン化する。lower
を使用してテキストを小文字に変換する。erasePunctuation
を使用して句読点を消去する。
function documents = preprocessText(textData) % Tokenize the text. documents = tokenizedDocument(textData); % Convert to lowercase. documents = lower(documents); % Erase punctuation. documents = erasePunctuation(documents); end
ミニバッチ前処理関数
関数 preprocessMiniBatch
は、ドキュメントのミニバッチを整数のシーケンスに変換し、ラベル データを one-hot 符号化します。
function [X,T] = preprocessMiniBatch(dataX,dataT,enc) % Preprocess predictors. X = preprocessMiniBatchPredictors(dataX,enc); % Extract labels from cell and concatenate. T = cat(1,dataT{1:end}); % One-hot encode labels. T = onehotencode(T,2); % Transpose the encoded labels to match the network output. T = T'; end
ミニバッチ予測子前処理関数
関数 preprocessMiniBatchPredictors
は、ドキュメントのミニバッチを整数のシーケンスに変換します。
function X = preprocessMiniBatchPredictors(dataX,enc) % Extract documents from cell and concatenate. documents = cat(4,dataX{1:end}); % Convert documents to sequences of integers. X = doc2sequence(enc,documents); X = cat(1,X{:}); end
モデル損失関数
関数 modelLoss
は、dlnetwork
オブジェクト net
、入力データ X
のミニバッチとそれに対応するターゲット ラベル T
を受け取り、net
内の学習可能パラメーターについての損失の勾配、および損失を返します。勾配を自動的に計算するには、関数 dlgradient
を使用します。
function [loss,gradients] = modelLoss(net,X,T) Y = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables); end
モデル予測関数
関数 modelPredictions
は、dlnetwork
オブジェクト net
およびミニバッチ キューを受け取り、キュー内のミニバッチを反復することによりモデルの予測とスコアを出力します。
function [predictions,scores] = modelPredictions(net,mbq,classes) % Initialize predictions. predictions = []; scores = []; % Reset mini-batch queue. reset(mbq); % Loop over mini-batches. while hasdata(mbq) % Make predictions. X = next(mbq); Y = predict(net,X); scores = [scores Y]; Y = onehotdecode(Y,classes,1)'; predictions = [predictions; Y]; end end
参考
wordEmbeddingLayer
(Text Analytics Toolbox) | tokenizedDocument
(Text Analytics Toolbox) | lstmLayer
| doc2sequence
(Text Analytics Toolbox) | sequenceInputLayer
| wordcloud
(Text Analytics Toolbox) | dlfeval
| dlgradient
| dlarray
関連するトピック
- カスタム学習ループ、損失関数、およびネットワークの定義
- 深層学習を使用したテキスト データの分類
- 分類用の単純なテキスト モデルの作成 (Text Analytics Toolbox)
- トピック モデルを使用したテキスト データの解析 (Text Analytics Toolbox)
- マルチワード フレーズを使用したテキスト データの解析 (Text Analytics Toolbox)
- センチメント分類器の学習 (Text Analytics Toolbox)
- 深層学習を使用したシーケンスの分類
- MATLAB による深層学習