Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

深層学習を使用した複数ラベルをもつテキストの分類

この例では、複数の独立したラベルをもつテキスト データを分類する方法を説明します。

各観測値に複数の独立したラベル (たとえば、科学論文のタグなど) が存在する可能性のある分類タスクでは、独立した各クラスの確率を予測するように深層学習モデルを学習させることが可能です。ネットワークに複数ラベルの分類ターゲットを学習させるには、バイナリ交差エントロピー損失を使用して、各クラスの損失を個別に最適化できます。

この例では、arXiv API [1] を使用して収集した数学の論文の要旨に基づいて、その主題を分類する深層学習モデルを定義します。モデルは、単語埋め込み、GRU、最大プーリング演算、全結合、およびシグモイド演算で構成されます。

複数ラベル分類の性能を測定するには、ラベル付け F 値 [2] が使用できます。ラベル付け F 値は、部分一致をもつテキスト単位の分類に焦点を当てることによって複数ラベルの分類を評価します。この尺度は、真のラベルと予測ラベルの総数に対して一致ラベルの比率を正規化したものです。

この例では、次のモデルを定義します。

  • 単語のシーケンスを数値ベクトルのシーケンスにマッピングする単語埋め込み。

  • 埋め込みベクトル間の依存関係を学習する GRU 演算。

  • 特徴ベクトルのシーケンスを 1 つの特徴ベクトルに減らす最大ポーリング演算。

  • 特徴をバイナリ出力にマッピングする全結合層。

  • 出力とターゲット ラベルの間のバイナリ交差エントロピー損失を学習するシグモイド演算。

次の図は、1 つのテキストがモデル アーキテクチャを通って伝播し、確率のベクトルが出力される過程を示しています。確率は独立していて、総和が 1 になる必要はありません。

テキスト データのインポート

arXiv API を使用して、数学論文から概要とカテゴリ ラベルのセットをインポートします。変数 importSize を使用してインポートするレコードの数を指定します。

importSize = 50000;

セット "math" とメタデータ接頭辞 "arXiv" でレコードをクエリする URL を作成します。

url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
    "&set=math" + ...
    "&metadataPrefix=arXiv";

この例にサポート ファイルとして添付されている関数 parseArXivRecords を使用し、クエリ URL によって返される概要テキスト、カテゴリ ラベル、および再開トークンを抽出します。このファイルにアクセスするには、この例をライブ スクリプトとして開きます。arXiv API ではレートが制限されており、複数のリクエストの間に待つ必要があることに注意してください。

[textData,labelsAll,resumptionToken] = parseArXivRecords(url);

必要な量に達するか、レコードがなくなるまで、レコードのチャンクのインポートを繰り返します。停止した場所からレコードのインポートを続けるには、クエリ URL で前の実行結果の再開トークンを使用します。arXiv API で指定されているレート制限に従うには、関数 pause を使用して各クエリの前に 20 秒の遅延を追加します。

while numel(textData) < importSize
    
    if resumptionToken == ""
        break
    end
    
    url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
        "&resumptionToken=" + resumptionToken;
    
    pause(20)
    [textDataNew,labelsNew,resumptionToken] = parseArXivRecords(url);
    
    textData = [textData; textDataNew];
    labelsAll = [labelsAll; labelsNew];
end

テキスト データの前処理

この例の最後にリストされている関数 preprocessText を使用して、テキスト データのトークン化と前処理を行います。

documentsAll = preprocessText(textData);
documentsAll(1:5)
ans = 
  5×1 tokenizedDocument:

    72 tokens: describe new algorithm $(k,\ell)$ pebble game color obtain characterization family $(k,\ell)$ sparse graph algorithmic solution family problem concerning tree decomposition graph special instance sparse graph appear rigidity theory receive increase attention recent year particular colored pebble generalize strengthen previous result lee streinu give new proof tuttenashwilliams characterization arboricity present new decomposition certify sparsity base $(k,\ell)$ pebble game color work expose connection pebble game algorithm previous sparse graph algorithm gabow gabow westermann hendrickson
    22 tokens: show determinant stirling cycle number count unlabeled acyclic singlesource automaton proof involve bijection automaton certain marked lattice path signreversing involution evaluate determinant
    18 tokens: paper show compute $\lambda_{\alpha}$ norm alpha dyadic grid result consequence description hardy space $h^p(r^n)$ term dyadic special atom
    62 tokens: partial cube isometric subgraphs hypercubes structure graph define mean semicubes djokovi winklers relation play important role theory partial cube structure employ paper characterize bipartite graph partial cube arbitrary dimension new characterization establish new proof know result give operation cartesian product paste expansion contraction process utilize paper construct new partial cube old particular isometric lattice dimension finite partial cube obtain mean operation calculate
    29 tokens: paper present algorithm compute hecke eigensystems hilbertsiegel cusp form real quadratic field narrow class number give illustrative example quadratic field $\q(\sqrt{5})$ example identify hilbertsiegel eigenforms possible lift hilbert eigenforms

