Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

深層学習を使用した複数ラベルをもつテキストの分類

この例では、複数の独立したラベルをもつテキスト データを分類する方法を説明します。

各観測値に複数の独立したラベル (たとえば、科学論文のタグなど) が存在する可能性のある分類タスクでは、独立した各クラスの確率を予測するように深層学習モデルを学習させることが可能です。ネットワークに複数ラベルの分類ターゲットを学習させるには、バイナリ交差エントロピー損失を使用して、各クラスの損失を個別に最適化できます。

この例では、arXiv API [1] を使用して収集した数学の論文の要旨に基づいて、その主題を分類する深層学習モデルを定義します。モデルは、単語埋め込み、GRU、最大プーリング演算、全結合、およびシグモイド演算で構成されます。

複数ラベル分類の性能を測定するには、ラベル付け F 値 [2] が使用できます。ラベル付け F 値は、部分一致をもつテキスト単位の分類に焦点を当てることによって複数ラベルの分類を評価します。この尺度は、真のラベルと予測ラベルの総数に対して一致ラベルの比率を正規化したものです。

この例では、次のモデルを定義します。

  • 単語のシーケンスを数値ベクトルのシーケンスにマッピングする単語埋め込み。

  • 埋め込みベクトル間の依存関係を学習する GRU 演算。

  • 特徴ベクトルのシーケンスを 1 つの特徴ベクトルに減らす最大ポーリング演算。

  • 特徴をバイナリ出力にマッピングする全結合層。

  • 出力とターゲット ラベルの間のバイナリ交差エントロピー損失を学習するシグモイド演算。

次の図は、1 つのテキストがモデル アーキテクチャを通って伝播し、確率のベクトルが出力される過程を示しています。確率は独立していて、総和が 1 になる必要はありません。

テキスト データのインポート

arXiv API を使用して、数学論文から概要とカテゴリ ラベルのセットをインポートします。変数 importSize を使用してインポートするレコードの数を指定します。arXiv API では、1 回のクエリで 1,000 個の記事までにレートが制限されており、リクエストの間に待つ必要があることに注意してください。

importSize = 50000;

1 番目のレコード セットをインポートします。

url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
    "&set=math" + ...
    "&metadataPrefix=arXiv";
options = weboptions('Timeout',160);
code = webread(url,options);

返された XML コンテンツを解析して、レコードの情報を含む htmlTree オブジェクトの配列を作成します。

tree = htmlTree(code);
subtrees = findElement(tree,"record");
numel(subtrees)

必要な量に達するか、レコードがなくなるまで、レコードのチャンクのインポートを繰り返します。停止した場所からレコードのインポートを続けるには、前の実行結果の属性 resumptionToken を使用します。arXiv API で指定されているレート制限に従うには、関数 pause を使用して各クエリの前に 20 秒の遅延を追加します。

while numel(subtrees) < importSize
    subtreeResumption = findElement(tree,"resumptionToken");
    
    if isempty(subtreeResumption)
        break
    end
    
    resumptionToken = extractHTMLText(subtreeResumption);
    
    url = "https://export.arxiv.org/oai2?verb=ListRecords" + ...
        "&resumptionToken=" + resumptionToken;
    
    pause(20)
    code = webread(url,options);
    
    tree = htmlTree(code);
    
    subtrees = [subtrees; findElement(tree,"record")];
end

テキスト データの抽出と前処理

解析した HTML ツリーから要旨とラベルを抽出します。

関数 findElement を使用して "<abstract>" 要素と "<categories>" 要素を検索します。

subtreeAbstract = htmlTree("");
subtreeCategory = htmlTree("");

for i = 1:numel(subtrees)
    subtreeAbstract(i) = findElement(subtrees(i),"abstract");
    subtreeCategory(i) = findElement(subtrees(i),"categories");
end

関数 extractHTMLText を使用して、要旨を含むサブツリーからテキスト データを抽出します。

textData = extractHTMLText(subtreeAbstract);

この例の最後にリストされている関数 preprocessText を使用して、テキスト データのトークン化と前処理を行います。

