ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

ウェーブレット解析と深層学習を使用した時系列の分類

この例では、連続ウェーブレット変換 (CWT) と深層畳み込みニューラル ネットワーク (CNN) を使用した人間の心電図 (ECG) 信号を分類する方法を説明します。

深層 CNN をゼロから学習させるには、大量の計算と大量の学習データが必要です。さまざまなアプリケーションでは、十分な量の学習データが利用可能ではなく、新しい現実的な学習の例の合成は不可能です。これらの場合、概念的に類似したタスクについて大規模なデータセットで学習されている既存のニューラル ネットワークを利用することが望まれます。既存のニューラル ネットワークの利用は転移学習と呼ばれます。この例では、イメージ認識用に事前学習済みの、2 つの深層 CNN、GoogLeNet および AlexNet を適応させて、時間-周波数表現を基にした ECG 波形を分類します。

GoogLeNet と AlexNet は、1000 カテゴリにイメージを分類するために最初に設計された深層 CNN です。時系列データの CWT からのイメージを基にした ECG 信号を分類するために CNN のネットワーク アーキテクチャを再利用します。この例を実行するには、Wavelet Toolbox™、Image Processing Toolbox™、Deep Learning Toolbox™、Deep Learning Toolbox™ Model for GoogLeNet Network サポート パッケージ、および Deep Learning Toolbox™ Model for AlexNet Network サポート パッケージが必要です。サポート パッケージを検索してインストールするには、MATLAB™ アドオン エクスプローラーを使用します。この例内のオプションは、学習プロセスを CPU で実行するために設定されます。マシンに GPU と Parallel Computing Toolbox™ がある場合は、オプションの設定によって学習プロセスを高速化して GPU 上で実行できます。この例で使用されているデータは、PhysioNet から公的に入手可能です。

データの説明

この例では、人の 3 つのグループから取得された ECG データを使用します。3 つのグループとは、心不整脈の患者 (ARR)、鬱血性心不全の患者 (CHF)、および正常洞調律の患者 (NSR) です。全体では、この例では、3 つの PhysioNet データベース (MIT-BIH Arrhythmia Database [3][7]、MIT-BIH Normal Sinus Rhythm Database [3]、BIDMC Congestive Heart Failure Database [1][3]) から 162 個の ECG 記録を使用します。具体的には、不整脈の患者の記録は 96 個、鬱血性心不全の患者の記録は 30 個、正常洞調律の患者の記録は 36 個あります。目標は、ARR、CHF、および NSR 間で分類器を学習させて区別することです。

データのダウンロード

1 番目のステップは、GitHub リポジトリからデータをダウンロードすることです。データを Web サイトからダウンロードするには、[Clone or download] をクリックして [Download ZIP] を選択します。書き込み権限のあるフォルダーに、ファイル physionet_ECG_data-master.zip を保存します。この例の手順では、ファイルを一時ディレクトリ (MATLAB の tempdir) にダウンロードしているものと仮定します。tempdir とは異なるフォルダーにデータをダウンロードすることを選択した場合は、データの解凍および読み込みに関する後続の手順を変更してください。Git に精通している場合は、最新バージョンのツール (git) をダウンロードし、git clone https://github.com/mathworks/physionet_ECG_data/ を使用してシステム コマンド プロンプトからデータを取得できます。

データを GitHub からダウンロードした後、一時ディレクトリでそのファイルを解凍します。

unzip(fullfile(tempdir,'physionet_ECG_data-master.zip'),tempdir)

解凍すると、一時ディレクトリにフォルダー physionet-ECG_data-master が作成されます。このフォルダーには、テキスト ファイル README.md と ECGData.zip が含まれます。ECGData.zip ファイルには次のものが含まれています

  • ECGData.mat

  • Modified_physionet_data.txt

  • License.txt

ECGData.mat は、この例で使用されるデータを保持します。.txt ファイルの Modified_physionet_data.txt は PhysioNet のコピー ポリシーで必要になり、データのソース属性、および ECG の各記録に適用される前処理手順の説明を提供します。