"math" セットに属さないラベルを削除します。

for i = 1:numel(labelsAll)
    labelsAll{i} = labelsAll{i}(startsWith(labelsAll{i},"math."));
end

クラスの一部をワード クラウドで可視化します。以下に対応するドキュメントを検索します。

  • "Combinatorics" でタグ付けされていて、"Statistics Theory" でタグ付けされていない要旨

  • "Statistics Theory" でタグ付けされていて、"Combinatorics" でタグ付けされていない要旨

  • "Combinatorics""Statistics Theory" の両方でタグ付けされている要旨

関数 ismember を使用して、各グループのドキュメント インデックスを検索します。

idxCO = cellfun(@(lbls) ismember("math.CO",lbls) && ~ismember("math.ST",lbls),labelsAll);
idxST = cellfun(@(lbls) ismember("math.ST",lbls) && ~ismember("math.CO",lbls),labelsAll);
idxCOST = cellfun(@(lbls) ismember("math.CO",lbls) && ismember("math.ST",lbls),labelsAll);

各グループのドキュメントをワード クラウドで可視化します。

figure
subplot(1,3,1)
wordcloud(documentsAll(idxCO));
title("Combinatorics")

subplot(1,3,2)
wordcloud(documentsAll(idxST));
title("Statistics Theory")

subplot(1,3,3)
wordcloud(documentsAll(idxCOST));
title("Both")

クラス数を表示します。

classNames = unique(cat(1,labelsAll{:}));
numClasses = numel(classNames)
numClasses = 32

ヒストグラムを使用してドキュメント単位のラベルの数を可視化します。

labelCounts = cellfun(@numel,labelsAll);
figure
histogram(labelCounts)
xlabel("Number of Labels")
ylabel("Frequency")
title("Label Counts")

深層学習用のテキスト データの準備

関数 cvpartition を使用して、データを学習用と検証用の区画に分割します。HoldOut オプションを 0.1 に設定し、データの 10% を検証用にホールドアウトします。

cvp = cvpartition(numel(documentsAll),HoldOut=0.1);
documentsTrain = documentsAll(training(cvp));
documentsValidation = documentsAll(test(cvp));

labelsTrain = labelsAll(training(cvp));
labelsValidation = labelsAll(test(cvp));

学習用のドキュメントを単語インデックスのシーケンスとして符号化する単語符号化オブジェクトを作成します。Order オプションを "frequency" に、MaxNumWords オプションを 5000 に設定し、5000 単語のボキャブラリを指定します。

enc = wordEncoding(documentsTrain,Order="frequency",MaxNumWords=5000)
enc = 
  wordEncoding with properties:

      NumWords: 5000
    Vocabulary: [1×5000 string]

学習の精度を高めるには、次の手法を使用します。

  1. 学習時に、使用されるパディングの量を減らしながら、データを破棄し過ぎない長さにドキュメントを打ち切る。

  2. 長さの昇順に並べ替えられたドキュメントで 1 エポック学習させ、その後はエポックごとにデータをシャッフルする。この手法は SortaGrad として知られています。

切り捨てるシーケンスの長さを選択するには、ドキュメントの長さをヒストグラムで可視化して、データの大部分を捉える値を選択します。

documentLengths = doclength(documentsTrain);

figure
histogram(documentLengths)
xlabel("Document Length")
ylabel("Frequency")
title("Document Lengths")

学習ドキュメントのほとんどは 175 トークン未満です。切り捨てとパディングの長さのターゲットとして 175 のトークンを使用します。

maxSequenceLength = 175;

SortaGrad 手法を使用するには、ドキュメントを長さの昇順に並べ替えます。

[~,idx] = sort(documentLengths);
documentsTrain = documentsTrain(idx);
labelsTrain = labelsTrain(idx);

