メインコンテンツ

BERT 文書分類器の学習

R2023b 以降

この例では、文書分類用の BERT ニューラル ネットワークに学習させる方法を説明します。

Bidirectional Encoder Representations from Transformer (BERT) モデルは、文書分類やセンチメント分析などの自然言語処理タスクに合わせて微調整できるトランスフォーマー ニューラル ネットワークです。このネットワークは、注意層を使用してコンテキスト内のテキストを解析し、単語間の長距離依存関係を取得します。

この例では、事前学習済みの BERT-Base ニューラル ネットワークを微調整し、テキストの説明を使用して工場レポートのカテゴリを予測します。

学習データの読み込み

CSV ファイル factoryReports から学習データを読み取ります。このファイルには、各レポートの説明テキストとカテゴリカル ラベルを含む工場レポートが格納されています。

filename = "factoryReports.csv";
data = readtable(filename,TextType="string");
head(data)
                                 Description                                       Category          Urgency          Resolution         Cost 
    _____________________________________________________________________    ____________________    ________    ____________________    _____

    "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
    "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
    "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
    "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
    "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
    "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
    "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
    "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

table の Category 列のラベルを categorical 値に変換し、ヒストグラムを使用してデータ内のクラスの分布を表示します。

data.Category = categorical(data.Category);
figure
histogram(data.Category)
xlabel("Class")
ylabel("Frequency")
title("Class Distribution")

クラス数を表示します。

classNames = categories(data.Category);
numClasses = numel(classNames)
numClasses = 4

データを学習セットとテスト セットに分割します。ホールドアウトの割合を 10% に指定します。

cvp = cvpartition(data.Category,Holdout=0.1);
dataTrain = data(cvp.training,:);
dataTest = data(cvp.test,:);

table からテキスト データとラベルを抽出します。

textDataTrain = dataTrain.Description;
textDataTest = dataTest.Description;
TTrain = dataTrain.Category;
TTest = dataTest.Category;

事前学習済みの BERT 文書分類器の読み込み

bertDocumentClassifier 関数を使用して、事前学習済みの BERT-Base 文書分類器を読み込みます。Text Analytics Toolbox™ Model for BERT-Base Network サポート パッケージがインストールされていない場合、この関数は、必要なサポート パッケージへのリンクをアドオン エクスプローラーに表示します。サポート パッケージをインストールするには、リンクをクリックして、[インストール] をクリックします。

mdl = bertDocumentClassifier(ClassNames=classNames)
mdl = 
  bertDocumentClassifier with properties:

       Network: [1×1 dlnetwork]
     Tokenizer: [1×1 bertTokenizer]
    ClassNames: ["Electronic Failure"    "Leak"    "Mechanical Failure"    "Software Failure"]

学習オプションの指定

学習オプションを指定します。学習オプションの中から選択するには、経験的解析が必要です。実験を実行してさまざまな学習オプションの構成を調べるには、実験マネージャー (Deep Learning Toolbox)アプリを使用できます。

  • Adam オプティマイザーを使用して学習させます。

  • 学習を 8 エポック行います。

  • 微調整を行うため、学習率を下げます。学習率 0.0001 を使用して学習させます。

  • すべてのエポックでデータをシャッフルします。

  • 学習の進行状況をプロットで監視し、精度メトリクスを監視します。

  • 詳細出力を無効にします。

options = trainingOptions("adam", ...
    MaxEpochs=8, ...
    InitialLearnRate=1e-4, ...
    Shuffle="every-epoch", ...  
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

ニューラル ネットワークの学習

trainBERTDocumentClassifier 関数を使用してニューラル ネットワークに学習させます。既定では、trainBERTDocumentClassifier 関数は利用可能な GPU がある場合にそれを使用します。GPU での学習には、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数 trainBERTDocumentClassifier は CPU を使用します。実行環境を指定するには、ExecutionEnvironment 学習オプションを使用します。

mdl = trainBERTDocumentClassifier(textDataTrain,TTrain,mdl,options);

ニューラル ネットワークのテスト

テスト データを使用して予測を行います。

YTest = classify(mdl,textDataTest);

混同行列で予測を可視化します。

figure
confusionchart(TTest,YTest)

テスト予測の分類精度を計算します。

accuracy = mean(TTest == YTest)
accuracy = 0.9375

新しいデータを使用した予測

新しい工場レポートのイベント タイプを分類します。新しい工場レポートを格納する string 配列を作成します。

strNew = [
    "Coolant is pooling underneath sorter."
    "Sorter blows fuses at start up."
    "There are some very loud rattling sounds coming from the assembler."];
labelsNew = classify(mdl,strNew)
labelsNew = 3×1 categorical
     Leak 
     Electronic Failure 
     Mechanical Failure 

参考

| | (Deep Learning Toolbox) | (Deep Learning Toolbox) |

トピック