physionet-ECG_data-master の ECGData.zip を解凍します。データ ファイルを MATLAB ワークスペースに読み込みます。

unzip(fullfile(tempdir,'physionet_ECG_data-master','ECGData.zip'),...
    fullfile(tempdir,'physionet_ECG_data-master'))
load(fullfile(tempdir,'physionet_ECG_data-master','ECGData.mat'))

ECGData は、2 つのフィールド (Data および Labels) をもつ構造体配列です。Data フィールドは 162 行 65,536 列の行列で、各行は 128 Hz でサンプリングした ECG 記録です。Labels は 162 行 1 列の診断ラベルの cell 配列で、それぞれが Data の各行に対応します。3 つの診断カテゴリは、'ARR'、'CHF'、および 'NSR' です。

各カテゴリの前処理したデータを保存するには、最初に tempdir 内に ECG データ ディレクトリ dataDir を作成します。その後、各 ECG カテゴリに由来した 'data' に 3 つのサブディレクトリを作成します。補助関数 helperCreateECGDirectories がこれを実行します。helperCreateECGDirectories は、ECGData、ECG データ ディレクトリの名前、親ディレクトリの名前を入力引数として受け入れます。tempdir を書き込み権限のある別のディレクトリと置き換えることができます。この補助関数のソース コードは、この例の最後にある「サポート関数」の節で見つけることができます。

parentDir = tempdir;
dataDir = 'data';
helperCreateECGDirectories(ECGData,parentDir,dataDir)

各 ECG カテゴリの表現をプロットします。補助関数 helperPlotReps がこれを実行します。helperPlotRepsECGData を入力として受け入れます。この補助関数のソース コードは、この例の最後にある「サポート関数」の節で見つけることができます。

helperPlotReps(ECGData)

時間-周波数表現の作成

フォルダー作成後、ECG 信号の時間-周波数表現を作成します。これらの表現はスカログラムと呼ばれます。スカログラムは、信号の CWT 係数の絶対値です。

スカログラムを作成するには、CWT フィルター バンクを事前に計算します。CWT フィルター バンクの事前計算は、同じパラメーターを使用して多数の信号の CWT を取得するときの推奨方法です。

スカログラムを生成する前に、そのうちの 1 つを調べます。1,000 サンプルをもつ信号にcwtfilterbankを使用して CWT フィルター バンクを作成します。フィルター バンクを使用して、信号の最初の 1,000 サンプルの CWT を取って、係数からスカログラムを取得します。

Fs = 128;
fb = cwtfilterbank('SignalLength',1000,...
    'SamplingFrequency',Fs,...
    'VoicesPerOctave',12);
sig = ECGData.Data(1,1:1000);
[cfs,frq] = wt(fb,sig);
t = (0:999)/Fs;figure;pcolor(t,frq,abs(cfs))
set(gca,'yscale','log');shading interp;axis tight;
title('Scalogram');xlabel('Time (s)');ylabel('Frequency (Hz)')

補助関数 helperCreateRGBfromTF を使用して、スカログラムを RGB イメージとして作成し、それらを dataDir の適切なサブディレクトリに書き込みます。この補助関数のソース コードは、この例の最後にある「サポート関数」の節にあります。GoogLeNet アーキテクチャと互換性をもたせるために、各 RGB イメージは 224 x 224 x 3 のサイズの配列になります。

helperCreateRGBfromTF(ECGData,parentDir,dataDir)

学習データと検証データへの分割

スカログラム イメージをイメージ データストアとして読み込みます。関数 imageDatastore は、フォルダー名に基づいてイメージに自動的にラベルを付け、データを ImageDatastore オブジェクトとして格納します。イメージ データストアを使用すると、メモリに収まらないデータなどの大きなイメージ データを格納し、CNN の学習中にイメージをバッチ単位で効率的に読み取ることができます。

allImages = imageDatastore(fullfile(parentDir,dataDir),...
    'IncludeSubfolders',true,...
    'LabelSource','foldernames');