モデル パラメーターの定義と初期化

各演算のパラメーターを定義して struct に含めます。parameters.OperationName.ParameterName の形式で使用します。ここで、parameters は struct、OperationName は演算名 ("fc" など)、ParameterName はパラメーター名 ("Weights" など) です。

モデル パラメーターを含む struct parameters を作成します。ゼロでバイアスを初期化します。演算には次の重み初期化子を使用します。

  • 埋め込みでは、関数 initializeGaussian を使用して重みを初期化します。

  • GRU 演算では、関数 initializeGlorot と関数 initializeZeros を使用して、それぞれ重みとバイアスを初期化します。

  • 全結合演算では、関数 initializeGaussian と関数 initializeZeros を使用して、それぞれ重みとバイアスを初期化します。

初期化関数 initializeGlorotinitializeGaussian、および initializeZeros は、この例にサポート ファイルとして添付されています。これらの関数にアクセスするには、例をライブ スクリプトとして開きます。

埋め込みに関する学習可能パラメーターを初期化します。

embeddingDimension = 300;
numHiddenUnits = 250;
inputSize = enc.NumWords + 1;

parameters = struct;

sz = [embeddingDimension inputSize];
mu = 0;
sigma = 0.01;
parameters.emb.Weights = initializeGaussian(sz,mu,sigma);

GRU 演算の使用に関する学習可能パラメーターを初期化します。

sz = [3*numHiddenUnits embeddingDimension];
numOut = 3*numHiddenUnits;
numIn = embeddingDimension;
parameters.gru.InputWeights = initializeGlorot(sz,numOut,numIn);

sz = [3*numHiddenUnits numHiddenUnits];
numOut = 3*numHiddenUnits;
numIn = numHiddenUnits;
parameters.gru.RecurrentWeights = initializeGlorot(sz,numOut,numIn);

sz = [3*numHiddenUnits 1];
parameters.gru.Bias = initializeZeros(sz);

全結合演算に関する学習可能パラメーターを初期化します。

sz = [numClasses numHiddenUnits];
mu = 0;
sigma = 0.01;
parameters.fc.Weights = initializeGaussian(sz,mu,sigma);

sz = [numClasses 1];
parameters.fc.Bias = initializeZeros(sz);

parameters struct を表示します。

parameters
parameters = struct with fields:
    emb: [1×1 struct]
    gru: [1×1 struct]
     fc: [1×1 struct]

GRU 演算のパラメーターを表示します。

parameters.gru
ans = struct with fields:
        InputWeights: [750×300 dlarray]
    RecurrentWeights: [750×250 dlarray]
                Bias: [750×1 dlarray]

モデルの関数の定義

この例の最後にリストされている関数 model を作成します。この関数は前に説明した深層学習モデルの出力を計算します。関数 model は、入力データとモデル パラメーターを入力として受け取ります。ネットワークはラベルの予測を出力します。

モデル損失関数の定義

この例の最後にリストされている関数 modelLoss を作成します。この関数は入力データのミニバッチと、対応するターゲットを入力として受け取り、損失、学習可能パラメーターについての損失の勾配、およびネットワーク出力を返します。

学習オプションの指定

ミニバッチ サイズを 256 として 5 エポック学習させます。

numEpochs = 5;
miniBatchSize = 256;

Adam オプティマイザーを使用して学習させます。学習率は 0.01 に、勾配の減衰係数と 2 乗勾配の減衰係数はそれぞれ 0.5 と 0.999 に設定します。

learnRate = 0.01;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

L2 ノルムの勾配クリップを使用して、しきい値 1 で勾配をクリップします。

gradientThreshold = 1;

確率のベクトルをラベルに変換するには、指定したしきい値より高い確率のラベルを使用します。ラベルのしきい値に 0.5 を指定します。

labelThreshold = 0.5;

エポックごとにネットワークを検証します。

numObservationsTrain = numel(documentsTrain);
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);
validationFrequency = numIterationsPerEpoch;

モデルの学習

学習の進行状況プロットを初期化します。F 値と損失に関するアニメーションの線を作成します。

figure
C = colororder;

subplot(2,1,1)
lineFScoreTrain = animatedline(Color=C(1,:));
lineFScoreValidation = animatedline( ...
    LineStyle="--", ...
    Marker="o", ...
    MarkerFaceColor="black");
ylim([0 1])
xlabel("Iteration")
ylabel("Labeling F-Score")
grid on

