長短期記憶ネットワークを使用した ECG 信号の分類
この例では、PhysioNet 2017 Challenge からの心拍心電図 (ECG) データを深層学習と信号処理を使用して分類する方法を示します。特に、この例では長短期記憶ネットワークと時間-周波数解析を使用します。
GPU および Parallel Computing Toolbox™ を使用してこのワークフローの再現と高速化を行う例については、長短期記憶ネットワークを GPU 高速化と組み合わせて使用した ECG 信号の分類 (Signal Processing Toolbox)を参照してください。
はじめに
ECG は、人の心臓の電気的活動を一定期間記録します。医師は ECG を使用して、患者の心拍が正常か正常でないかどうかを視覚的に検出します。
心房細動 (AFib) は、心臓の上部室、心房が下部室、心室と連携せずに脈打っているときに発生する不規則な心拍の一種です。
この例では、PhysioNet 2017 Challenge [1]、[2]、[3] からの ECG データを使用します。これは、https://physionet.org/challenge/2017/ で入手できます。データは、300 Hz でサンプリングされ、専門家の手によって次の 4 つの別々のクラスに分けられた、一連の ECG 信号で構成されています。正常 (N)、AFib (A)、その他の律動 (O)、およびノイズを含む録音 (~)。この例では、深層学習を使用して分類プロセスを自動化する方法を説明します。手順として、AFib の兆候を示す信号から正常な ECG 信号を識別できるバイナリ分類器を調査します。
長短期記憶 (LSTM) ネットワークは、シーケンスおよび時系列のデータの学習に適した再帰型ニューラル ネットワーク (RNN) の一種です。LSTM ネットワークは、シーケンスのタイム ステップ間の長期的な依存関係を学習できます。LSTM 層 (lstmLayer
) では順方向の時間系列を確認でき、双方向の LSTM 層 (bilstmLayer
) では順方向と逆方向の両方の時間系列を確認できます。この例では、双方向の LSTM 層を使用します。
この例では、人工知能 (AI) の問題を解く際にデータ中心の手法を使用することの利点を示します。生データを使用して LSTM ネットワークに学習させる最初の試行では、標準以下の結果しか得られません。抽出した特徴を使用して同じモデル アーキテクチャに学習させると、分類性能が大幅に改善されます。
学習プロセスを高速化するには、GPU を使用するマシン上でこの例を実行します。マシンに GPU と Parallel Computing Toolbox™ がある場合は、MATLAB® は学習に GPU を自動で使用します。それ以外では CPU を使用します。
データの読み込みおよび確認
ReadPhysionetData
スクリプトを実行して PhysioNet Web サイトからデータをダウンロードし、適切な形式の ECG 信号を含む MAT ファイル (PhysionetData.mat
) を生成します。データのダウンロードには数分かかる場合があります。PhysionetData.mat
が現在のフォルダーに既に存在していない場合のみスクリプトを実行する条件ステートメントを使用します。
if ~isfile('PhysionetData.mat') ReadPhysionetData end load PhysionetData
読み込み操作ではワークスペースに 2 つの変数、Signals
と Labels
が追加されます。Signals
は ECG 信号を保持する cell 配列です。Labels
は対応する信号のグラウンド トゥルース ラベルを保持する categorical 配列です。
Signals(1:5)
ans=5×1 cell array
{1×9000 double}
{1×9000 double}
{1×18000 double}
{1×9000 double}
{1×18000 double}
Labels(1:5)
ans = 5×1 categorical
N
N
N
A
A
関数 summary
を使用して、データに含まれる AFib 信号と正常な信号の数を確認します。
summary(Labels)
A 738 N 5050
信号長のヒストグラムを生成します。ほとんどの信号は 9000 サンプルの長さです。
L = cellfun(@length,Signals); h = histogram(L); xticks(0:3000:18000); xticklabels(0:3000:18000); title('Signal Lengths') xlabel('Length') ylabel('Count')
各クラスから 1 つの信号のセグメントを可視化します。通常の心拍は規則的に起こりますが、AFib 心拍は不規則な間隔で起こります。さらに AFib 心拍信号は P 波が欠落することがよくあります。P 波は、正常な心拍信号では QRS 群の前に脈を打ちます。正常な信号のプロットでは P 波と QRS 群が示されます。
normal = Signals{1}; aFib = Signals{4}; subplot(2,1,1) plot(normal) title('Normal Rhythm') xlim([4000,5200]) ylabel('Amplitude (mV)') text(4330,150,'P','HorizontalAlignment','center') text(4370,850,'QRS','HorizontalAlignment','center') subplot(2,1,2) plot(aFib) title('Atrial Fibrillation') xlim([4000,5200]) xlabel('Samples') ylabel('Amplitude (mV)')
学習用データの準備
学習中に、関数 trainNetwork
はデータをミニバッチに分割します。関数は、その後、同じミニバッチ内の信号にパディングや切り捨てを行い、すべて同じ長さになるようにします。パディングや切り捨てをしすぎると、ネットワークのパフォーマンスにマイナスの影響が出ることがあります。これは、ネットワークが追加または削除された情報に基づいて信号を誤って解釈することがあるためです。
過度なパディングや切り捨てを防ぐため、関数 segmentSignals
を ECG 信号に適用してすべてが 9000 サンプルの長さになるようにします。関数は 9000 サンプル未満の信号を無視します。信号が 9000 サンプルを超える場合、segmentSignals
はできるだけ多くの 9000 サンプルのセグメントに分割し、残っているサンプルは無視します。たとえば、18500 サンプルの信号は、2 つの 9000 サンプルの信号になり、残りの 500 サンプルは無視されます。
[Signals,Labels] = segmentSignals(Signals,Labels);
配列 Signals
の最初の 5 つの要素を表示し、各エントリが 9000 サンプルの長さになっていることを確認します。
Signals(1:5)
ans=5×1 cell array
{1×9000 double}
{1×9000 double}
{1×9000 double}
{1×9000 double}
{1×9000 double}
1 番目の試行: 生の信号データを使用した分類器の学習
分類器を設計するには、前のセクションで生成した生の信号を使用します。分類器に学習させるための学習セットと新しいデータに対して分類器の精度をテストするためのテスト セットに信号を分割します。
関数 summary
を使用して、AFib 信号と通常の信号の比率が 718:4937 (約 1:7) であることを示します。
summary(Labels)
A 718 N 4937
約 7/8 の信号が正常であるため、分類器は単純にすべての信号を正常として分類することで高い精度が達成できる、と学習することがあります。このバイアスを回避するには、データセットの AFib 信号を複製することで AFib データを拡張して、正常な信号と AFib 信号とを同じ数にします。一般にオーバーサンプリングと呼ばれるこの複製は、深層学習で使用されるデータ拡張の 1 つの形式です。
信号をそれらのクラスに従って分割します。
afibX = Signals(Labels=='A'); afibY = Labels(Labels=='A'); normalX = Signals(Labels=='N'); normalY = Labels(Labels=='N');
次に、dividerand
を使用して、各クラスからターゲットを学習セットとテスト セットにランダムに分割します。
[trainIndA,~,testIndA] = dividerand(718,0.9,0.0,0.1); [trainIndN,~,testIndN] = dividerand(4937,0.9,0.0,0.1); XTrainA = afibX(trainIndA); YTrainA = afibY(trainIndA); XTrainN = normalX(trainIndN); YTrainN = normalY(trainIndN); XTestA = afibX(testIndA); YTestA = afibY(testIndA); XTestN = normalX(testIndN); YTestN = normalY(testIndN);
こうして、646 の AFib 信号と 4443 の正常な信号が学習用となります。各クラスの信号の数を同じにするには、最初の 4438 の正常な信号を使用してから、repmat
を使用して最初の 634 の AFib 信号を 7 回繰り返します。
テスト用には、72 の AFib 信号と 494 の正常な信号があります。最初の 490 の正常な信号を使用してから、repmat
を使用して最初の 70 の AFib 信号を 7 回繰り返します。既定の設定では、隣り合う信号がすべて同じラベルにならないように、ニューラル ネットワークは学習前にデータをランダムにシャッフルします。
XTrain = [repmat(XTrainA(1:634),7,1); XTrainN(1:4438)]; YTrain = [repmat(YTrainA(1:634),7,1); YTrainN(1:4438)]; XTest = [repmat(XTestA(1:70),7,1); XTestN(1:490)]; YTest = [repmat(YTestA(1:70),7,1); YTestN(1:490);];
正常な信号と AFib 信号の分布はこうして、学習セットとテスト セットの両方で等しく釣り合います。
summary(YTrain)
A 4438 N 4438
summary(YTest)
A 490 N 490
LSTM ネットワーク アーキテクチャの定義
LSTM ネットワークは、シーケンス データのタイム ステップ間の長期的な依存関係を学習できます。この例では、双方向の LSTM 層 bilstmLayer
を使用し、シーケンスを順方向および逆方向の両方で確認します。
入力信号がそれぞれ 1 次元であるため、入力サイズがサイズ 1 のシーケンスになるように指定します。出力サイズが 100 の双方向 LSTM 層を指定し、シーケンスの最後の要素を出力します。このコマンドは双方向の LSTM 層に対し、入力時系列の 100 の特徴へのマッピングを指示し、その後全結合層への出力を準備します。最後に、サイズが 2 の全結合層を含めることによって 2 個のクラスを指定し、その後にソフトマックス層と分類層を配置します。
layers = [ ... sequenceInputLayer(1) bilstmLayer(100,'OutputMode','last') fullyConnectedLayer(2) softmaxLayer classificationLayer ]
layers = 5x1 Layer array with layers: 1 '' Sequence Input Sequence input with 1 dimensions 2 '' BiLSTM BiLSTM with 100 hidden units 3 '' Fully Connected 2 fully connected layer 4 '' Softmax softmax 5 '' Classification Output crossentropyex
次に、分類器の学習オプションを指定します。'MaxEpochs'
を 10 に設定してネットワークが学習データから 10 個のパスを作れるようにします。150 の 'MiniBatchSize'
は、ネットワークに一度に 150 の学習信号の確認を指示します。0.01 の 'InitialLearnRate'
は、学習プロセスを加速させます。1000 の 'SequenceLength'
は、一度に確認するデータ量が多すぎてマシンがメモリ不足にならないように信号をより小さな塊に分割します。'GradientThreshold
' を 1 に設定して、勾配が大きくなりすぎないようにして学習プロセスを安定化させます。'Plots'
を 'training-progress'
として指定し、反復回数の増大に応じた学習の進行状況のグラフィックスを示すプロットを生成します。'Verbose'
を false
に設定し、プロットで示されるデータに対応する表出力を非表示にします。この表を表示する場合は、'Verbose'
を true
に設定します。
この例では、適応モーメント推定 (ADAM) ソルバーを使用します。Adam は、LSTM などの RNN では、既定のモーメンタム項付き確率的勾配降下法 (SGDM) ソルバーよりもパフォーマンスにおいて優れています。
options = trainingOptions('adam', ... 'MaxEpochs',10, ... 'MiniBatchSize', 150, ... 'InitialLearnRate', 0.01, ... 'SequenceLength', 1000, ... 'GradientThreshold', 1, ... 'ExecutionEnvironment',"auto",... 'plots','training-progress', ... 'Verbose',false);
LSTM ネットワークの学習
trainNetwork
を使用し、指定した学習オプションと層のアーキテクチャで LSTM ネットワークに学習させます。学習セットが大きいため、学習プロセスには数分かかる場合があります。
net = trainNetwork(XTrain,YTrain,layers,options);
学習の進行状況プロットの一番上のサブプロットには、学習精度、つまり各ミニバッチの分類精度が表示されます。学習が正常に進行すると、この値は通常 100% へと増加します。一番下のサブプロットには、学習損失、つまり各ミニバッチの交差エントロピー損失が表示されます。学習が正常に進行すると、この値は通常ゼロへと減少します。
学習が収束しない場合、プロットは上方または下方の特定の方向に向かわず、値と値の間で振動する場合があります。この振動は学習精度が向上せず、学習損失が減少していないことを意味します。この状況は学習の最初から発生することもありますし、学習精度においてまずいくらか改善された後でプロットが停滞してしまうこともあります。多くの場合、学習オプションを変更するとネットワークが収束できるようになります。MiniBatchSize
が減少したり、InitialLearnRate
が減少したりすると、学習時間が長くなることがありますが、ネットワークの学習を改善できます。
分類器の学習精度は約 50% ~約 60% で振動し、10 エポックの最後で、既に学習に数分間かかっています。
学習精度とテスト精度の可視化
学習精度を計算します。これは学習した信号に対する分類器の精度を表します。まず、学習データを分類します。
trainPred = classify(net,XTrain,'SequenceLength',1000);
分類問題において、混同行列は真の値が既知である一連のデータに対する分類パフォーマンスの可視化に使用されます。ターゲット クラスは信号のグラウンド トゥルース ラベルで、出力クラスはネットワークによって信号に割り当てられるラベルです。座標軸のラベルは AFib (A) と正常 (N) のクラス ラベルを表します。
confusionchart
コマンドを使用して、テスト データ予測に対する全体の分類精度を計算します。真陽性率と偽陽性率を行要約に表示するため、'row-normalized'
として 'RowSummary'
を指定します。また、陽性の予測値と偽発見率を列要約に表示するため、'column-normalized'
として 'ColumnSummary'
を指定します。
LSTMAccuracy = sum(trainPred == YTrain)/numel(YTrain)*100
LSTMAccuracy = 61.7283
figure confusionchart(YTrain,trainPred,'ColumnSummary','column-normalized',... 'RowSummary','row-normalized','Title','Confusion Chart for LSTM');
テスト データを同じネットワークで分類します。
testPred = classify(net,XTest,'SequenceLength',1000);
テスト精度を計算して、混同行列で分類性能を可視化します。
LSTMAccuracy = sum(testPred == YTest)/numel(YTest)*100
LSTMAccuracy = 66.2245
figure confusionchart(YTest,testPred,'ColumnSummary','column-normalized',... 'RowSummary','row-normalized','Title','Confusion Chart for LSTM');
2 番目の試行: 特徴抽出によるパフォーマンスの改善
データからの特徴抽出によって、分類器の学習精度とテスト精度を向上させることができます。抽出する特徴を決定するために、この例では、スペクトログラムなどの時間-周波数イメージを計算するアプローチを適応させて、それらを使用して畳み込みニューラル ネットワーク (CNN) に学習させます [4]、[5]。
各タイプの信号のスペクトログラムを可視化します。
fs = 300; figure subplot(2,1,1); pspectrum(normal,fs,'spectrogram','TimeResolution',0.5) title('Normal Signal') subplot(2,1,2); pspectrum(aFib,fs,'spectrogram','TimeResolution',0.5) title('AFib Signal')
この例では、CNN の代わりに LSTM を使用するため、1 次元信号に作用するようにアプローチを変換することが重要です。時間-周波数 (TF) モーメントはスペクトログラムから情報を抽出します。各モーメントは 1 次元の特徴として使用し、LSTM に入力することができます。
時間領域の次の 2 つの TF モーメントを調査します。
瞬時周波数 (
instfreq
)スペクトル エントロピー (
pentropy
)
関数 instfreq
は、パワー スペクトログラムの最初のモーメントとして信号の時間依存周波数を推定します。関数は時間枠上の短時間フーリエ変換を使用してスペクトログラムを計算します。この例では、関数は 255 個の時間枠を使用します。関数の時間出力は、時間枠の中心に対応します。
各タイプの信号の瞬時周波数を可視化します。
[instFreqA,tA] = instfreq(aFib,fs); [instFreqN,tN] = instfreq(normal,fs); figure subplot(2,1,1); plot(tN,instFreqN) title('Normal Signal') xlabel('Time (s)') ylabel('Instantaneous Frequency') subplot(2,1,2); plot(tA,instFreqA) title('AFib Signal') xlabel('Time (s)') ylabel('Instantaneous Frequency')
cellfun
を使用して、学習セットとテスト セットの各セルに関数 instfreq
を適用します。
instfreqTrain = cellfun(@(x)instfreq(x,fs)',XTrain,'UniformOutput',false); instfreqTest = cellfun(@(x)instfreq(x,fs)',XTest,'UniformOutput',false);
スペクトル エントロピーは、信号のスペクトルがどの程度スパイキーでフラットであるかを計測します。正弦波の和などのスパイキー スペクトルの信号はスペクトル エントロピーが低くなります。ホワイト ノイズなどのフラット スペクトルの信号はスペクトル エントロピーが高くなります。関数 pentropy
は、パワー スペクトログラムに基づいてスペクトル エントロピーを推定します。瞬時周波数推定の場合と同様に、pentropy
では 255 個の時間枠を使用してスペクトログラムを計算します。関数の時間出力は、時間枠の中心に対応します。
各タイプの信号のスペクトル エントロピーを可視化します。
[pentropyA,tA2] = pentropy(aFib,fs); [pentropyN,tN2] = pentropy(normal,fs); figure subplot(2,1,1) plot(tN2,pentropyN) title('Normal Signal') ylabel('Spectral Entropy') subplot(2,1,2) plot(tA2,pentropyA) title('AFib Signal') xlabel('Time (s)') ylabel('Spectral Entropy')
cellfun
を使用して、学習セットとテスト セットの各セルに関数 pentropy
を適用します。
pentropyTrain = cellfun(@(x)pentropy(x,fs)',XTrain,'UniformOutput',false); pentropyTest = cellfun(@(x)pentropy(x,fs)',XTest,'UniformOutput',false);
新しい学習セットとテスト セットの各セルが 2 次元、つまり 2 つの特徴をもつように特徴を連結します。
XTrain2 = cellfun(@(x,y)[x;y],instfreqTrain,pentropyTrain,'UniformOutput',false); XTest2 = cellfun(@(x,y)[x;y],instfreqTest,pentropyTest,'UniformOutput',false);
新しい入力の形式を可視化します。各セルには既に、1 つの 9000 サンプル長の信号は含まれまていません。ここでは 2 つの 255 サンプル長の特徴が含まれます。
XTrain2(1:5)
ans=5×1 cell array
{2×255 double}
{2×255 double}
{2×255 double}
{2×255 double}
{2×255 double}
データの標準化
瞬時周波数とスペクトル エントロピーの平均値には、ほぼ 1 桁の差があります。さらに、瞬時周波数の平均値は LSTM には高すぎるため効果的に学習することができない可能性があります。平均値が大きくて値の範囲が広いデータに対してネットワークが近似される場合、入力量が大きいとネットワークの学習と収束の速度が低下します [6]。
mean(instFreqN)
ans = 5.5615
mean(pentropyN)
ans = 0.6326
学習セットの平均値と標準偏差を使用して学習セットとテスト セットを標準化します。標準化、つまり z スコアは学習中のネットワーク性能を向上させる一般的な方法です。
XV = [XTrain2{:}]; mu = mean(XV,2); sg = std(XV,[],2); XTrainSD = XTrain2; XTrainSD = cellfun(@(x)(x-mu)./sg,XTrainSD,'UniformOutput',false); XTestSD = XTest2; XTestSD = cellfun(@(x)(x-mu)./sg,XTestSD,'UniformOutput',false);
標準化した瞬時周波数とスペクトル エントロピーの平均値を表示します。
instFreqNSD = XTrainSD{1}(1,:); pentropyNSD = XTrainSD{1}(2,:); mean(instFreqNSD)
ans = -0.3211
mean(pentropyNSD)
ans = -0.2416
LSTM ネットワーク アーキテクチャの変更
信号にはそれぞれ 2 つの次元があるため、入力シーケンス サイズに 2 を指定してネットワーク アーキテクチャを変更する必要があります。出力サイズが 100 の双方向 LSTM 層を指定し、シーケンスの最後の要素を出力します。サイズが 2 の全結合層を含めることによって 2 個のクラスを指定し、その後にソフトマックス層と分類層を配置します。
layers = [ ... sequenceInputLayer(2) bilstmLayer(100,'OutputMode','last') fullyConnectedLayer(2) softmaxLayer classificationLayer ]
layers = 5x1 Layer array with layers: 1 '' Sequence Input Sequence input with 2 dimensions 2 '' BiLSTM BiLSTM with 100 hidden units 3 '' Fully Connected 2 fully connected layer 4 '' Softmax softmax 5 '' Classification Output crossentropyex
学習オプションを指定します。エポックの最大回数を 30 に設定すると、ネットワークが学習データから 30 個のパスを作れるようになります。
options = trainingOptions('adam', ... 'MaxEpochs',30, ... 'MiniBatchSize', 150, ... 'InitialLearnRate', 0.01, ... 'GradientThreshold', 1, ... 'ExecutionEnvironment',"auto",... 'plots','training-progress', ... 'Verbose',false);
時間-周波数の特徴を使用した LSTM ネットワークの学習
trainNetwork
を使用し、指定した学習オプションと層のアーキテクチャで LSTM ネットワークに学習させます。
net2 = trainNetwork(XTrainSD,YTrain,layers,options);
学習精度が大きく改善しています。交差エントロピーの損失は 0 に近くなっています。さらに、TF モーメントが生のシーケンスより短いため、学習に必要な時間が短くなっています。
学習精度とテスト精度の可視化
更新した LSTM ネットワークを使用して学習データを分類します。混同行列として分類性能を可視化します。
trainPred2 = classify(net2,XTrainSD); LSTMAccuracy = sum(trainPred2 == YTrain)/numel(YTrain)*100
LSTMAccuracy = 83.5962
figure confusionchart(YTrain,trainPred2,'ColumnSummary','column-normalized',... 'RowSummary','row-normalized','Title','Confusion Chart for LSTM');
テスト データを更新したネットワークで分類します。混同行列をプロットして、テスト精度を調べます。
testPred2 = classify(net2,XTestSD); LSTMAccuracy = sum(testPred2 == YTest)/numel(YTest)*100
LSTMAccuracy = 80.1020
figure confusionchart(YTest,testPred2,'ColumnSummary','column-normalized',... 'RowSummary','row-normalized','Title','Confusion Chart for LSTM');
まとめ
この例では、分類器をビルドし、LSTM ネットワークを使用して ECG 信号の心房細動を検出する方法を示します。手順では、オーバーサンプリングを使用して、大半が健康な患者で構成される母集団の中で異常な状態を検出しようとするときに発生する分類バイアスを回避します。生の信号データを使用した LSTM ネットワークの学習では良い分類精度は得られません。各信号で 2 つの時間-周波数モーメントの特徴を使用してネットワークに学習させると、分類性能は大幅に改善されて学習時間も短くなります。
参考文献
[1] AF Classification from a Short Single Lead ECG Recording: the PhysioNet/Computing in Cardiology Challenge, 2017. https://physionet.org/challenge/2017/
[2] Clifford, Gari, Chengyu Liu, Benjamin Moody, Li-wei H. Lehman, Ikaro Silva, Qiao Li, Alistair Johnson, and Roger G. Mark."AF Classification from a Short Single Lead ECG Recording: The PhysioNet Computing in Cardiology Challenge 2017." Computing in Cardiology (Rennes: IEEE). Vol. 44, 2017, pp. 1–4.
[3] Goldberger, A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch. Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K. Peng, and H. E. Stanley. "PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals". Circulation.Vol. 101, No. 23, 13 June 2000, pp. e215–e220. http://circ.ahajournals.org/content/101/23/e215.full
[4] Pons, Jordi, Thomas Lidy, and Xavier Serra."Experimenting with Musically Motivated Convolutional Neural Networks".14th International Workshop on Content-Based Multimedia Indexing (CBMI).June 2016.
[5] Wang, D. "Deep learning reinvents the hearing aid," IEEE Spectrum, Vol. 54, No. 3, March 2017, pp. 32–37. doi: 10.1109/MSPEC.2017.7864754.
[6] Brownlee, Jason.How to Scale Data for Long Short-Term Memory Networks in Python.7 July 2017. https://machinelearningmastery.com/how-to-scale-data-for-long-short-term-memory-networks-in-python/.
参考
関数
instfreq
(Signal Processing Toolbox) |pentropy
(Signal Processing Toolbox) |trainingOptions
|trainNetwork
|bilstmLayer
|lstmLayer