イメージを 2 つのグループ (学習用と検証用) にランダムに分割します。イメージの 80% を学習に使用し、残りを検証に使用します。再現性を得るために、乱数シードを既定値に設定します。

rng default
[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,'randomized');
disp(['Number of training images: ',num2str(numel(imgsTrain.Files))]);
disp(['Number of validation images: ',num2str(numel(imgsValidation.Files))]);
Number of training images: 130
Number of validation images: 32

GoogLeNet

読み込み

事前学習済みの GoogLeNet ニューラル ネットワークを読み込みます。Deep Learning Toolbox™ Model for GoogLeNet Network サポート パッケージがインストールされていない場合、ソフトウェアによってアドオン エクスプローラーに必要なサポート パッケージへのリンクが表示されます。サポート パッケージをインストールするには、リンクをクリックして、[インストール] をクリックします。

net = googlenet;

ネットワークから層グラフを抽出して、層グラフをプロットします。

lgraph = layerGraph(net);
numberOfLayers = numel(lgraph.Layers);
figure('Units','normalized','Position',[0.1 0.1 0.8 0.8]);
plot(lgraph)
title(['GoogLeNet Layer Graph: ',num2str(numberOfLayers),' Layers']);

ネットワーク層プロパティの 1 番目の要素を検査します。GoogLeNet にはサイズ 224 x 224 x 3 の RGB イメージが必要であることに注意してください。

net.Layers(1)
ans = 

  ImageInputLayer with properties:

                Name: 'data'
           InputSize: [224 224 3]

   Hyperparameters
    DataAugmentation: 'none'
       Normalization: 'zerocenter'
        AverageImage: [224×224×3 single]

GoogLeNet ネットワーク パラメーターの変更

ネットワーク アーキテクチャの各層はフィルターと見なすことができます。初期の層は、ブロブ、エッジ、および色など、より一般的なイメージの特徴を識別します。後の層は、カテゴリを区別するためにより具体的な特徴に焦点を当てます。

ECG 分類問題に GoogLeNet を再学習するには、ネットワークの最後の 4 つの層を置き換えます。4 つの層の最初の層、'pool5-drop_7x7_s1' は、ドロップアウト層です。ドロップアウト層は、与えられた確率でランダムに、入力要素をゼロに設定します。ドロップアウト層は過適合を防止するために使用されます。既定の確率は 0.5 です。詳細は、dropoutLayerを参照してください。残りの 3 つの層 'loss3-classifier'、'prob' および 'output' は、ネットワークがクラス確率とラベルに抽出する特徴を組み合わせる方法に関する情報を含んでいます。既定では、最後の 3 つの層は、1000 個のカテゴリに対して構成されています。

新しい 4 つの層 (ドロップアウトの確率が 60% のドロップアウト層、全結合層、ソフトマックス層、および分類出力層) を層グラフに追加します。最後の全結合層のサイズが新しいデータセットのクラスの数 (この例では 3) と同じになるように設定します。新しい層での学習速度を転移された層より速くするには、全結合層の学習率係数を大きくします。GoogLeNet イメージ次元を inputSize に保存します。

lgraph = removeLayers(lgraph,{'pool5-drop_7x7_s1','loss3-classifier','prob','output'});

numClasses = numel(categories(imgsTrain.Labels));
newLayers = [
    dropoutLayer(0.6,'Name','newDropout')
    fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',5,'BiasLearnRateFactor',5)
    softmaxLayer('Name','softmax')
    classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);

lgraph = connectLayers(lgraph,'pool5-7x7_s1','newDropout');
inputSize = net.Layers(1).InputSize;

学習オプションの設定および GoogLeNet の学習

ニューラル ネットワークの学習は、損失関数の最小化を含む反復プロセスです。損失関数を最小化するには、勾配降下アルゴリズムが使用されます。各反復では、損失関数の勾配が評価されて、降下アルゴリズムの重みが更新されます。