documentsAll = preprocessText(textData);
documentsAll(1:5)
ans = 
  5×1 tokenizedDocument:

    72 tokens: describe new algorithm $(k,\ell)$ pebble game color obtain characterization family $(k,\ell)$ sparse graph algorithmic solution family problem concern tree decomposition graph special instance sparse graph appear rigidity theory receive increase attention recent year particular colored pebble generalize strengthen previous result lee streinu give new proof tuttenashwilliams characterization arboricity present new decomposition certify sparsity base $(k,\ell)$ pebble game color work expose connection pebble game algorithm previous sparse graph algorithm gabow gabow westermann hendrickson
    22 tokens: show determinant stirling cycle number count unlabeled acyclic singlesource automaton proof involve bijection automaton certain marked lattice path signreversing involution evaluate determinant
    18 tokens: "paper" "show" "compute" "$\lambda_{\alpha}$" "norm" "$\alpha\ge 0$" "dyadic" "grid" "result" "consequence" "description" "hardy" "space" "$h^p(r^n)$" "term" "dyadic" "special" "atom"
    62 tokens: partial cube isometric subgraphs hypercubes structure graph define mean semicubes djokovi winklers relation play important role theory partial cube structure employ paper characterize bipartite graph partial cube arbitrary dimension new characterization establish new proof know result give operation cartesian product pasting expansion contraction process utilized paper construct new partial cube old particular isometric lattice dimension finite partial cube obtain mean operation calculate
    29 tokens: paper present algorithm compute hecke eigensystems hilbertsiegel cusp form real quadratic field narrow class number give illustrative example quadratic field $\q(\sqrt{5})$ example identify hilbertsiegel eigenforms possible lift hilbert eigenforms

ラベルを含むサブツリーからラベルを抽出します。

strLabels = extractHTMLText(subtreeCategory);
labelsAll = arrayfun(@split,strLabels,'UniformOutput',false);

"math" セットに属さないラベルを削除します。

for i = 1:numel(labelsAll)
    labelsAll{i} = labelsAll{i}(startsWith(labelsAll{i},"math."));
end

クラスの一部をワード クラウドで可視化します。以下に対応するドキュメントを検索します。

  • "Combinatorics" でタグ付けされていて、"Statistics Theory" でタグ付けされていない要旨

  • "Statistics Theory" でタグ付けされていて、"Combinatorics" でタグ付けされていない要旨

  • "Combinatorics""Statistics Theory" の両方でタグ付けされている要旨

関数 ismember を使用して、各グループのドキュメント インデックスを検索します。

idxCO = cellfun(@(lbls) ismember("math.CO",lbls) && ~ismember("math.ST",lbls),labelsAll);
idxST = cellfun(@(lbls) ismember("math.ST",lbls) && ~ismember("math.CO",lbls),labelsAll);
idxCOST = cellfun(@(lbls) ismember("math.CO",lbls) && ismember("math.ST",lbls),labelsAll);

各グループのドキュメントをワード クラウドで可視化します。

figure
subplot(1,3,1)
wordcloud(documentsAll(idxCO));
title("Combinatorics")

subplot(1,3,2)
wordcloud(documentsAll(idxST));
title("Statistics Theory")

subplot(1,3,3)
wordcloud(documentsAll(idxCOST));
title("Both")

クラス数を表示します。

classNames = unique(cat(1,labelsAll{:}));
numClasses = numel(classNames)
numClasses = 32

ヒストグラムを使用してドキュメント単位のラベルの数を可視化します。

labelCounts = cellfun(@numel, labelsAll);
figure
histogram(labelCounts)
xlabel("Number of Labels")
ylabel("Frequency")
title("Label Counts")

深層学習用のテキスト データの準備

関数 cvpartition を使用して、データを学習用と検証用の区画に分割します。'HoldOut' オプションを 0.1 に設定し、データの 10% を検証用にホールド アウトします。

cvp = cvpartition(numel(documentsAll),'HoldOut',0.1);
documentsTrain = documentsAll(training(cvp));
documentsValidation = documentsAll(test(cvp));

labelsTrain = labelsAll(training(cvp));
labelsValidation = labelsAll(test(cvp));