subplot(2,1,2)
lineLossTrain = animatedline(Color=C(2,:));
lineLossValidation = animatedline( ...
    LineStyle="--", ...
    Marker="o", ...
    MarkerFaceColor="black");
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

Adam オプティマイザーのパラメーターを初期化します。

trailingAvg = [];
trailingAvgSq = [];

検証データを準備します。非ゼロのエントリが各観測値のラベルに対応する、one-hot 符号化された行列を作成します。

numObservationsValidation = numel(documentsValidation);
TValidation = zeros(numClasses, numObservationsValidation,"single");
for i = 1:numObservationsValidation
    [~,idx] = ismember(labelsValidation{i},classNames);
    TValidation(idx,i) = 1;
end

カスタム学習ループを使用してモデルに学習させます。

各エポックで、データのミニバッチに対してループします。各エポックの最後にデータをシャッフルします。各反復の最後に、学習の進行状況プロットを更新します。

各ミニバッチで次を行います。

  • ドキュメントを単語インデックスのシーケンスに、ラベルをダミー変数に変換します。

  • 基となる型が single の dlarray オブジェクトにシーケンスを変換し、次元ラベル "BTC" (batch、time、channel) を指定します。

  • GPU が利用できる場合、GPU で学習を行います。これには、Parallel Computing Toolbox™ が必要です。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

  • GPU で学習する場合、gpuArray オブジェクトに変換します。

  • 関数 dlfeval および modelLoss を使用してモデルの損失と勾配を評価します。

  • 勾配をクリップします。

  • 関数 adamupdate を使用してネットワーク パラメーターを更新します。

  • 必要に応じて、この例の最後にリストされている関数 modelPredictions を使用してネットワークを検証します。

  • 学習プロットを更新します。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        documents = documentsTrain(idx);
        labels = labelsTrain(idx);
        
        % Convert documents to sequences.
        len = min(maxSequenceLength,max(doclength(documents)));
        X = doc2sequence(enc,documents, ...
            PaddingValue=inputSize, ...
            Length=len);
        X = cat(1,X{:});
        
        % Dummify labels.
        T = zeros(numClasses,miniBatchSize,"single");
        for j = 1:miniBatchSize
            [~,idx2] = ismember(labels{j},classNames);
            T(idx2,j) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        X = dlarray(X,"BTC");
        
        % If training on a GPU, then convert data to gpuArray.
        if canUseGPU
            X = gpuArray(X);
        end
        
        % Evaluate the model loss, gradients, and predictions using dlfeval and the
        % modelLoss function.
        [loss,gradients,Y] = dlfeval(@modelLoss,X,T,parameters);
        
        % Gradient clipping.
        gradients = dlupdate(@(g) thresholdL2Norm(g,gradientThreshold),gradients);
        
        % Update the network parameters using the Adam optimizer.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAvgSq,iteration,learnRate, ...
            gradientDecayFactor,squaredGradientDecayFactor);

        % Display the training progress.
        subplot(2,1,1)
        D = duration(0,0,toc(start),Format="hh:mm:ss");
        title("Epoch: " + epoch + ", Elapsed: " + string(D))

        % Loss.
        loss = double(loss);
        addpoints(lineLossTrain,iteration,loss)

        % Labeling F-score.
        Y = Y > labelThreshold;
        score = labelingFScore(Y,T);
        addpoints(lineFScoreTrain,iteration,double(gather(score)))

        drawnow

        % Display validation metrics.
        if iteration == 1 || mod(iteration,validationFrequency) == 0
            YValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);

            % Loss.
            lossValidation = crossentropy(YValidation,TValidation, ...
                ClassificationMode="multilabel", ...
                DataFormat="CB");
            lossValidation = double(lossValidation);
            addpoints(lineLossValidation,iteration,lossValidation)

            % Labeling F-score.
            YValidation = YValidation > labelThreshold;
            score = labelingFScore(YValidation,TValidation);
            score = double(score);
            addpoints(lineFScoreValidation,iteration,score)

            drawnow
        end
    end
    
    % Shuffle data.
    idx = randperm(numObservationsTrain);
    documentsTrain = documentsTrain(idx);
    labelsTrain = labelsTrain(idx);
end

モデルのテスト

