メインコンテンツ

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

Grad-CAM を使用した深層学習時系列分類の解釈

この例では、勾配加重クラス活性化マッピング (Grad-CAM) 手法を使用して、時系列データで学習させた 1 次元畳み込みニューラル ネットワークによる分類判定を理解する方法を説明します。

Grad-CAM [1] では、ネットワークにより決定する畳み込みの特徴についての分類スコアの勾配を使用して、データのどの部分が分類に最も重要であるかを理解します。時系列データの場合、Grad-CAM はネットワークの分類判定にとって最も重要なタイム ステップを計算します。

このイメージは、Grad-CAM 重要度カラーマップを使用したシーケンスの例を示しています。マップでは、ネットワークが分類判定を行うために使用する領域が強調表示されています。

この例では、ラベル付きデータに対する教師あり学習を使用して、時系列データを "正常" または "センサー障害" として分類します。自己符号化器ネットワークを使用して、ラベルなしデータに対する時系列異常検出を実行することもできます。詳細については、深層学習を使用した時系列異常検出を参照してください。

波形データの読み込み

WaveformData.mat から波形データ セットを読み込みます。このデータ セットには、異なる波長の合成生成波形が格納されています。各波形には 3 つのチャネルがあります。

rng("default")
load WaveformData

numChannels = size(data{1},2);
numObservations = numel(data);

最初のいくつかのシーケンスをプロットに可視化します。

figure
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{i},DisplayLabels="Channel "+(1:numChannels));
    title("Observation "+i)
    xlabel("Time Step")
end

Figure contains objects of type stackedplot. The chart of type stackedplot has title Observation 1. The chart of type stackedplot has title Observation 2. The chart of type stackedplot has title Observation 3. The chart of type stackedplot has title Observation 4.

センサー障害のシミュレーション

いくつかのシーケンスを手動で編集して、センサーの障害をシミュレートし、新しいデータ セットを作成します。

変更されていないデータのコピーを作成します。

dataUnmodified = data;

変更するシーケンスの 10% をランダムに選択します。

failureFraction = 0.1;

numFailures = round(numObservations*failureFraction);
failureIdx = randperm(numel(data),numFailures);

センサー障害をシミュレートするために、高さ 0.25 ~ 2 の小さな付加的な異常を導入します。各異常はシーケンス内のランダムな位置で発生し、4 ~ 20 のタイム ステップで発生します。

anomalyHeight = [0.25 2];
anomalyPatchSize = [4 20];

anomalyHeightRange = anomalyHeight(2) - anomalyHeight(1);

シーケンスを変更します。

failureLocation = cell(size(data));

for i = 1:numFailures
    X = data{failureIdx(i)};

    % Generate sensor failure location.
    patchLength = randi(anomalyPatchSize,1);
    patchStart = randi(length(X)-patchLength);
    idxPatch = patchStart:(patchStart+patchLength);

    % Generate anomaly height. 
    patchExtraHeight = anomalyHeight(1) + anomalyHeightRange*rand;
    X(idxPatch,:) = X(idxPatch,:) + patchExtraHeight;
    
    % Save modified sequence.
    data{failureIdx(i)} = X;

    % Save failure location.
    failureLocation{failureIdx(i)} = idxPatch;
end

変更されていないシーケンスのクラス ラベルを Normal に設定します。変更されたシーケンスのクラス ラベルを Sensor Failure に設定します。

labels = repmat("Normal",numObservations,1);
labels(failureIdx) = "Sensor Failure";
labels = categorical(labels);

ヒストグラムを使用してクラス ラベルの分布を可視化します。

figure
histogram(labels)

Figure contains an axes object. The axes object contains an object of type categoricalhistogram.

センサー障害の可視化

変更されたシーケンスの一部を元のシーケンスと比較します。破線はセンサー障害の領域を示しています。

numFailuresToShow = 2;

for i=1:numFailuresToShow
    figure
    t = tiledlayout(numChannels,1);
    idx = failureIdx(i);

    modifiedSignal = data{idx};
    originalSignal = dataUnmodified{idx};

    for j = 1:numChannels
        nexttile
       
        plot(modifiedSignal(:,j))
        hold on
        plot(originalSignal(:,j))

        ylabel("Channel " + j)
        xlabel("Time Step")

        xline(failureLocation{idx}(1),":")
        xline(failureLocation{idx}(end),":")
        hold off
    end
    
    title(t,"Observation "+failureIdx(i))
    legend("Modified","Original", ...
        Location="southoutside", ...
        NumColumns=2)
end

Figure contains 3 axes objects. Axes object 1 with xlabel Time Step, ylabel Channel 1 contains 4 objects of type line, constantline. Axes object 2 with xlabel Time Step, ylabel Channel 2 contains 4 objects of type line, constantline. Axes object 3 with xlabel Time Step, ylabel Channel 3 contains 4 objects of type line, constantline. These objects represent Modified, Original.