学習はさまざまなオプションを設定することによって調整できます。InitialLearnRate は損失関数の負の勾配方向の初期ステップ サイズを指定します。MiniBatchSize は各反復で使用するために学習セットのサブセットの大きさを指定します。1 エポックは、学習セット全体に対する学習アルゴリズムの完全なパスです。MaxEpochs は学習に使用するエポックの最大回数を指定します。エポックの正しい数の選択は自明のタスクではありません。エポックの数が減少するとモデルが適合不足になります。エポックの数が増加すると過適合になります。

関数trainingOptionsを使用して学習オプションを指定します。MiniBatchSize を 10、MaxEpochs を 10、InitialLearnRate を 0.0001 に設定します。Plotstraining-progress に設定して、学習の進行状況を可視化します。モーメンタム項付き確率的勾配降下法オプティマイザーを使用します。既定では、学習は利用可能な GPU がある場合、GPU で行われます (Parallel Computing Toolbox™、および Compute Capability 3.0 以上の CUDA® 対応 GPU が必要)。再現性を得るために、ExecutionEnvironmentcpu に設定することによって、1 つの CPU のみを使用してネットワークを学習して、乱数シードを既定値に設定します。GPU を使用できる場合、実行時間は速くなります。

options = trainingOptions('sgdm',...
    'MiniBatchSize',15,...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4,...
    'ValidationData',imgsValidation,...
    'ValidationFrequency',10,...
    'Verbose',1,...
    'ExecutionEnvironment','cpu',...
    'Plots','training-progress');

学習プロセスは、通常、デスクトップ CPU 上で 1 ~ 5 分かかります。実行中にコマンド ウィンドウに学習情報を表示します。結果には、エポック数、反復回数、経過時間、ミニバッチの精度、検証の精度、検証データの損失関数値が含まれます。

rng default
trainedGN = trainNetwork(imgsTrain,lgraph,options);
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:05 |        6.67% |       18.75% |       4.9207 |       2.4153 |      1.0000e-04 |
|       2 |          10 |       00:00:25 |       66.67% |       62.50% |       0.9578 |       1.3203 |      1.0000e-04 |
|       3 |          20 |       00:00:46 |       46.67% |       75.00% |       1.2938 |       0.5938 |      1.0000e-04 |
|       4 |          30 |       00:01:06 |       53.33% |       78.13% |       0.7139 |       0.4628 |      1.0000e-04 |
|       5 |          40 |       00:01:27 |       73.33% |       84.38% |       0.4740 |       0.3422 |      1.0000e-04 |
|       7 |          50 |       00:01:47 |       93.33% |       84.38% |       0.2818 |       0.2945 |      1.0000e-04 |
|       8 |          60 |       00:02:08 |       80.00% |       87.50% |       0.3610 |       0.2482 |      1.0000e-04 |
|       9 |          70 |       00:02:28 |       86.67% |       84.38% |       0.3397 |       0.2574 |      1.0000e-04 |
|      10 |          80 |       00:02:49 |      100.00% |       96.88% |       0.0718 |       0.1922 |      1.0000e-04 |
|      12 |          90 |       00:03:10 |       86.67% |      100.00% |       0.2872 |       0.1726 |      1.0000e-04 |
|      13 |         100 |       00:03:30 |       86.67% |       96.88% |       0.4367 |       0.1650 |      1.0000e-04 |
|      14 |         110 |       00:03:51 |       86.67% |      100.00% |       0.3139 |       0.1589 |      1.0000e-04 |
|      15 |         120 |       00:04:12 |       93.33% |       96.88% |       0.1491 |       0.1524 |      1.0000e-04 |
|      17 |         130 |       00:04:32 |      100.00% |      100.00% |       0.0553 |       0.1368 |      1.0000e-04 |
|      18 |         140 |       00:04:53 |       93.33% |       96.88% |       0.0997 |       0.1414 |      1.0000e-04 |
|      19 |         150 |       00:05:13 |       93.33% |       93.75% |       0.1621 |       0.1339 |      1.0000e-04 |
|      20 |         160 |       00:05:34 |       93.33% |       96.88% |       0.0881 |       0.1176 |      1.0000e-04 |
|======================================================================================================================|