一連の新しいデータの予測を行うには、この例の最後にリストされている関数 modelPredictions を使用します。関数 modelPredictions は、モデル パラメーター、単語符号化、およびトークン化されたドキュメントの配列を入力として受け取り、指定されたミニバッチ サイズと最大シーケンス長に対応するモデルの予測を出力します。

YValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);

性能を評価するには、この例の最後にリストされている関数 labelingFScore を使用してラベル付け F 値を計算します。ラベル付け F 値は、部分一致をもつテキスト単位の分類に焦点を当てることによって複数ラベルの分類を評価します。ネットワーク出力をラベルの配列に変換するには、指定されたラベルのしきい値よりも高いスコアのラベルを検索します。

score = labelingFScore(YValidation > labelThreshold,TValidation)
score = single
    0.5663

しきい値にさまざまな値を試して結果を比較することにより、ラベル付けのしきい値がラベル付け F 値に与える影響を表示します。

thr = linspace(0,1,10);
score = zeros(size(thr));
for i = 1:numel(thr)
    YPredValidationThr = YValidation >= thr(i);
    score(i) = labelingFScore(YPredValidationThr,TValidation);
end

figure
plot(thr,score)
xline(labelThreshold,"r--");
xlabel("Threshold")
ylabel("Labeling F-Score")
title("Effect of Labeling Threshold")

予測の可視化

分類の正しい予測を可視化するには、真陽性の数を計算します。真陽性は、観測値の特定クラスを正しく予測する分類器のインスタンスです。

Y = YValidation > labelThreshold;
T = TValidation;

numTruePositives = sum(T & Y,2);

numObservationsPerClass = sum(T,2);
truePositiveRates = numTruePositives ./ numObservationsPerClass;

各クラスの真陽性の数をヒストグラムで可視化します。

figure
truePositiveRates = extractdata(truePositiveRates);
[~,idx] = sort(truePositiveRates,"descend");
histogram(Categories=classNames(idx),BinCounts=truePositiveRates(idx))
xlabel("Category")
ylabel("True Positive Rate")
title("True Positive Rates")

真陽性、偽陽性、および偽陰性の分布を表示して、分類器が誤って予測するインスタンスを可視化します。偽陽性は、観測値に間違った特定クラスを割り当てる分類器のインスタンスを指します。偽陰性は、観測値に正しい特定クラスを割り当てることに失敗する分類器のインスタンスを指します。

真陽性、偽陽性、および偽陰性の数を表す混同行列を作成します。

  • 各クラスについて、真陽性の数を対角線上に表示。

  • クラスの各ペア (ij) について、j が偽陽性でかつ i が偽陰性であるインスタンスの数を表示。

したがって、混同行列の要素は次によって求められます。