学習用のドキュメントを単語インデックスのシーケンスとして符号化する単語符号化オブジェクトを作成します。'Order' オプションを 'frequency' に、'MaxNumWords' オプションを 5000 に設定し、5000 単語のボキャブラリを指定します。

enc = wordEncoding(documentsTrain,'Order','frequency','MaxNumWords',5000)
enc = 
  wordEncoding with properties:

      NumWords: 5000
    Vocabulary: [1×5000 string]

学習の精度を高めるには、次の手法を使用します。

  1. 学習時に、使用されるパディングの量を減らしながら、データを破棄し過ぎない長さにドキュメントを打ち切る。

  2. 長さの昇順に並べ替えられたドキュメントで 1 エポック学習させ、その後はエポックごとにデータをシャッフルする。この手法は SortaGrad として知られています。

切り捨てるシーケンスの長さを選択するには、ドキュメントの長さをヒストグラムで可視化して、データの大部分を捉える値を選択します。

documentLengths = doclength(documentsTrain);

figure
histogram(documentLengths)
xlabel("Document Length")
ylabel("Frequency")
title("Document Lengths")

学習ドキュメントのほとんどは 175 トークン未満です。切り捨てとパディングの長さのターゲットとして 175 のトークンを使用します。

maxSequenceLength = 175;

SortaGrad 手法を使用するには、ドキュメントを長さの昇順に並べ替えます。

[~,idx] = sort(documentLengths);
documentsTrain = documentsTrain(idx);
labelsTrain = labelsTrain(idx);

モデル パラメーターの定義と初期化

各演算のパラメーターを定義して構造体に含めます。parameters.OperationName.ParameterName の形式で使用します。ここで、parameters は構造体、OperationName は演算名 ("fc" など)、ParameterName はパラメーター名 ("Weights" など) です。

モデル パラメーターを含む構造体 parameters を作成します。ゼロでバイアスを初期化します。演算には次の重み初期化子を使用します。

  • 埋め込みについては、ランダムな正常値で重みを初期化します。

  • GRU 演算については、この例の最後にリストされている関数 initializeGlorot を使用して重みを初期化します。

  • 全結合演算については、この例の最後にリストされている関数 initializeGaussian を使用して重みを初期化します。

embeddingDimension = 300;
numHiddenUnits = 250;
inputSize = enc.NumWords + 1;

parameters = struct;
parameters.emb.Weights = dlarray(randn([embeddingDimension inputSize]));

parameters.gru.InputWeights = dlarray(initializeGlorot(3*numHiddenUnits,embeddingDimension));
parameters.gru.RecurrentWeights = dlarray(initializeGlorot(3*numHiddenUnits,numHiddenUnits));
parameters.gru.Bias = dlarray(zeros(3*numHiddenUnits,1,'single'));

parameters.fc.Weights = dlarray(initializeGaussian([numClasses,numHiddenUnits]));
parameters.fc.Bias = dlarray(zeros(numClasses,1,'single'));

parameters 構造体を表示します。

parameters
parameters = struct with fields:
    emb: [1×1 struct]
    gru: [1×1 struct]
     fc: [1×1 struct]

GRU 演算のパラメーターを表示します。

parameters.gru
ans = struct with fields:
        InputWeights: [750×300 dlarray]
    RecurrentWeights: [750×250 dlarray]
                Bias: [750×1 dlarray]

モデルの関数の定義

この例の最後にリストされている関数 model を作成します。この関数は前に説明した深層学習モデルの出力を計算します。関数 model は、入力データ dlX とモデル パラメーター parameters を入力として受け取ります。ネットワークはラベルの予測を出力します。

モデル勾配関数の定義

この例の最後にリストされている関数 modelGradients を作成します。この関数は入力データ dlX のミニバッチと、ラベルを含む対応するターゲット T を入力として受け取り、学習可能なパラメーターについての損失の勾配、対応する損失、およびネットワーク出力を返します。

学習オプションの指定

ミニバッチ サイズを 256 として 5 エポック学習させます。

numEpochs = 5;
miniBatchSize = 256;