Figure contains 3 axes objects. Axes object 1 with xlabel Time Step, ylabel Channel 1 contains 4 objects of type line, constantline. Axes object 2 with xlabel Time Step, ylabel Channel 2 contains 4 objects of type line, constantline. Axes object 3 with xlabel Time Step, ylabel Channel 3 contains 4 objects of type line, constantline. These objects represent Modified, Original.

センサー障害に対応する異常な区画を除いて、変更された信号と元の信号は一致します。

データの準備

データを学習セットと検証セットに分割し、学習用のデータを準備します。データの 90% を学習に使用し、データの 10% を検証に使用する。

trainFraction = 0.9;
idxTrain = 1:floor(trainFraction*numObservations);
idxValidation = (idxTrain(end)+1):numObservations;

XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XValidation = data(idxValidation);
TValidation = labels(idxValidation);
failureLocationValidation = failureLocation(idxValidation);

ネットワーク アーキテクチャの定義

1 次元畳み込みニューラル ネットワーク アーキテクチャを定義します。

  • 入力データのチャネル数と一致する入力サイズのシーケンス入力層を使用します。

  • 畳み込み層のフィルター サイズが 3 である 1 次元畳み込み層、ReLU 層、およびレイヤー正規化層から成るブロックを 2 つ指定します。32 個のフィルターと 64 個のフィルターを最初と 2 番目の畳み込み層にそれぞれ指定します。どちらの畳み込み層に対しても、出力の長さが同じになるように入力を左パディングします (因果的パディング)。

  • 畳み込み層の出力を単一のベクトルに減らすために、1 次元グローバル平均プーリング層を使用します。

  • 出力を確率のベクトルにマッピングするため、クラス数と一致する出力サイズをもつ全結合層を含め、その後にソフトマックス層を含めます。

classes = categories(TTrain);
numClasses = numel(classes);

filterSize = 3;
numFilters = 32;