TPFNij={numTruePositives(i),if i=jnumFalsePositives(j|i is a false negative),if ijTrue positive, false negative rates

偽陰性と偽陽性を計算します。

falseNegatives = T & ~Y;
falsePositives = ~T & Y;

非対角要素を計算します。

falseNegatives = permute(falseNegatives,[3 2 1]);
numConditionalFalsePositives = sum(falseNegatives & falsePositives, 2);
numConditionalFalsePositives = squeeze(numConditionalFalsePositives);

tpfnMatrix = numConditionalFalsePositives;

対角要素に真陽性の数を設定します。

idxDiagonal = 1:numClasses+1:numClasses^2;
tpfnMatrix(idxDiagonal) = numTruePositives;

関数 confusionchart を使用して混同行列の真陽性と偽陽性の数を可視化し、対角要素が降順になるように行列を並べ替えます。

figure
tpfnMatrix = extractdata(tpfnMatrix);
cm = confusionchart(tpfnMatrix,classNames);
sortClasses(cm,"descending-diagonal");
title("True Positives, False Positives")

行列の詳細を確認するには、この例をライブ スクリプトとして開き、新しいウィンドウで Figure を開きます。

テキスト前処理関数

関数 preprocessText は、入力テキスト データのトークン化と前処理を次の手順で実行します。

  1. 関数 tokenizedDocument を使用してテキストをトークン化します。2 つの "$" 記号の間に現れるテキストを取得する正規表現 "\$.*?\$" を指定し、RegularExpressions オプションを使用して数式を単一のトークンとして抽出します。

  2. 関数 erasePunctuation を使用して句読点を消去します。

  3. 関数 lower を使用してテキストを小文字に変換します。

  4. 関数 removeStopWords を使用してストップ ワードを除去します。

  5. Style オプションを "lemma" に設定して関数 normalizeWords を使用し、テキストをレンマ化します。

function documents = preprocessText(textData)

% Tokenize the text.
regularExpressions = table;
regularExpressions.Pattern = "\$.*?\$";
regularExpressions.Type = "equation";

documents = tokenizedDocument(textData,RegularExpressions=regularExpressions);

% Erase punctuation.
documents = erasePunctuation(documents);

% Convert to lowercase.
documents = lower(documents);

% Lemmatize.
documents = addPartOfSpeechDetails(documents);
documents = normalizeWords(documents,Style="lemma");

% Remove stop words.
documents = removeStopWords(documents);

% Remove short words.
documents = removeShortWords(documents,2);

end

モデル関数

関数 model は、入力データ X とモデル パラメーター parameters を入力として受け取り、ラベルに対する予測を返します。

function Y = model(X,parameters)

% Embedding
weights = parameters.emb.Weights;
X = embed(X,weights);

% GRU
inputWeights = parameters.gru.InputWeights;
recurrentWeights = parameters.gru.RecurrentWeights;
bias = parameters.gru.Bias;

numHiddenUnits = size(inputWeights,1)/3;
hiddenState = dlarray(zeros([numHiddenUnits 1]));

Y = gru(X,hiddenState,inputWeights,recurrentWeights,bias);

% Max pooling along time dimension
Y = max(Y,[],3);

% Fully connect
weights = parameters.fc.Weights;
bias = parameters.fc.Bias;
Y = fullyconnect(Y,weights,bias);

% Sigmoid
Y = sigmoid(Y);

end

モデル損失関数

関数 modelLoss は、入力データ X のミニバッチと、ラベルを含む対応するターゲット T を入力として受け取り、損失、学習可能パラメーターについての損失の勾配、およびネットワーク出力を返します。

function [loss,gradients,Y] = modelLoss(X,T,parameters)

Y = model(X,parameters);

loss = crossentropy(Y,T,ClassificationMode="multilabel");

gradients = dlgradient(loss,parameters);

end

モデル予測関数

関数 modelPredictions は、モデル パラメーター、単語符号化、トークン化されたドキュメントの配列、ミニバッチ サイズ、および最大シーケンス長を入力として受け取り、指定サイズのミニバッチを反復することによりモデル予測を返します。

function Y = modelPredictions(parameters,enc,documents,miniBatchSize,maxSequenceLength)

inputSize = enc.NumWords + 1;

numObservations = numel(documents);
numIterations = ceil(numObservations / miniBatchSize);

numFeatures = size(parameters.fc.Weights,1);
Y = zeros(numFeatures,numObservations,"like",parameters.fc.Weights);

for i = 1:numIterations
    
    idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    
    len = min(maxSequenceLength,max(doclength(documents(idx))));
    X = doc2sequence(enc,documents(idx), ...
        PaddingValue=inputSize, ...
        Length=len);
    X = cat(1,X{:});
    
    X = dlarray(X,"BTC");
    
    Y(:,idx) = model(X,parameters);
end

end

ラベル付け F 値の関数

ラベル付け F スコア関数 [2] は、テキストごとの分類の部分一致に焦点を当てて複数ラベル分類を評価します。測定では、真のラベルと予測ラベルの総数に対して一致するラベルの割合を正規化します。次で求められます。

1Nn=1N(2c=1CYncTncc=1C(Ync+Tnc)),Labeling F-Score

ここで、NC はそれぞれ観測値とクラスの数に対応し、YT はそれぞれ予測とターゲットに対応します。

function score = labelingFScore(Y,T)

numObservations = size(T,2);

scores = (2 * sum(Y .* T)) ./ sum(Y + T);
score = sum(scores) / numObservations;

end

勾配クリップ関数

関数 thresholdL2Norm は、入力の勾配をスケーリングして、学習可能なパラメーターの勾配の L2 ノルムの値が指定のしきい値より大きい場合に L2 ノルムの値が指定の勾配しきい値に等しくなるようにします。

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end

参考文献

  1. arXiv. "arXiv API." Accessed January 15, 2020. https://arxiv.org/help/api

  2. Sokolova, Marina, and Guy Lapalme. "A Systematic Analysis of Performance Measures for Classification Tasks." Information Processing & Management 45, no. 4 (2009): 427–437.

参考

(Text Analytics Toolbox) | | | | | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

関連するトピック