Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

畳み込みニューラル ネットワークを使用したテキスト データの分類

この例では、畳み込みニューラル ネットワークを使用してテキスト データを分類する方法を説明します。

畳み込みを使用してテキスト データを分類するには、入力の時間次元全体で畳み込みを行う 1 次元畳み込み層を使用します。

この例では、幅が異なる 1 次元畳み込みフィルターをもつネットワークに学習させます。各フィルターの幅は、フィルターが確認可能な単語の数 (n-gram の長さ) に対応します。ネットワークは畳み込み層の複数の分岐をもつため、異なる n-gram の長さを使用できます。

データの読み込み

factoryReports.csv 内のデータから表形式のテキスト データストアを作成し、最初のいくつかについてレポートを表示します。

data = readtable("factoryReports.csv");
head(data)
ans=8×5 table
                                  Description                                         Category            Urgency            Resolution          Cost 
    _______________________________________________________________________    ______________________    __________    ______________________    _____

    {'Items are occasionally getting stuck in the scanner spools.'        }    {'Mechanical Failure'}    {'Medium'}    {'Readjust Machine'  }       45
    {'Loud rattling and banging sounds are coming from assembler pistons.'}    {'Mechanical Failure'}    {'Medium'}    {'Readjust Machine'  }       35
    {'There are cuts to the power when starting the plant.'               }    {'Electronic Failure'}    {'High'  }    {'Full Replacement'  }    16200
    {'Fried capacitors in the assembler.'                                 }    {'Electronic Failure'}    {'High'  }    {'Replace Components'}      352
    {'Mixer tripped the fuses.'                                           }    {'Electronic Failure'}    {'Low'   }    {'Add to Watch List' }       55
    {'Burst pipe in the constructing agent is spraying coolant.'          }    {'Leak'              }    {'High'  }    {'Replace Components'}      371
    {'A fuse is blown in the mixer.'                                      }    {'Electronic Failure'}    {'Low'   }    {'Replace Components'}      441
    {'Things continue to tumble off of the belt.'                         }    {'Mechanical Failure'}    {'Low'   }    {'Readjust Machine'  }       38

データを学習区画と検証区画に分割します。データの 80% を学習に使用し、残りのデータを検証に使用します。

cvp = cvpartition(data.Category,Holdout=0.2);
dataTrain = data(training(cvp),:);
dataValidation = data(test(cvp),:);

テキスト データの前処理

テーブルの "Description" 列からテキスト データを抽出し、この例のテキスト前処理関数の節にリストされている関数 preprocessText を使用して前処理を行います。

documentsTrain = preprocessText(dataTrain.Description);

"Category" 列からラベルを抽出して categorical に変換します。

TTrain = categorical(dataTrain.Category);

クラス名と観測値の数を表示します。

classNames = unique(TTrain)
classNames = 4×1 categorical
     Electronic Failure 
     Leak 
     Mechanical Failure 
     Software Failure 

numObservations = numel(TTrain)
numObservations = 384

同じ手順で、検証データを抽出して前処理を行います。

documentsValidation = preprocessText(dataValidation.Description);
TValidation = categorical(dataValidation.Category);

ドキュメントからシーケンスへの変換

ドキュメントをニューラル ネットワークに入力するために、単語符号化を使用してドキュメントを数値インデックスのシーケンスに変換します。

ドキュメントから単語符号化を作成します。

enc = wordEncoding(documentsTrain);

単語符号化の語彙サイズを表示します。語彙サイズは、単語符号化における一意の単語の数です。

numWords = enc.NumWords
numWords = 436

関数 doc2sequence を使用して、ドキュメントを整数のシーケンスに変換します。

XTrain = doc2sequence(enc,documentsTrain);

学習データから作成した単語符号化を使用して、検証ドキュメントをシーケンスに変換します。

XValidation = doc2sequence(enc,documentsValidation);

ネットワーク アーキテクチャの定義

分類タスク用のネットワーク アーキテクチャを定義します。

ネットワーク アーキテクチャの手順は以下のとおりです。

  • 入力サイズとして 1 を指定します。これは、整数シーケンス入力のチャネル次元に対応します。

  • 100 次元の単語埋め込みを使用して、入力を埋め込みます。

  • 長さ 2、3、4、5 の n-gram について、畳み込み層、バッチ正規化層、ReLU 層、ドロップアウト層、および最大プーリング層を含む層のブロックを作成。

  • 各ブロックについて、サイズが 1 行 N 列の畳み込みフィルター 200 個と、グローバル最大プーリング層を指定。

  • 入力層を各ブロックに接続し、連結層を使用してブロックの出力を連結。

  • 出力を分類するために、出力サイズ K の全結合層、ソフトマックス層、および分類層を含める。ここで、K はクラスの数。