学習済みネットワークの最後の 3 つの層を検査します。分類出力層は 3 つのラベルを示すことに注意してください。

trainedGN.Layers(end-2:end)
cNames = trainedGN.Layers(end).ClassNames
ans = 

  3x1 Layer array with layers:

     1   'fc'            Fully Connected         3 fully connected layer
     2   'softmax'       Softmax                 softmax
     3   'classoutput'   Classification Output   crossentropyex with 'ARR' and 2 other classes

cNames =

  3×1 cell array

    {'ARR'}
    {'CHF'}
    {'NSR'}

GoogLeNet 精度の評価

検証データを使用してネットワークを評価します。

[YPred,probs] = classify(trainedGN,imgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
display(['GoogLeNet Accuracy: ',num2str(accuracy)])
GoogLeNet Accuracy: 0.96875

精度は、学習可視化の図で報告された検証精度と同じになります。スカログラムは学習コレクションと検証コレクションに分割されました。両方のコレクションは GoogLeNet の学習に使用されました。学習の結果を評価する理想的な方法は、ネットワークで確認されていないデータを分類することです。学習、検証、およびテストに分割するためのデータが十分にないため、計算された検証精度をネットワーク精度として取り扱います。

GoogLeNet 活性化の調査

CNN の各層は入力イメージに対する応答またはアクティベーションを生成します。ただし、CNN 内でイメージの特性抽出に適している層は数層しかありません。ネットワークの始まりにある層が、エッジやブロブのようなイメージの基本的特徴を捉えます。これを確認するには、最初の畳み込み層からネットワーク フィルターの重みを可視化します。最初の層に 64 個の重みの個々のセットがあります。

wghts = trainedGN.Layers(2).Weights;
wghts = rescale(wghts);
wghts = imresize(wghts,5);
figure
montage(wghts)
title('First Convolutional Layer Weights')

活性化を調べ、活性化の領域を元のイメージと比較して、GoogLeNet が学習する特徴を確認できます。詳細は、畳み込みニューラル ネットワークの活性化の可視化 (Deep Learning Toolbox)と畳み込みニューラル ネットワークの特徴の可視化 (Deep Learning Toolbox)を参照してください。

ARR クラスからイメージに対して活性化する畳み込み層の領域を確認します。元のイメージの対応する領域と比較します。畳み込みニューラル ネットワークの各層は、"チャネル" と呼ばれる多数の 2 次元配列で構成されています。イメージをネットワークに渡して、最初の畳み込み層である 'conv1-7x7_s2' の出力活性化を確認します。

convLayer = 'conv1-7x7_s2';

imgClass = 'ARR';
imgName = 'ARR_10.jpg';
imarr = imread(fullfile(parentDir,dataDir,imgClass,imgName));

trainingFeaturesARR = activations(trainedGN,imarr,convLayer);
sz = size(trainingFeaturesARR);
trainingFeaturesARR = reshape(trainingFeaturesARR,[sz(1) sz(2) 1 sz(3)]);
figure
montage(rescale(trainingFeaturesARR),'Size',[8 8])
title([imgClass,' Activations'])

このイメージに最も強いチャネルを確認します。最も強いチャネルを元のイメージと比較します。

imgSize = size(imarr);
imgSize = imgSize(1:2);
[~,maxValueIndex] = max(max(max(trainingFeaturesARR)));
arrMax = trainingFeaturesARR(:,:,:,maxValueIndex);
arrMax = rescale(arrMax);
arrMax = imresize(arrMax,imgSize);
figure;
imshowpair(imarr,arrMax,'montage')
title(['Strongest ',imgClass,' Channel: ',num2str(maxValueIndex)])

AlexNet

AlexNet は、アーキテクチャがサイズ 227 x 227 x 3 のイメージをサポートする深層 CNN です。イメージの次元が GoogLeNet で異なっているにもかかわらず、AlexNet の次元で新しい RGB イメージを生成する必要はありません。元の RGB イメージを使用できます。

読み込み

事前学習済みの AlexNet ニューラル ネットワークを読み込みます。Deep Learning Toolbox™ Model for AlexNet Network サポート パッケージがインストールされていない場合、ソフトウェアによってアドオン エクスプローラーに必要なサポート パッケージへのリンクが表示されます。サポート パッケージをインストールするには、リンクをクリックして、[インストール] をクリックします。

alex = alexnet;

ネットワーク アーキテクチャを確認します。最初の層はイメージの入力サイズを 227 x 227 x 3 に指定しており、AlexNet は GoogLeNet より層が少ないことに注意してください。

layers = alex.Layers
layers = 

  25x1 Layer array with layers:

     1   'data'     Image Input                   227x227x3 images with 'zerocenter' normalization
     2   'conv1'    Convolution                   96 11x11x3 convolutions with stride [4  4] and padding [0  0  0  0]
     3   'relu1'    ReLU                          ReLU
     4   'norm1'    Cross Channel Normalization   cross channel normalization with 5 channels per element
     5   'pool1'    Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0  0  0]
     6   'conv2'    Grouped Convolution           2 groups of 128 5x5x48 convolutions with stride [1  1] and padding [2  2  2  2]
     7   'relu2'    ReLU                          ReLU
     8   'norm2'    Cross Channel Normalization   cross channel normalization with 5 channels per element
     9   'pool2'    Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0  0  0]
    10   'conv3'    Convolution                   384 3x3x256 convolutions with stride [1  1] and padding [1  1  1  1]
    11   'relu3'    ReLU                          ReLU
    12   'conv4'    Grouped Convolution           2 groups of 192 3x3x192 convolutions with stride [1  1] and padding [1  1  1  1]
    13   'relu4'    ReLU                          ReLU
    14   'conv5'    Grouped Convolution           2 groups of 128 3x3x192 convolutions with stride [1  1] and padding [1  1  1  1]
    15   'relu5'    ReLU                          ReLU
    16   'pool5'    Max Pooling                   3x3 max pooling with stride [2  2] and padding [0  0  0  0]
    17   'fc6'      Fully Connected               4096 fully connected layer
    18   'relu6'    ReLU                          ReLU
    19   'drop6'    Dropout                       50% dropout
    20   'fc7'      Fully Connected               4096 fully connected layer
    21   'relu7'    ReLU                          ReLU
    22   'drop7'    Dropout                       50% dropout
    23   'fc8'      Fully Connected               1000 fully connected layer
    24   'prob'     Softmax                       softmax
    25   'output'   Classification Output         crossentropyex with 'tench' and 999 other classes