Adam オプティマイザーを使用して学習させます。学習率は 0.01 に、勾配の減衰係数と 2 乗勾配の減衰係数はそれぞれ 0.5 と 0.999 に設定します。

learnRate = 0.01;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

L2 ノルムの勾配クリップを使用して、しきい値 1 で勾配をクリップします。

gradientThreshold = 1;

学習の進行状況をプロットに可視化します。

plots = "training-progress";

確率のベクトルをラベルに変換するには、指定したしきい値より高い確率のラベルを使用します。ラベルのしきい値に 0.5 を指定します。

labelThreshold = 0.5;

エポックごとにネットワークを検証します。

numObservationsTrain = numel(documentsTrain);
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);
validationFrequency = numIterationsPerEpoch;

GPU が利用できる場合、GPU で学習を行います。これには、Parallel Computing Toolbox™ が必要です。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

executionEnvironment = "auto";

モデルの学習

カスタム学習ループを使用してモデルに学習させます。

各エポックで、データのミニバッチに対してループします。各エポックの最後にデータをシャッフルします。各反復の最後に、学習の進行状況プロットを更新します。

各ミニバッチで次を行います。

  • ドキュメントを単語インデックスのシーケンスに、ラベルをダミー変数に変換。

  • 基となる型が single の dlarray オブジェクトにシーケンスを変換し、次元ラベル 'BCT' (batch、channel、time) を指定。

  • GPU で学習する場合、gpuArray オブジェクトに変換。

  • 関数 dlfeval および modelGradients を使用してモデルの勾配と損失を評価。

  • 勾配をクリップ。

  • 関数 adamupdate を使用してネットワーク パラメーターを更新。

  • 必要に応じて、この例の最後にリストされている関数 modelPredictions を使用してネットワークを検証。

  • 学習プロットを更新。

学習の進行状況プロットを初期化します。

if plots == "training-progress"
    figure
    
    % Labeling F-Score.
    subplot(2,1,1)
    lineFScoreTrain = animatedline('Color',[0 0.447 0.741]);
    lineFScoreValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 1])
    xlabel("Iteration")
    ylabel("Labeling F-Score")
    grid on
    
    % Loss.
    subplot(2,1,2)
    lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
    lineLossValidation = animatedline( ...
        'LineStyle','--', ...
        'Marker','o', ...
        'MarkerFaceColor','black');
    ylim([0 inf])
    xlabel("Iteration")
    ylabel("Loss")
    grid on
end

Adam オプティマイザーのパラメーターを初期化します。

trailingAvg = [];
trailingAvgSq = [];

検証データを準備します。非ゼロのエントリが各観測値のラベルに対応する、one-hot 符号化された行列を作成します。

numObservationsValidation = numel(documentsValidation);
TValidation = zeros(numClasses, numObservationsValidation, 'single');
for i = 1:numObservationsValidation
    [~,idx] = ismember(labelsValidation{i},classNames);
    TValidation(idx,i) = 1;
end

モデルに学習させます。

iteration = 0;
start = tic;