layers = [ ...
    sequenceInputLayer(numChannels)
    convolution1dLayer(filterSize,numFilters,Padding="causal")
    reluLayer
    layerNormalizationLayer(OperationDimension="batch-excluded")
    convolution1dLayer(filterSize,2*numFilters,Padding="causal")
    reluLayer
    layerNormalizationLayer(OperationDimension="batch-excluded")
    globalAveragePooling1dLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

学習オプションの指定

適応モーメント推定 (Adam) を使用してネットワークに学習させます。最大エポック数を 15 に設定し、ミニバッチのサイズ 27 を使用します。ミニバッチ内のすべてのシーケンスが同じ長さになるように左パディングします。検証データを使用して、学習中にネットワークを検証します。学習の進行状況をプロットで表示し、精度を監視します。詳細出力を非表示にします。

miniBatchSize = 27;

options = trainingOptions("adam", ...
    MiniBatchSize=miniBatchSize, ...
    MaxEpochs=15, ...
    SequencePaddingDirection="left", ...
    ValidationData={XValidation,TValidation}, ...
    Metrics="accuracy", ...
    Plots="training-progress", ...
    Verbose=false);

ネットワークの学習

trainnet 関数を使用して、指定されたオプションで畳み込みネットワークに学習させます。

net = trainnet(XTrain,TTrain,layers,"crossentropy",options);

ネットワークのテスト

検証イメージを分類します。複数の観測値を使用して予測を行うには、関数 minibatchpredict を使用します。予測スコアをラベルに変換するには、関数 scores2label を使用します。関数 minibatchpredict は利用可能な GPU がある場合に自動的にそれを使用します。GPU を使用するには、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数は CPU を使用します。

scores = minibatchpredict(net,XValidation, ...
    MiniBatchSize=miniBatchSize, ...
    SequencePaddingDirection="left");
YValidation = scores2label(scores,classes);

予測の分類精度を計算します。

accuracy = mean(YValidation == TValidation)
accuracy = 
0.9800

混同行列で予測を可視化します。

figure
confusionchart(TValidation,YValidation)

Figure contains an object of type ConfusionMatrixChart.

Grad-CAM を使用した分類結果の解釈

Grad-CAM を使用して、ネットワークが分類判定に使用するシーケンスの部分を可視化します。

ネットワークが "センサー障害" として正しく分類するシーケンスのサブセットを見つけます。

numFailuresToShow = 2;

isCorrect = TValidation == "Sensor Failure" & YValidation == "Sensor Failure";
idxValidationFailure = find(isCorrect,numFailuresToShow);

観測値ごとに、Grad-CAM マップを計算して可視化します。Grad-CAM 重要度マップを計算するには、gradCAMを使用します。この例の最後に定義されている plotWithColorGradient 補助関数を使用して、Grad-CAM の重要度を表すカラーマップを表示します。センサー障害の実際の位置を示す破線を追加します。

for i = 1:numFailuresToShow
    figure
    t = tiledlayout(numChannels,1);
    idx = idxValidationFailure(i);

    modifiedSignal = XValidation{idx};
    channel = find("Sensor Failure" == classes);
    importance = gradCAM(net,modifiedSignal',channel);

    for j = 1:numChannels
        nexttile
        plotWithColorGradient(modifiedSignal(:,j),importance');

        ylabel("Channel "+j)
        xlabel("Time Steps")

        if ~isempty(failureLocationValidation{idx})
            xline(failureLocationValidation{idx}(1),":")
            xline(failureLocationValidation{idx}(end),":")
        end
    end
    
    title(t,"Grad-CAM: Validation Observation "+idx)

    c = colorbar;
    c.Layout.Tile = "east";
    c.Label.String = "Grad-CAM Importance";
end

Figure contains 3 axes objects. Axes object 1 with xlabel Time Steps, ylabel Channel 1 contains 3 objects of type patch, constantline. Axes object 2 with xlabel Time Steps, ylabel Channel 2 contains 3 objects of type patch, constantline. Axes object 3 with xlabel Time Steps, ylabel Channel 3 contains 3 objects of type patch, constantline.

Figure contains 3 axes objects. Axes object 1 with xlabel Time Steps, ylabel Channel 1 contains 3 objects of type patch, constantline. Axes object 2 with xlabel Time Steps, ylabel Channel 2 contains 3 objects of type patch, constantline. Axes object 3 with xlabel Time Steps, ylabel Channel 3 contains 3 objects of type patch, constantline.

Grad-CAM マップは、ネットワークがシーケンスにおけるセンサー障害の領域を正しく使用して分類判定を行っていることを示しています。正しい領域を使用していることは、ネットワークが正常なデータと障害のあるデータの識別方法を学習していることを示しています。ネットワークは、誤った背景特徴ではなく、故障を判定基準として使用しています。

Grad-CAM を使用した誤分類の調査

Grad-CAM を使用して、誤分類されたシーケンスを調査することもできます。

ネットワークが "正常" と誤分類するセンサー障害シーケンスのサブセットを見つけます。

numFailuresToShow = 2;
isIncorrect = TValidation == "Sensor Failure" & YValidation == "Normal";
idxValidationFailure = find(isIncorrect,numFailuresToShow);

誤分類ごとに、Grad-CAM マップを計算して可視化します。誤分類されたセンサー障害シーケンスについて、Grad-CAM マップは、ネットワークが障害領域を見つけていることを示しています。しかし、正しく分類されたシーケンスとは異なり、ネットワークは分類判定のために障害領域全体を使用していません。

for i = 1:length(idxValidationFailure)
    figure
    t = tiledlayout(numChannels,1);
    idx = idxValidationFailure(i);

    modifiedSignal = XValidation{idx};
    channel = find("Sensor Failure" == classes);
    importance = gradCAM(net,modifiedSignal',channel);

    for j = 1:numChannels
        nexttile
        plotWithColorGradient(modifiedSignal(:,j),importance');

        ylabel("Channel "+j)
        xlabel("Time Steps")

        if ~isempty(failureLocationValidation{idx})
            xline(failureLocationValidation{idx}(1),":")
            xline(failureLocationValidation{idx}(end),":")
        end
    end

    title(t,"Grad-CAM: Validation Observation "+idx)

    c = colorbar;
    c.Layout.Tile = "east";
    c.Label.String = "Grad-CAM Importance";
end

Figure contains 3 axes objects. Axes object 1 with xlabel Time Steps, ylabel Channel 1 contains 3 objects of type patch, constantline. Axes object 2 with xlabel Time Steps, ylabel Channel 2 contains 3 objects of type patch, constantline. Axes object 3 with xlabel Time Steps, ylabel Channel 3 contains 3 objects of type patch, constantline.

Figure contains 3 axes objects. Axes object 1 with xlabel Time Steps, ylabel Channel 1 contains 3 objects of type patch, constantline. Axes object 2 with xlabel Time Steps, ylabel Channel 2 contains 3 objects of type patch, constantline. Axes object 3 with xlabel Time Steps, ylabel Channel 3 contains 3 objects of type patch, constantline.

補助関数

plotWithColorGradient 関数は、単一チャネルをもつシーケンスと、シーケンスと同じ数のタイム ステップをもつ重要度マップを入力として受け取ります。この関数は重要度マップを使用して、シーケンスのセグメントに色を付けます。

その区画が閉じた多角形ではなく線を作成するように、yc の最後のエントリを NaN に設定します。

function plotWithColorGradient(sequence,importance)

x = 1:size(sequence,1) + 1;
y = [sequence; NaN];
c = [importance; NaN];

patch(x,y,c,EdgeColor="interp");
end

[1] Selvaraju, Ramprasaath R., Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, and Dhruv Batra. "Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization." International Journal of Computer Vision 128, no. 2 (February 2020): 336–59. https://doi.org/10.1007/s11263-019-01228-7.

参考

| | |

トピック