AlexNet ネットワーク パラメーターの変更

AlexNet の再学習を行って新しいイメージを分類するには、GoogLeNet で行った変更と類似の変更を加えます。

既定では、AlexNet の最後の 3 つの層は、1000 個のカテゴリに対して構成されています。これらの層を、ECG 分類問題に対して微調整しなければなりません。層 23 の全結合層のサイズは ECG データのカテゴリ数と同じになるように設定する必要があります。層 24 は、ECG 分類問題と共に変更する必要はありません。Softmax はソフトマックス関数を入力に適用します。詳細は、softmaxLayerを参照してください。層 25 の分類出力層は、ネットワーク ラベルとクラス ラベルを学習するために使用される損失関数の名前を保持します。3 つの ECG カテゴリがあるため、層 23 を全結合層のサイズが 3 と同じになるように、層 25 を分類出力層になるように設定します。

layers(23) = fullyConnectedLayer(3);
layers(25) = classificationLayer;

AlexNet 用 RGB データの準備

RGB のイメージには GoogLeNet アーキテクチャに適した次元があります。最初の AlexNet 層から AlexNet によって使用されたイメージの次元を取得します。それらの次元を使用して、AlexNet アーキテクチャ用の既存 RGB イメージを自動的にサイズ変更する拡張イメージのデータストアを作成します。詳細については、augmentedImageDatastoreを参照してください。

