classify
説明
例
CSV ファイル factoryReports から学習データを読み取ります。このファイルには、各レポートの説明テキストとカテゴリカル ラベルを含む工場レポートが格納されています。
filename = "factoryReports.csv"; data = readtable(filename,TextType="string"); head(data)
Description Category Urgency Resolution Cost
_____________________________________________________________________ ____________________ ________ ____________________ _____
"Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45
"Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35
"There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200
"Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352
"Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55
"Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371
"A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441
"Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38
table の Category 列のラベルを categorical 値に変換します。
data.Category = categorical(data.Category);
データを学習セットとテスト セットに分割します。ホールドアウトの割合を 10% に指定します。
cvp = cvpartition(data.Category,Holdout=0.1); dataTrain = data(cvp.training,:); dataTest = data(cvp.test,:);
table からテキスト データとラベルを抽出します。
textDataTrain = dataTrain.Description; textDataTest = dataTest.Description; TTrain = dataTrain.Category; TTest = dataTest.Category;
bertDocumentClassifier 関数を使用して、事前学習済みの BERT-Base 文書分類器を読み込みます。
classNames = categories(data.Category); mdl = bertDocumentClassifier(ClassNames=classNames)
mdl =
bertDocumentClassifier with properties:
Network: [1×1 dlnetwork]
Tokenizer: [1×1 bertTokenizer]
ClassNames: ["Electronic Failure" "Leak" "Mechanical Failure" "Software Failure"]
学習オプションを指定します。学習オプションの中から選択するには、経験的解析が必要です。実験を実行してさまざまな学習オプションの構成を調べるには、実験マネージャー (Deep Learning Toolbox)アプリを使用できます。
Adam オプティマイザーを使用して学習させます。
学習を 8 エポック行います。
微調整を行うため、学習率を下げます。学習率 0.0001 を使用して学習させます。
すべてのエポックでデータをシャッフルします。
学習の進行状況をプロットで監視し、精度メトリクスを監視します。
詳細出力を無効にします。
options = trainingOptions("adam", ... MaxEpochs=8, ... InitialLearnRate=1e-4, ... Shuffle="every-epoch", ... Plots="training-progress", ... Metrics="accuracy", ... Verbose=false);
trainBERTDocumentClassifier 関数を使用してニューラル ネットワークに学習させます。既定では、trainBERTDocumentClassifier 関数は利用可能な GPU がある場合にそれを使用します。GPU での学習には、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数 trainBERTDocumentClassifier は CPU を使用します。実行環境を指定するには、ExecutionEnvironment 学習オプションを使用します。
mdl = trainBERTDocumentClassifier(textDataTrain,TTrain,mdl,options);

テスト データを使用して予測を行います。
YTest = classify(mdl,textDataTest);
テスト予測の分類精度を計算します。
accuracy = mean(TTest == YTest)
accuracy = 0.9375
入力引数
BERT 文書分類器のモデル。bertDocumentClassifier オブジェクトとして指定します。
入力文書。string 配列、文字ベクトルの cell 配列、または tokenizedDocument 配列として指定します。
名前と値の引数
オプションの引数のペアを Name1=Value1,...,NameN=ValueN として指定します。ここで、Name は引数名で、Value は対応する値です。名前と値の引数は他の引数の後に指定しなければなりませんが、ペアの順序は重要ではありません。
例: classify(mdl,document,MiniBatchSize=64) は、サイズ 64 のミニバッチを使用して、指定された文書を分類します。
予測に使用するミニバッチのサイズ。正の整数として指定します。ミニバッチのサイズが大きくなるとより多くのメモリが必要になりますが、予測時間が短縮される可能性があります。
データ型: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64
パフォーマンスの最適化。次のいずれかの値として指定します。
"auto"— 入力ネットワークとハードウェア リソースに適した最適化の回数を自動的に適用します。"mex"— MEX 関数をコンパイルして実行します。このオプションは GPU の使用時にのみ利用できます。GPU を使用するには Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。Parallel Computing Toolbox または適切な GPU が利用できない場合、エラーが返されます。"none"— すべての高速化を無効にします。
"auto" オプションまたは "mex" オプションを使用した場合、ソフトウェアはパフォーマンス上のメリットを提供しますが、初期実行時間が長くなります。関数のそれ以降の呼び出しでは、通常、より高速になります。異なる入力データを使用して関数を複数回呼び出す場合は、パフォーマンスの最適化を使用してください。
Acceleration が "mex" の場合、ソフトウェアは、関数の呼び出しで指定したモデルとパラメーターに基づいて MEX 関数を生成し、実行します。1 つのモデルに一度に複数の MEX 関数を関連付けることができます。モデルの変数をクリアすると、そのモデルに関連付けられている MEX 関数もクリアされます。
Acceleration が "auto" の場合、ソフトウェアは MEX 関数を生成しません。
"mex" オプションは GPU の使用時にのみ利用できます。C/C++ コンパイラがインストールされ、GPU Coder™ Interface for Deep Learning サポート パッケージがなければなりません。MATLAB® でアドオン エクスプローラーを使用してサポート パッケージをインストールします。設定手順については、コンパイラの設定 (GPU Coder)を参照してください。GPU Coder は不要です。
"mex" オプションを使用する場合、MATLAB Compiler™ ソフトウェアはモデルのコンパイルをサポートしません。
ハードウェア リソース。次のいずれかの値として指定します。
"auto"— 利用可能な場合、GPU を使用します。そうでない場合、CPU を使用します。"gpu"— GPU を使用します。GPU を使用するには Parallel Computing Toolbox ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。Parallel Computing Toolbox または適切な GPU が利用できない場合、エラーが返されます。"cpu"— CPU を使用します。
出力引数
予測されたクラス。categorical 配列として返されます。
バージョン履歴
R2023b で導入
参考
trainBERTDocumentClassifier | bertDocumentClassifier | bert | dlnetwork (Deep Learning Toolbox) | bertTokenizer
トピック
- BERT 文書分類器の学習
- 深層学習を使用したテキスト データの分類
- 分類用の単純なテキスト モデルの作成
- トピック モデルを使用したテキスト データの解析
- マルチワード フレーズを使用したテキスト データの解析
- 深層学習を使用したシーケンスの分類 (Deep Learning Toolbox)
- MATLAB による深層学習 (Deep Learning Toolbox)
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Web サイトの選択
Web サイトを選択すると、翻訳されたコンテンツにアクセスし、地域のイベントやサービスを確認できます。現在の位置情報に基づき、次のサイトの選択を推奨します:
また、以下のリストから Web サイトを選択することもできます。
最適なサイトパフォーマンスの取得方法
中国のサイト (中国語または英語) を選択することで、最適なサイトパフォーマンスが得られます。その他の国の MathWorks のサイトは、お客様の地域からのアクセスが最適化されていません。
南北アメリカ
- América Latina (Español)
- Canada (English)
- United States (English)
ヨーロッパ
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)