ネットワークのハイパーパラメーターを指定します。

embeddingDimension = 100;
ngramLengths = [2 3 4 5];
numFilters = 200;

まず、入力層を含む層グラフ、および 100 次元の単語埋め込み層を作成します。単語埋め込み層を畳み込み層に接続できるようにするため、埋め込み層の名前を "emb" に設定します。学習時に畳み込み層による畳み込みによってシーケンスの長さが 0 とならないことをチェックするには、MinLength オプションを学習データ内で最も短いシーケンスの長さに設定します。

minLength = min(doclength(documentsTrain));
layers = [ 
    sequenceInputLayer(1,MinLength=minLength)
    wordEmbeddingLayer(embeddingDimension,numWords,Name="emb")];
lgraph = layerGraph(layers);

n-gram のそれぞれの長さについて、1 次元畳み込み、バッチ正規化、ReLU、ドロップアウト、および 1 次元グローバル最大プーリングの各層のブロックを作成します。各ブロックを単語埋め込み層に接続します。

numBlocks = numel(ngramLengths);
for j = 1:numBlocks
    N = ngramLengths(j);
    
    block = [
        convolution1dLayer(N,numFilters,Name="conv"+N,Padding="same")
        batchNormalizationLayer(Name="bn"+N)
        reluLayer(Name="relu"+N)
        dropoutLayer(0.2,Name="drop"+N)
        globalMaxPooling1dLayer(Name="max"+N)];
    
    lgraph = addLayers(lgraph,block);
    lgraph = connectLayers(lgraph,"emb","conv"+N);
end

連結層、全結合層、ソフトマックス層、および分類層を追加します。

numClasses = numel(classNames);

layers = [
    concatenationLayer(1,numBlocks,Name="cat")
    fullyConnectedLayer(numClasses,Name="fc")
    softmaxLayer(Name="soft")
    classificationLayer(Name="classification")];

lgraph = addLayers(lgraph,layers);

グローバル最大プーリング層を連結層に接続し、ネットワーク アーキテクチャをプロットで表示します。

for j = 1:numBlocks
    N = ngramLengths(j);
    lgraph = connectLayers(lgraph,"max"+N,"cat/in"+j);
end

figure
plot(lgraph)
title("Network Architecture")

ネットワークの学習

学習オプションを指定します。

  • ミニバッチ サイズを 128 として学習させます。

  • 検証データを使用してネットワークを検証します。

  • 検証損失が最小のネットワークを返します。

  • 学習の進行状況のプロットを表示し、詳細出力を非表示にします。

options = trainingOptions("adam", ...
    MiniBatchSize=128, ...
    ValidationData={XValidation,TValidation}, ...
    OutputNetwork="best-validation-loss", ...
    Plots="training-progress", ...
    Verbose=false);

関数 trainNetwork を使用してネットワークに学習させます。

net = trainNetwork(XTrain,TTrain,lgraph,options);

ネットワークのテスト

学習済みネットワークを使用して検証データを分類します。

YValidation = classify(net,XValidation);

混同チャートで予測を可視化します。

figure
confusionchart(TValidation,YValidation)

分類精度を計算します。精度は、正しく予測されたラベルの比率です。

accuracy = mean(TValidation == YValidation)
accuracy = 0.9375

新しいデータを使用した予測

3 つの新しいレポートの事象タイプを分類します。新しいレポートを含む string 配列を作成します。

reportsNew = [ 
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];

学習ドキュメントおよび検証ドキュメントと同じ前処理手順を使用してテキスト データを前処理します。

documentsNew = preprocessText(reportsNew);
XNew = doc2sequence(enc,documentsNew);

学習済みネットワークを使用して新しいシーケンスを分類します。

YNew = classify(net,XNew)
YNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

テキスト前処理関数

関数 preprocessTextData は、テキスト データを入力として受け取り、次の手順を実行します。

  1. テキストをトークン化します。

  2. テキストを小文字に変換します。

function documents = preprocessText(textData)

documents = tokenizedDocument(textData);
documents = lower(documents);

end

参考

(Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | | | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) |

関連するトピック