inputSize = alex.Layers(1).InputSize;
augimgsTrain = augmentedImageDatastore(inputSize(1:2),imgsTrain);
augimgsValidation = augmentedImageDatastore(inputSize(1:2),imgsValidation);

学習オプションの設定および AlexNet の学習

GoogLeNet に使用されたものと一致させるために学習オプションを設定します。次に、AlexNet に学習させます。学習プロセスは、通常、デスクトップ CPU 上で 1 ~ 5 分かかります。

rng default
mbSize = 10;
mxEpochs = 10;
ilr = 1e-4;
plt = 'training-progress';

opts = trainingOptions('sgdm',...
    'InitialLearnRate',ilr, ...
    'MaxEpochs',mxEpochs ,...
    'MiniBatchSize',mbSize, ...
    'ValidationData',augimgsValidation,...
    'ExecutionEnvironment','cpu',...
    'Plots',plt);

trainedAN = trainNetwork(augimgsTrain,layers,opts);
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:02 |       20.00% |       18.75% |       2.6121 |       1.9324 |      1.0000e-04 |
|       4 |          50 |       00:00:34 |      100.00% |       87.50% |       0.0961 |       0.3066 |      1.0000e-04 |
|       8 |         100 |       00:01:07 |      100.00% |       96.88% |       0.1020 |       0.1581 |      1.0000e-04 |
|      10 |         130 |       00:01:27 |      100.00% |       93.75% |       0.0414 |       0.1236 |      1.0000e-04 |
|======================================================================================================================|

検証精度は 93.75% です。学習済みの AlexNet ネットワークの最後の 3 つの層を検査します。3 つのラベルを示す分類出力層を観察します。

trainedAN.Layers(end-2:end)
ans = 

  3x1 Layer array with layers:

     1   'fc'            Fully Connected         3 fully connected layer
     2   'prob'          Softmax                 softmax
     3   'classoutput'   Classification Output   crossentropyex with 'ARR' and 2 other classes

まとめ

この例では、転移学習と連続ウェーブレット解析を使用して、事前学習済みの CNN、GoogLeNet および AlexNet を利用することによって ECG 信号の 3 つのクラスを分類する方法を説明します。ECG 信号のウェーブレットベースの時間-周波数表現を使用してスカログラムを作成します。スカログラムの RGB イメージが生成されます。イメージは両方の深層 CNN を微調整するために使用されます。また、異なるネットワーク層の活性化についても調査しました。

この例は、事前学習済みの CNN モデルを使用して信号を分類するために使用できる 1 つのワークフローを示しています。その他のワークフローを使用できます。GoogLeNet および AlexNet は ImageNet データベースのサブセットで事前学習済みのモデルです [10]。これは、ImageNet Large-Scale Visual Recognition Challenge (ILSVRC) で使用されます [8]。ImageNet コレクションには、魚、鳥、機器、および菌類など、実世界のオブジェクトのイメージが含まれます。スカログラムは実世界のオブジェクトのクラスの範囲外にあります。GoogLeNet および AlexNet アーキテクチャに適合させるために、スカログラムはデータ削減も行いました。事前学習済みの CNN を微調整してスカログラムの異なるクラスを区別する代わりに、元のスカログラムの次元でゼロから CNN を学習する作業はオプションです。