% Loop over epochs.
for epoch = 1:numEpochs
    
    % Loop over mini-batches.
    for i = 1:numIterationsPerEpoch
        iteration = iteration + 1;
        idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
        
        % Read mini-batch of data and convert the labels to dummy
        % variables.
        documents = documentsTrain(idx);
        labels = labelsTrain(idx);
        
        % Convert documents to sequences.
        len = min(maxSequenceLength,max(doclength(documents)));
        X = doc2sequence(enc,documents, ...
            'PaddingValue',inputSize, ...
            'Length',len);
        X = cat(1,X{:});
        
        % Dummify labels.
        T = zeros(numClasses, miniBatchSize, 'single');
        for j = 1:miniBatchSize
            [~,idx2] = ismember(labels{j},classNames);
            T(idx2,j) = 1;
        end
        
        % Convert mini-batch of data to dlarray.
        dlX = dlarray(X,'BTC');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients, state, and loss using dlfeval and the
        % modelGradients function.
        [gradients,loss,dlYPred] = dlfeval(@modelGradients, dlX, T, parameters);
        
        % Gradient clipping.
        gradients = dlupdate(@(g) thresholdL2Norm(g, gradientThreshold),gradients);
        
        % Update the network parameters using the Adam optimizer.
        [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
            trailingAvg,trailingAvgSq,iteration,learnRate,gradientDecayFactor,squaredGradientDecayFactor);
        
        % Display the training progress.
        if plots == "training-progress"
            subplot(2,1,1)
            D = duration(0,0,toc(start),'Format','hh:mm:ss');
            title("Epoch: " + epoch + ", Elapsed: " + string(D))
            
            % Loss.
            addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
            
            % Labeling F-score.
            YPred = extractdata(dlYPred) > labelThreshold;
            score = labelingFScore(YPred,T);
            addpoints(lineFScoreTrain,iteration,double(gather(score)))
            
            drawnow
            
            % Display validation metrics.
            if iteration == 1 || mod(iteration,validationFrequency) == 0
                dlYPredValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);
                
                % Loss.
                lossValidation = crossentropy(dlYPredValidation,TValidation, ...
                    'TargetCategories','independent', ...
                    'DataFormat','CB');
                addpoints(lineLossValidation,iteration,double(gather(extractdata(lossValidation))))
                
                % Labeling F-score.
                YPredValidation = extractdata(dlYPredValidation) > labelThreshold;
                score = labelingFScore(YPredValidation,TValidation);
                addpoints(lineFScoreValidation,iteration,double(gather(score)))
                
                drawnow
            end
        end
    end
    
    % Shuffle data.
    idx = randperm(numObservationsTrain);
    documentsTrain = documentsTrain(idx);
    labelsTrain = labelsTrain(idx);
end

モデルのテスト

一連の新しいデータの予測を行うには、この例の最後にリストされている関数 modelPredictions を使用します。関数 modelPredictions は、モデル パラメーター、単語符号化、およびトークン化されたドキュメントの配列を入力として受け取り、指定されたミニバッチ サイズと最大シーケンス長に対応するモデルの予測を出力します。

dlYPredValidation = modelPredictions(parameters,enc,documentsValidation,miniBatchSize,maxSequenceLength);

ネットワーク出力をラベルの配列に変換するには、指定されたラベルのしきい値よりも高いスコアのラベルを検索します。

YPredValidation = extractdata(dlYPredValidation) > labelThreshold;

性能を評価するには、この例の最後にリストされている関数 labelingFScore を使用してラベル付け F 値を計算します。ラベル付け F 値は、部分一致をもつテキスト単位の分類に焦点を当てることによって複数ラベルの分類を評価します。

score = labelingFScore(YPredValidation,TValidation)
score = single
    0.5852

しきい値にさまざまな値を試して結果を比較することにより、ラベル付けのしきい値がラベル付け F 値に与える影響を表示します。

thr = linspace(0,1,10);
score = zeros(size(thr));
for i = 1:numel(thr)
    YPredValidationThr = extractdata(dlYPredValidation) >= thr(i);
    score(i) = labelingFScore(YPredValidationThr,TValidation);
end

figure
plot(thr,score)
xline(labelThreshold,'r--');
xlabel("Threshold")
ylabel("Labeling F-Score")
title("Effect of Labeling Threshold")

予測の可視化

分類の正しい予測を可視化するには、真陽性の数を計算します。真陽性は、観測値の特定クラスを正しく予測する分類器のインスタンスです。

Y = YPredValidation;
T = TValidation;

numTruePositives = sum(T & Y,2);

numObservationsPerClass = sum(T,2);
truePositiveRates = numTruePositives ./ numObservationsPerClass;

各クラスの真陽性の数をヒストグラムで可視化します。

figure
[~,idx] = sort(truePositiveRates,'descend');
histogram('Categories',classNames(idx),'BinCounts',truePositiveRates(idx))
xlabel("Category")
ylabel("True Positive Rate")
title("True Positive Rates")

