メインコンテンツ

classify

BERT 文書分類器を使用して文書を分類する

R2023b 以降

    説明

    Y = classify(mdl,documents) は、BERT 文書分類器 mdl を使用して、指定された文書を分類します。

    Y = classify(mdl,documents,Name=Value) は、1 つ以上の名前と値の引数を使用して追加オプションを指定します。

    すべて折りたたむ

    CSV ファイル factoryReports から学習データを読み取ります。このファイルには、各レポートの説明テキストとカテゴリカル ラベルを含む工場レポートが格納されています。

    filename = "factoryReports.csv";
    data = readtable(filename,TextType="string");
    head(data)
                                     Description                                       Category          Urgency          Resolution         Cost 
        _____________________________________________________________________    ____________________    ________    ____________________    _____
    
        "Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
        "Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
        "There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
        "Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
        "Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
        "Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
        "A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
        "Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38
    

    table の Category 列のラベルを categorical 値に変換します。

    data.Category = categorical(data.Category);

    データを学習セットとテスト セットに分割します。ホールドアウトの割合を 10% に指定します。

    cvp = cvpartition(data.Category,Holdout=0.1);
    dataTrain = data(cvp.training,:);
    dataTest = data(cvp.test,:);

    table からテキスト データとラベルを抽出します。

    textDataTrain = dataTrain.Description;
    textDataTest = dataTest.Description;
    TTrain = dataTrain.Category;
    TTest = dataTest.Category;

    bertDocumentClassifier 関数を使用して、事前学習済みの BERT-Base 文書分類器を読み込みます。

    classNames = categories(data.Category);
    mdl = bertDocumentClassifier(ClassNames=classNames)
    mdl = 
      bertDocumentClassifier with properties:
    
           Network: [1×1 dlnetwork]
         Tokenizer: [1×1 bertTokenizer]
        ClassNames: ["Electronic Failure"    "Leak"    "Mechanical Failure"    "Software Failure"]
    
    

    学習オプションを指定します。学習オプションの中から選択するには、経験的解析が必要です。実験を実行してさまざまな学習オプションの構成を調べるには、実験マネージャー (Deep Learning Toolbox)アプリを使用できます。

    • Adam オプティマイザーを使用して学習させます。

    • 学習を 8 エポック行います。

    • 微調整を行うため、学習率を下げます。学習率 0.0001 を使用して学習させます。

    • すべてのエポックでデータをシャッフルします。

    • 学習の進行状況をプロットで監視し、精度メトリクスを監視します。

    • 詳細出力を無効にします。

    options = trainingOptions("adam", ...
        MaxEpochs=8, ...
        InitialLearnRate=1e-4, ...
        Shuffle="every-epoch", ...  
        Plots="training-progress", ...
        Metrics="accuracy", ...
        Verbose=false);

    trainBERTDocumentClassifier 関数を使用してニューラル ネットワークに学習させます。既定では、trainBERTDocumentClassifier 関数は利用可能な GPU がある場合にそれを使用します。GPU での学習には、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数 trainBERTDocumentClassifier は CPU を使用します。実行環境を指定するには、ExecutionEnvironment 学習オプションを使用します。

    mdl = trainBERTDocumentClassifier(textDataTrain,TTrain,mdl,options);

    テスト データを使用して予測を行います。

    YTest = classify(mdl,textDataTest);

    テスト予測の分類精度を計算します。

    accuracy = mean(TTest == YTest)
    accuracy = 0.9375
    

    入力引数

    すべて折りたたむ

    BERT 文書分類器のモデル。bertDocumentClassifier オブジェクトとして指定します。

    入力文書。string 配列、文字ベクトルの cell 配列、または tokenizedDocument 配列として指定します。

    名前と値の引数

    すべて折りたたむ

    オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで、Name は引数名で、Value は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。

    例: classify(mdl,document,MiniBatchSize=64) は、サイズ 64 のミニバッチを使用して、指定された文書を分類します。

    予測に使用するミニバッチのサイズ。正の整数として指定します。ミニバッチのサイズが大きくなるとより多くのメモリが必要になりますが、予測時間が短縮される可能性があります。

    データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

    パフォーマンスの最適化。次のいずれかの値として指定します。

    • "auto" — 入力ネットワークとハードウェア リソースに適した最適化の回数を自動的に適用します。

    • "mex" — MEX 関数をコンパイルして実行します。このオプションは GPU の使用時にのみ利用できます。GPU を使用するには Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。Parallel Computing Toolbox または適切な GPU が利用できない場合、エラーが返されます。

    • "none" — すべての高速化を無効にします。

    "auto" オプションまたは "mex" オプションを使用した場合、ソフトウェアはパフォーマンス上のメリットを提供しますが、初期実行時間が長くなります。関数のそれ以降の呼び出しでは、通常、より高速になります。異なる入力データを使用して関数を複数回呼び出す場合は、パフォーマンスの最適化を使用してください。

    Acceleration"mex" の場合、ソフトウェアは、関数の呼び出しで指定したモデルとパラメーターに基づいて MEX 関数を生成し、実行します。1 つのモデルに一度に複数の MEX 関数を関連付けることができます。モデルの変数をクリアすると、そのモデルに関連付けられている MEX 関数もクリアされます。

    Acceleration"auto" の場合、ソフトウェアは MEX 関数を生成しません。

    "mex" オプションは GPU の使用時にのみ利用できます。C/C++ コンパイラがインストールされ、GPU Coder™ Interface for Deep Learning サポート パッケージがなければなりません。MATLAB® でアドオン エクスプローラーを使用してサポート パッケージをインストールします。設定手順については、コンパイラの設定 (GPU Coder)を参照してください。GPU Coder は不要です。

    "mex" オプションを使用する場合、MATLAB Compiler™ ソフトウェアはモデルのコンパイルをサポートしません。

    ハードウェア リソース。次のいずれかの値として指定します。

    • "auto" — 利用可能な場合、GPU を使用します。そうでない場合、CPU を使用します。

    • "gpu" — GPU を使用します。GPU を使用するには Parallel Computing Toolbox ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。Parallel Computing Toolbox または適切な GPU が利用できない場合、エラーが返されます。

    • "cpu" — CPU を使用します。

    出力引数

    すべて折りたたむ

    予測されたクラス。categorical 配列として返されます。

    バージョン履歴

    R2023b で導入