参考文献

  1. Baim, D. S., W. S. Colucci, E. S. Monrad, H. S. Smith, R. F. Wright, A. Lanoue, D. F. Gauthier, B. J. Ransil, W. Grossman, and E. Braunwald."Survival of patients with severe congestive heart failure treated with oral milrinone."Journal of the American College of Cardiology.Vol. 7, Number 3, 1986, pp. 661–670.

  2. Engin, M. "ECG beat classification using neuro-fuzzy network."Pattern Recognition Letters. Vol. 25, Number 15, 2004, pp.1715–1722.

  3. Goldberger A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch. Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K.Peng, and H. E. Stanley."PhysioBank, PhysioToolkit,and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals."Circulation. Vol. 101, Number 23: e215–e220.[Circulation Electronic Pages; http://circ.ahajournals.org/content/101/23/e215.full]; 2000 (June 13). doi: 10.1161/01.CIR.101.23.e215.

  4. Leonarduzzi, R. F., G. Schlotthauer, and M. E. Torres."Wavelet leader based multifractal analysis of heart rate variability during myocardial ischaemia."In Engineering in Medicine and Biology Society (EMBC), Annual International Conference of the IEEE, 110–113.Buenos Aires, Argentina: IEEE, 2010.

  5. Li, T., and M. Zhou."ECG classification using wavelet packet entropy and random forests."Entropy.Vol. 18, Number 8, 2016, p.285.

  6. Maharaj, E. A., and A. M. Alonso."Discriminant analysis of multivariate time series: Application to diagnosis based on ECG signals."Computational Statistics and Data Analysis.Vol. 70, 2014, pp. 67–87.

  7. Moody, G. B., and R. G. Mark."The impact of the MIT-BIH Arrhythmia Database."IEEE Engineering in Medicine and Biology Magazine.Vol. 20. Number 3, May-June 2001, pp. 45–50.(PMID: 11446209)

  8. Russakovsky, O., J. Deng, and H. Su et al. "ImageNet Large Scale Visual Recognition Challenge."International Journal of Computer Vision.Vol. 115, Number 3, 2015, pp. 211–252.

  9. Zhao, Q., and L. Zhang."ECG feature extraction and classification using wavelet transform and support vector machines."In IEEE International Conference on Neural Networks and Brain, 1089–1092.Beijing, China: IEEE, 2005.

  10. ImageNet.http://www.image-net.org

サポート関数

helperCreateECGDataDirectories は、親ディレクトリ内にデータ ディレクトリを作成してから、データ ディレクトリ内に 3 つのサブディレクトリを作成します。サブディレクトリには ECGData で見つかった ECG 信号の各クラスに由来した名前が付けられます。

function helperCreateECGDirectories(ECGData,parentFolder,dataFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

rootFolder = parentFolder;
localFolder = dataFolder;
mkdir(fullfile(rootFolder,localFolder))

folderLabels = unique(ECGData.Labels);
for i = 1:numel(folderLabels)
    mkdir(fullfile(rootFolder,localFolder,char(folderLabels(i))));
end
end

helperPlotReps は、ECGData で見つかった ECG 信号の各クラスの代表的なものから最初の 1000 サンプルをプロットします。

function helperPlotReps(ECGData)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

folderLabels = unique(ECGData.Labels);

for k=1:3
    ecgType = folderLabels{k};
    ind = find(ismember(ECGData.Labels,ecgType));
    subplot(3,1,k)
    plot(ECGData.Data(ind(1),1:1000));
    grid on
    title(ecgType)
end
end

helperCreateRGBfromTFcwtfilterbank を使用して、ECG 信号の連続ウェーブレット変換を取得して、ウェーブレット係数からスカログラムを生成します。補助関数はスカログラムのサイズを変更し、jpeg イメージとしてディスクに書き込みます。

function helperCreateRGBfromTF(ECGData,parentFolder,childFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

imageRoot = fullfile(parentFolder,childFolder);

data = ECGData.Data;
labels = ECGData.Labels;

[~,signalLength] = size(data);

fb = cwtfilterbank('SignalLength',signalLength,'VoicesPerOctave',12);
r = size(data,1);

for ii = 1:r
    cfs = abs(fb.wt(data(ii,:)));
    im = ind2rgb(im2uint8(rescale(cfs)),jet(128));
    
    imgLoc = fullfile(imageRoot,char(labels(ii)));
    imFileName = strcat(char(labels(ii)),'_',num2str(ii),'.jpg');
    imwrite(imresize(im,[224 224]),fullfile(imgLoc,imFileName));
end
end