真陽性、偽陽性、および偽陰性の分布を表示して、分類器が誤って予測するインスタンスを可視化します。偽陽性は、観測値に間違った特定クラスを割り当てる分類器のインスタンスを指します。偽陰性は、観測値に正しい特定クラスを割り当てることに失敗する分類器のインスタンスを指します。

真陽性、偽陽性、および偽陰性の数を表す混同行列を作成します。

  • 各クラスについて、真陽性の数を対角線上に表示。

  • クラスの各ペア (ij) について、j が偽陽性でかつ i が偽陰性であるインスタンスの数を表示。

したがって、混同行列の要素は次によって求められます。

TPFNij={numTruePositives(i),if i=jnumFalsePositives(j|i is a false negative),if ijTrue positive, false negative rates

偽陰性と偽陽性を計算します。

falseNegatives = T & ~Y;
falsePositives = ~T & Y;

非対角要素を計算します。

falseNegatives = permute(falseNegatives,[3 2 1]);
numConditionalFalsePositives = sum(falseNegatives & falsePositives, 2);
numConditionalFalsePositives = squeeze(numConditionalFalsePositives);

tpfnMatrix = numConditionalFalsePositives;

対角要素に真陽性の数を設定します。

idxDiagonal = 1:numClasses+1:numClasses^2;
tpfnMatrix(idxDiagonal) = numTruePositives;

関数 confusionchart を使用して混同行列の真陽性と偽陽性の数を可視化し、対角要素が降順になるように行列を並べ替えます。

figure
cm = confusionchart(tpfnMatrix,classNames);
sortClasses(cm,"descending-diagonal");
title("True Positives, False Positives")

行列の詳細を確認するには、この例をライブ スクリプトとして開き、新しいウィンドウで Figure を開きます。

テキスト前処理関数

関数 preprocessText は、入力テキスト データのトークン化と前処理を次の手順で実行します。

  1. 関数 tokenizedDocument を使用してテキストをトークン化します。2 つの "$" 記号の間に現れるテキストを取得する正規表現 "\$.*?\$" を指定し、'RegularExpressions' オプションを使用して数学の方程式を単一のトークンとして抽出します。

  2. 関数 erasePunctuation を使用して句読点を消去します。

  3. 関数 lower を使用してテキストを小文字に変換します。

  4. 関数 removeStopWords を使用してストップ ワードを除去します。

  5. 'Style' オプションを 'lemma' に設定して関数 normalizeWords を使用し、テキストをレンマ化します。

function documents = preprocessText(textData)

% Tokenize the text.
regularExpressions = table;
regularExpressions.Pattern = "\$.*?\$";
regularExpressions.Type = "equation";

documents = tokenizedDocument(textData,'RegularExpressions',regularExpressions);

% Erase punctuation.
documents = erasePunctuation(documents);

% Convert to lowercase.
documents = lower(documents);

% Lemmatize.
documents = addPartOfSpeechDetails(documents);
documents = normalizeWords(documents,'Style','Lemma');

% Remove stop words.
documents = removeStopWords(documents);

% Remove short words.
documents = removeShortWords(documents,2);

end

モデル関数

関数 model は、入力データ dlX とモデル パラメーター parameters を入力として受け取り、ラベルに対する予測を返します。

function dlY = model(dlX,parameters)

% Embedding
weights = parameters.emb.Weights;
dlX = embedding(dlX, weights);

% GRU
inputWeights = parameters.gru.InputWeights;
recurrentWeights = parameters.gru.RecurrentWeights;
bias = parameters.gru.Bias;

numHiddenUnits = size(inputWeights,1)/3;
hiddenState = dlarray(zeros([numHiddenUnits 1]));

dlY = gru(dlX, hiddenState, inputWeights, recurrentWeights, bias,'DataFormat','CBT');

% Max pooling along time dimension
dlY = max(dlY,[],3);

% Fully connect
weights = parameters.fc.Weights;
bias = parameters.fc.Bias;
dlY = fullyconnect(dlY,weights,bias,'DataFormat','CB');

% Sigmoid
dlY = sigmoid(dlY);

end

モデル勾配関数

関数 modelGradients は、入力データ dlX のミニバッチと、ラベルを含む対応するターゲット T を入力として受け取り、学習可能なパラメーターについての損失の勾配、対応する損失、およびネットワーク出力を返します。

function [gradients,loss,dlYPred] = modelGradients(dlX,T,parameters)

dlYPred = model(dlX,parameters);

loss = crossentropy(dlYPred,T,'TargetCategories','independent','DataFormat','CB');

gradients = dlgradient(loss,parameters);

end

モデル予測関数

関数 modelPredictions は、モデル パラメーター、単語符号化、トークン化されたドキュメントの配列、ミニバッチ サイズ、および最大シーケンス長を入力として受け取り、指定サイズのミニバッチを反復することによりモデル予測を返します。

function dlYPred = modelPredictions(parameters,enc,documents,miniBatchSize,maxSequenceLength)

inputSize = enc.NumWords + 1;

numObservations = numel(documents);
numIterations = ceil(numObservations / miniBatchSize);

numFeatures = size(parameters.fc.Weights,1);
dlYPred = zeros(numFeatures,numObservations,'like',parameters.fc.Weights);

for i = 1:numIterations
    
    idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations);
    
    len = min(maxSequenceLength,max(doclength(documents(idx))));
    X = doc2sequence(enc,documents(idx), ...
        'PaddingValue',inputSize, ...
        'Length',len);
    X = cat(1,X{:});
    
    dlX = dlarray(X,'BTC');
    
    dlYPred(:,idx) = model(dlX,parameters);
end

end

ラベル付け F 値の関数

ラベル付け F スコア関数 [2] は、テキストごとの分類の部分一致に焦点を当てて複数ラベル分類を評価します。測定では、真のラベルと予測ラベルの総数に対して一致するラベルの割合を正規化します。次で求められます。

1Nn=1N(2c=1CYncTncc=1C(Ync+Tnc)),Labeling F-Score

ここで、NC はそれぞれ観測値とクラスの数に対応し、YT はそれぞれ予測とターゲットに対応します。

function score = labelingFScore(Y,T)

numObservations = size(T,2);

scores = (2 * sum(Y .* T)) ./ sum(Y + T);
score = sum(scores) / numObservations;

end

Glorot 重み初期化関数

関数 initializeGlorot は、Glorot の初期化に従って、重みの配列を生成します。

function weights = initializeGlorot(numOut, numIn)

varWeights = sqrt( 6 / (numIn + numOut) );
weights = varWeights * (2 * rand([numOut, numIn], 'single') - 1);

end

ガウス重み初期化関数

関数 initializeGaussian は、平均 0、標準偏差 0.01 のガウス分布から重みをサンプリングします。

function parameter = initializeGaussian(sz)

parameter = randn(sz,'single') .* 0.01;

end

埋め込み関数

関数 embedding は、数値インデックスを、入力の重みで与えられた対応するベクトルにマッピングします。

function Z = embedding(X, weights)
% Reshape inputs into a vector.
[N, T] = size(X, 2:3);
X = reshape(X, N*T, 1);

% Index into embedding matrix.
Z = weights(:, X);

% Reshape outputs by separating batch and sequence dimensions.
Z = reshape(Z, [], N, T);
end

L2 ノルム勾配クリップ関数

関数 thresholdL2Norm は、入力の勾配をスケーリングして、学習可能なパラメーターの勾配の L2 ノルムの値が指定のしきい値より大きい場合に L2 ノルムの値が指定の勾配しきい値に等しくなるようにします。

function gradients = thresholdL2Norm(gradients,gradientThreshold)

gradientNorm = sqrt(sum(gradients(:).^2));
if gradientNorm > gradientThreshold
    gradients = gradients * (gradientThreshold / gradientNorm);
end

end

参考文献

  1. arXiv. "arXiv API." Accessed January 15, 2020. https://arxiv.org/help/api

  2. Sokolova, Marina, and Guy Lapalme. "A Sytematic Analysis of Performance Measures for Classification Tasks." Information Processing & Management 45, no. 4 (2009): 427–437.

参考

(Text Analytics Toolbox) | | | | | | | | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox) | (Text Analytics Toolbox)

関連するトピック