Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

ウェーブレット解析と深層学習を使用した時系列の分類

この例では、連続ウェーブレット変換 (CWT) と深層畳み込みニューラル ネットワーク (CNN) を使用した人間の心電図 (ECG) 信号を分類する方法を説明します。

深層 CNN をゼロから学習させるには、大量の計算と大量の学習データが必要です。さまざまなアプリケーションでは、十分な量の学習データが利用可能ではなく、新しい現実的な学習の例の合成は不可能です。これらの場合、概念的に類似したタスクについて大規模なデータ セットで学習させている既存のニューラル ネットワークを利用することが望まれます。既存のニューラル ネットワークの利用は転移学習と呼ばれます。この例では、イメージ認識用に事前学習済みの、2 つの深層 CNN、GoogLeNet および SqueezeNet を適応させて、時間-周波数表現を基にした ECG 波形を分類します。

GoogLeNet と SqueezeNet は、1000 カテゴリにイメージを分類するために最初に設計された深層 CNN です。時系列データの CWT からのイメージを基にした ECG 信号を分類するために CNN のネットワーク アーキテクチャを再利用します。この例で使用されているデータは、PhysioNet から公的に入手可能です。

データの説明

この例では、人の 3 つのグループから取得された ECG データを使用します。3 つのグループとは、心不整脈の患者 (ARR)、鬱血性心不全の患者 (CHF)、および正常洞調律の患者 (NSR) です。全体では、次の 3 つの PhysioNet データベースから 162 個の ECG 記録を使用します。MIT-BIH Arrhythmia Database [3][7]、MIT-BIH Normal Sinus Rhythm Database [3]、BIDMC Congestive Heart Failure Database [1][3]。具体的には、不整脈の患者の記録は 96 個、鬱血性心不全の患者の記録は 30 個、正常洞調律の患者の記録は 36 個あります。目標は、ARR、CHF、および NSR 間で分類器を学習させて区別することです。

データのダウンロード

1 番目のステップは、GitHub® リポジトリからデータをダウンロードすることです。データを Web サイトからダウンロードするには、[Code] をクリックして [Download ZIP] を選択します。書き込み権限のあるフォルダーに、ファイル physionet_ECG_data-main.zip を保存します。この例の手順では、ファイルを一時ディレクトリ (MATLAB® の tempdir) にダウンロードしているものと仮定します。tempdir とは異なるフォルダーにデータをダウンロードすることを選択した場合は、データの解凍および読み込みに関する後続の手順を変更してください。

データを GitHub からダウンロードした後、一時ディレクトリでそのファイルを解凍します。

unzip(fullfile(tempdir,"physionet_ECG_data-main.zip"),tempdir)

解凍すると、一時ディレクトリにフォルダー physionet-ECG_data-main が作成されます。このフォルダーには、テキスト ファイル README.mdECGData.zip が含まれます。ECGData.zip ファイルには次のものが含まれています。

  • ECGData.mat

  • Modified_physionet_data.txt

  • License.txt

ECGData.mat は、この例で使用されるデータを保持します。テキスト ファイルの Modified_physionet_data.txt は PhysioNet のコピー ポリシーで必要になり、データのソース属性、および ECG の各記録に適用される前処理手順の説明を提供します。

physionet-ECG_data-mainECGData.zip を解凍します。データ ファイルを MATLAB ワークスペースに読み込みます。

unzip(fullfile(tempdir,"physionet_ECG_data-main","ECGData.zip"), ...
    fullfile(tempdir,"physionet_ECG_data-main"))
load(fullfile(tempdir,"physionet_ECG_data-main","ECGData.mat"))

ECGData は、2 つのフィールド (Data および Labels) をもつ構造体配列です。Data フィールドは 162 行 65,536 列の行列で、各行は 128 Hz でサンプリングした ECG 記録です。Labels は 162 行 1 列の診断ラベルの cell 配列で、それぞれが Data の各行に対応します。3 つの診断カテゴリは、'ARR''CHF'、および 'NSR' です。

各カテゴリの前処理したデータを保存するには、最初に tempdir 内に ECG データ ディレクトリ dataDir を作成します。その後、各 ECG カテゴリに由来した 'data' に 3 つのサブディレクトリを作成します。補助関数 helperCreateECGDirectories がこれを実行します。helperCreateECGDirectories は、ECGData、ECG データ ディレクトリの名前、親ディレクトリの名前を入力引数として受け入れます。tempdir を書き込み権限のある別のディレクトリと置き換えることができます。この補助関数のソース コードは、この例の最後にあるサポート関数のセクションで見つけることができます。

parentDir = tempdir;
dataDir = "data";
helperCreateECGDirectories(ECGData,parentDir,dataDir)

各 ECG カテゴリの表現をプロットします。補助関数 helperPlotReps がこれを実行します。helperPlotRepsECGData を入力として受け入れます。この補助関数のソース コードは、この例の最後にあるサポート関数のセクションで見つけることができます。

helperPlotReps(ECGData)

時間-周波数表現の作成

フォルダー作成後、ECG 信号の時間-周波数表現を作成します。これらの表現はスカログラムと呼ばれます。スカログラムは、信号の CWT 係数の絶対値です。

スカログラムを作成するには、CWT フィルター バンクを事前に計算します。CWT フィルター バンクの事前計算は、同じパラメーターを使用して多数の信号の CWT を取得するときの推奨方法です。

スカログラムを生成する前に、そのうちの 1 つを調べます。1,000 サンプルをもつ信号にcwtfilterbankを使用して CWT フィルター バンクを作成します。フィルター バンクを使用して、信号の最初の 1,000 サンプルの CWT を取って、係数からスカログラムを取得します。

Fs = 128;
fb = cwtfilterbank(SignalLength=1000, ...
    SamplingFrequency=Fs, ...
    VoicesPerOctave=12);
sig = ECGData.Data(1,1:1000);
[cfs,frq] = wt(fb,sig);
t = (0:999)/Fs;
figure
pcolor(t,frq,abs(cfs))
set(gca,"yscale","log")
shading interp
axis tight
title("Scalogram")
xlabel("Time (s)")
ylabel("Frequency (Hz)")

補助関数 helperCreateRGBfromTF を使用して、スカログラムを RGB イメージとして作成し、それらを dataDir の適切なサブディレクトリに書き込みます。この補助関数のソース コードは、この例の最後にあるサポート関数のセクションにあります。GoogLeNet アーキテクチャと互換性をもたせるために、各 RGB イメージは 224 x 224 x 3 のサイズの配列になります。

helperCreateRGBfromTF(ECGData,parentDir,dataDir)

学習データと検証データへの分割

スカログラム イメージをイメージ データストアとして読み込みます。関数imageDatastoreは、フォルダー名に基づいてイメージに自動的にラベルを付け、このデータを ImageDatastore オブジェクトとして格納します。イメージ データストアを使用すると、メモリに収まらないデータなどの大きなイメージ データを格納し、CNN の学習中にイメージをバッチ単位で効率的に読み取ることができます。

allImages = imageDatastore(fullfile(parentDir,dataDir), ...
    "IncludeSubfolders",true, ...
    "LabelSource","foldernames");

イメージを 2 つのグループ (学習用と検証用) にランダムに分割します。イメージの 80% を学習に使用し、残りを検証に使用します。

[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,"randomized");
disp("Number of training images: "+num2str(numel(imgsTrain.Files)))
Number of training images: 130
disp("Number of validation images: "+num2str(numel(imgsValidation.Files)))
Number of validation images: 32

GoogLeNet

読み込み

事前学習済みの GoogLeNet ニューラル ネットワークを読み込みます。Deep Learning Toolbox™ Model for GoogLeNet Network サポート パッケージがインストールされていない場合、ソフトウェアによってアドオン エクスプローラーに必要なサポート パッケージへのリンクが表示されます。サポート パッケージをインストールするには、リンクをクリックして、[インストール] をクリックします。

net = googlenet;

ネットワークから層グラフを抽出して表示します。

lgraph = layerGraph(net);
numberOfLayers = numel(lgraph.Layers);
figure("Units","normalized","Position",[0.1 0.1 0.8 0.8])
plot(lgraph)
title("GoogLeNet Layer Graph: "+num2str(numberOfLayers)+" Layers")

ネットワーク Layers プロパティの 1 番目の要素を検査します。GoogLeNet にはサイズ 224 x 224 x 3 の RGB イメージが必要であることを確認します。

lgraph.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'data'
                 InputSize: [224 224 3]
        SplitComplexInputs: 0

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'zerocenter'
    NormalizationDimension: 'auto'
                      Mean: [224×224×3 single]

GoogLeNet ネットワーク パラメーターの変更

ネットワーク アーキテクチャの各層はフィルターと見なすことができます。初期の層は、ブロブ、エッジ、および色など、より一般的なイメージの特徴を識別します。後続の層は、カテゴリを区別するためにより具体的な特徴に焦点を当てます。GoogLeNet は事前学習済みで、イメージを 1000 個のオブジェクト カテゴリに分類できます。GoogLeNet を、ECG 分類問題用に再学習させなければなりません。

ネットワークの最後の 5 つの層を検査します。

lgraph.Layers(end-4:end)
ans = 
  5×1 Layer array with layers:

     1   'pool5-7x7_s1'        2-D Global Average Pooling   2-D global average pooling
     2   'pool5-drop_7x7_s1'   Dropout                      40% dropout
     3   'loss3-classifier'    Fully Connected              1000 fully connected layer
     4   'prob'                Softmax                      softmax
     5   'output'              Classification Output        crossentropyex with 'tench' and 999 other classes

過適合を防止するには、ドロップアウト層を使用します。ドロップアウト層は、与えられた確率でランダムに、入力要素をゼロに設定します。詳細については、dropoutLayer (Deep Learning Toolbox)を参照してください。既定の確率は 0.5 です。ネットワーク内の最後のドロップアウト層 pool5-drop_7x7_s1 を、確率 0.6 のドロップアウト層に置き換えます。

newDropoutLayer = dropoutLayer(0.6,"Name","new_Dropout");
lgraph = replaceLayer(lgraph,"pool5-drop_7x7_s1",newDropoutLayer);

ネットワークの畳み込み層は、入力イメージを分類するために、最後の学習可能な層と最終分類層が使用するイメージの特徴を抽出します。GoogLeNet のこれらの 2 つの層 loss3-classifier および output は、ネットワークによって抽出された特徴を組み合わせてクラス確率、損失値、および予測ラベルにまとめる方法に関する情報を含んでいます。RGB イメージを分類するために GoogLeNet を再学習させるには、これら 2 つの層をデータに適応させた新しい層に置き換えます。

全結合層 loss3-classifier を、クラスの数と同じ数のフィルターを持つ新しい全結合層に置き換えます。新しい層での学習速度を転移された層より速くするには、全結合層の学習率係数を大きくします。

numClasses = numel(categories(imgsTrain.Labels));
newConnectedLayer = fullyConnectedLayer(numClasses,"Name","new_fc", ...
    "WeightLearnRateFactor",5,"BiasLearnRateFactor",5);
lgraph = replaceLayer(lgraph,"loss3-classifier",newConnectedLayer);

分類層はネットワークの出力クラスを指定します。分類層をクラス ラベルがない新しい分類層に置き換えます。trainNetwork は、学習時に層の出力クラスを自動的に設定します。

newClassLayer = classificationLayer("Name","new_classoutput");
lgraph = replaceLayer(lgraph,"output",newClassLayer);

最後の 5 つの層を検査します。ドロップアウト層、畳み込み層、および出力層を置き換えたことを確認します。

lgraph.Layers(end-4:end)
ans = 
  5×1 Layer array with layers:

     1   'pool5-7x7_s1'      2-D Global Average Pooling   2-D global average pooling
     2   'new_Dropout'       Dropout                      60% dropout
     3   'new_fc'            Fully Connected              3 fully connected layer
     4   'prob'              Softmax                      softmax
     5   'new_classoutput'   Classification Output        crossentropyex

学習オプションの設定および GoogLeNet の学習

ニューラル ネットワークの学習は、損失関数の最小化を含む反復プロセスです。損失関数を最小化するには、勾配降下アルゴリズムが使用されます。各反復では、損失関数の勾配が評価されて、降下アルゴリズムの重みが更新されます。

学習はさまざまなオプションを設定することによって調整できます。InitialLearnRate は損失関数の負の勾配方向の初期ステップ サイズを指定します。MiniBatchSize は各反復で使用するために学習セットのサブセットの大きさを指定します。1 エポックは、学習セット全体に対する学習アルゴリズムを一巡することです。MaxEpochs は学習に使用するエポックの最大回数を指定します。エポックの正しい数の選択は自明のタスクではありません。エポックの数が減少するとモデルが適合不足になります。エポックの数が増加すると過適合になります。

関数trainingOptions (Deep Learning Toolbox)を使用して学習オプションを指定します。MiniBatchSize を 15、MaxEpochs を 20、InitialLearnRate を 0.0001 に設定します。Plotstraining-progress に設定して、学習の進行状況を可視化します。モーメンタム項付き確率的勾配降下法オプティマイザーを使用します。既定では、GPU が利用可能な場合、学習は GPU で行われます。GPU を使用するには Parallel Computing Toolbox™ が必要です。サポートされている GPU については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

options = trainingOptions("sgdm", ...
    MiniBatchSize=15, ...
    MaxEpochs=20, ...
    InitialLearnRate=1e-4, ...
    ValidationData=imgsValidation, ...
    ValidationFrequency=10, ...
    Verbose=1, ...
    Plots="training-progress");

ネットワークに学習をさせます。学習プロセスは、通常、デスクトップ CPU 上で 1 ~ 5 分かかります。GPU を使用できる場合、実行時間は速くなります。実行中にコマンド ウィンドウに学習情報を表示します。結果には、エポック数、反復回数、経過時間、ミニバッチの精度、検証の精度、検証データの損失関数値が含まれます。

trainedGN = trainNetwork(imgsTrain,lgraph,options);
Training on single GPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:03 |       33.33% |       59.38% |       2.5151 |       1.1989 |      1.0000e-04 |
|       2 |          10 |       00:00:08 |       53.33% |       62.50% |       1.4702 |       0.8987 |      1.0000e-04 |
|       3 |          20 |       00:00:12 |       53.33% |       84.38% |       1.0575 |       0.5156 |      1.0000e-04 |
|       4 |          30 |       00:00:17 |       60.00% |       78.12% |       0.7547 |       0.4531 |      1.0000e-04 |
|       5 |          40 |       00:00:22 |       60.00% |       93.75% |       0.7951 |       0.3317 |      1.0000e-04 |
|       7 |          50 |       00:00:26 |      100.00% |       78.12% |       0.1318 |       0.3585 |      1.0000e-04 |
|       8 |          60 |       00:00:31 |       80.00% |       93.75% |       0.4344 |       0.2507 |      1.0000e-04 |
|       9 |          70 |       00:00:36 |      100.00% |       87.50% |       0.1106 |       0.2495 |      1.0000e-04 |
|      10 |          80 |       00:00:41 |       80.00% |      100.00% |       0.5102 |       0.1971 |      1.0000e-04 |
|      12 |          90 |       00:00:46 |      100.00% |      100.00% |       0.1180 |       0.1714 |      1.0000e-04 |
|      13 |         100 |       00:00:51 |       93.33% |      100.00% |       0.3106 |       0.1463 |      1.0000e-04 |
|      14 |         110 |       00:00:56 |      100.00% |      100.00% |       0.1004 |       0.1145 |      1.0000e-04 |
|      15 |         120 |       00:01:00 |       93.33% |      100.00% |       0.1732 |       0.1132 |      1.0000e-04 |
|      17 |         130 |       00:01:05 |       93.33% |       96.88% |       0.1391 |       0.1294 |      1.0000e-04 |
|      18 |         140 |       00:01:10 |      100.00% |      100.00% |       0.1139 |       0.1001 |      1.0000e-04 |
|      19 |         150 |       00:01:15 |      100.00% |      100.00% |       0.0320 |       0.0914 |      1.0000e-04 |
|      20 |         160 |       00:01:19 |      100.00% |       93.75% |       0.0611 |       0.1264 |      1.0000e-04 |
|======================================================================================================================|
Training finished: Max epochs completed.

学習済みネットワークの最後の層を検査します。3 つのクラスを含む分類出力層を確認します。

trainedGN.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'new_classoutput'
         Classes: [ARR    CHF    NSR]
    ClassWeights: 'none'
      OutputSize: 3

   Hyperparameters
    LossFunction: 'crossentropyex'

GoogLeNet 精度の評価

検証データを使用してネットワークを評価します。

[YPred,~] = classify(trainedGN,imgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp("GoogLeNet Accuracy: "+num2str(100*accuracy)+"%")
GoogLeNet Accuracy: 93.75%

精度は、学習可視化の図で報告された検証精度と同じになります。スカログラムは学習コレクションと検証コレクションに分割されました。両方のコレクションは GoogLeNet の学習に使用されました。学習の結果を評価する理想的な方法は、ネットワークで確認されていないデータを分類することです。学習、検証、およびテストに分割するためのデータが十分にないため、計算された検証精度をネットワーク精度として取り扱います。

GoogLeNet 活性化の調査

CNN の各層は入力イメージに対する応答または活性化を生成します。ただし、CNN 内でイメージの特徴抽出に適している層は数層しかありません。学習済みネットワークの最初の 5 つの層を検査します。

trainedGN.Layers(1:5)
ans = 
  5×1 Layer array with layers:

     1   'data'             Image Input                   224×224×3 images with 'zerocenter' normalization
     2   'conv1-7x7_s2'     2-D Convolution               64 7×7×3 convolutions with stride [2  2] and padding [3  3  3  3]
     3   'conv1-relu_7x7'   ReLU                          ReLU
     4   'pool1-3x3_s2'     2-D Max Pooling               3×3 max pooling with stride [2  2] and padding [0  1  0  1]
     5   'pool1-norm1'      Cross Channel Normalization   cross channel normalization with 5 channels per element

ネットワークの始まりにある層が、エッジやブロブのようなイメージの基本的特徴を捉えます。これを確認するには、最初の畳み込み層からネットワーク フィルターの重みを可視化します。最初の層に 64 個の重みの個々のセットがあります。

wghts = trainedGN.Layers(2).Weights;
wghts = rescale(wghts);
wghts = imresize(wghts,8);
figure
I = imtile(wghts,GridSize=[8 8]);
imshow(I)
title("First Convolutional Layer Weights")

活性化を調べ、活性化の領域を元のイメージと比較して、GoogLeNet が学習する特徴を確認できます。詳細については、畳み込みニューラル ネットワークの活性化の可視化 (Deep Learning Toolbox)畳み込みニューラル ネットワークの特徴の可視化 (Deep Learning Toolbox)を参照してください。

ARR クラスからイメージに対して活性化する畳み込み層の領域を確認します。元のイメージの対応する領域と比較します。畳み込みニューラル ネットワークの各層は、"チャネル" と呼ばれる多数の 2 次元配列で構成されています。イメージをネットワークに渡して、最初の畳み込み層である conv1-7x7_s2 の出力活性化を確認します。

convLayer = "conv1-7x7_s2";

imgClass = "ARR";
imgName = "ARR_10.jpg";
imarr = imread(fullfile(parentDir,dataDir,imgClass,imgName));

trainingFeaturesARR = activations(trainedGN,imarr,convLayer);
sz = size(trainingFeaturesARR);
trainingFeaturesARR = reshape(trainingFeaturesARR,[sz(1) sz(2) 1 sz(3)]);
figure
I = imtile(rescale(trainingFeaturesARR),GridSize=[8 8]);
imshow(I)
title(imgClass+" Activations")

このイメージに最も強いチャネルを確認します。最も強いチャネルを元のイメージと比較します。

imgSize = size(imarr);
imgSize = imgSize(1:2);
[~,maxValueIndex] = max(max(max(trainingFeaturesARR)));
arrMax = trainingFeaturesARR(:,:,:,maxValueIndex);
arrMax = rescale(arrMax);
arrMax = imresize(arrMax,imgSize);
figure
I = imtile({imarr,arrMax});
imshow(I)
title("Strongest "+imgClass+" Channel: "+num2str(maxValueIndex))

SqueezeNet

SqueezeNet は、アーキテクチャがサイズ 227 x 227 x 3 のイメージをサポートする深層 CNN です。イメージの次元が GoogLeNet で異なっているにもかかわらず、SqueezeNet の次元で新しい RGB イメージを生成する必要はありません。元の RGB イメージを使用できます。

読み込み

事前学習済みの SqueezeNet ニューラル ネットワークを読み込みます。Deep Learning Toolbox™ Model for SqueezeNet Network サポート パッケージがインストールされていない場合、ソフトウェアによってアドオン エクスプローラーに必要なサポート パッケージへのリンクが表示されます。サポート パッケージをインストールするには、リンクをクリックして、[インストール] をクリックします。

sqz = squeezenet;

ネットワークから層グラフを抽出します。SqueezeNet は GoogLeNet より層が少ないことを確認します。また、SqueezeNet がサイズ 227×227×3 のイメージ用に構成されていることも確認します。

lgraphSqz = layerGraph(sqz);
disp("Number of Layers: "+num2str(numel(lgraphSqz.Layers)))
Number of Layers: 68
lgraphSqz.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'data'
                 InputSize: [227 227 3]
        SplitComplexInputs: 0

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'zerocenter'
    NormalizationDimension: 'auto'
                      Mean: [1×1×3 single]

SqueezeNet ネットワーク パラメーターの変更

SqueezeNet の再学習を行って新しいイメージを分類するには、GoogLeNet で行った変更と類似の変更を加えます。

最後の 6 つのネットワーク層を検査します。

lgraphSqz.Layers(end-5:end)
ans = 
  6×1 Layer array with layers:

     1   'drop9'                             Dropout                      50% dropout
     2   'conv10'                            2-D Convolution              1000 1×1×512 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu_conv10'                       ReLU                         ReLU
     4   'pool10'                            2-D Global Average Pooling   2-D global average pooling
     5   'prob'                              Softmax                      softmax
     6   'ClassificationLayer_predictions'   Classification Output        crossentropyex with 'tench' and 999 other classes

ネットワーク内の最後のドロップアウト層を、確率 0.6 のドロップアウト層に置き換えます。

tmpLayer = lgraphSqz.Layers(end-5);
newDropoutLayer = dropoutLayer(0.6,"Name","new_dropout");
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newDropoutLayer);

GoogLeNet とは異なり、SqueezeNet の最後の学習可能な層は 1×1 畳み込み層 conv10 であり、全結合層ではありません。この層を、クラスの数と同じ数のフィルターを持つ新しい畳み込み層に置き換えます。GoogLeNet のときと同様に、新しい層の学習率係数を大きくします。

numClasses = numel(categories(imgsTrain.Labels));
tmpLayer = lgraphSqz.Layers(end-4);
newLearnableLayer = convolution2dLayer(1,numClasses, ...
        "Name","new_conv", ...
        "WeightLearnRateFactor",10, ...
        "BiasLearnRateFactor",10);
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newLearnableLayer);

分類層をクラス ラベルがない新しい分類層に置き換えます。

tmpLayer = lgraphSqz.Layers(end);
newClassLayer = classificationLayer("Name","new_classoutput");
lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newClassLayer);

ネットワークの最後の 6 つの層を検査します。ドロップアウト層、畳み込み層、および出力層が変更されていることを確認します。

lgraphSqz.Layers(end-5:end)
ans = 
  6×1 Layer array with layers:

     1   'new_dropout'       Dropout                      60% dropout
     2   'new_conv'          2-D Convolution              3 1×1 convolutions with stride [1  1] and padding [0  0  0  0]
     3   'relu_conv10'       ReLU                         ReLU
     4   'pool10'            2-D Global Average Pooling   2-D global average pooling
     5   'prob'              Softmax                      softmax
     6   'new_classoutput'   Classification Output        crossentropyex

SqueezeNet 用 RGB データの準備

RGB のイメージには GoogLeNet アーキテクチャに適した次元があります。SqueezeNet アーキテクチャ用の既存 RGB イメージを自動的にサイズ変更する拡張イメージのデータストアを作成します。詳細については、augmentedImageDatastore (Deep Learning Toolbox)を参照してください。

augimgsTrain = augmentedImageDatastore([227 227],imgsTrain);
augimgsValidation = augmentedImageDatastore([227 227],imgsValidation);

学習オプションの設定および SqueezeNet の学習

SqueezeNet で使用する新しい一連の学習オプションを作成し、ネットワークに学習させます。

ilr = 3e-4;
miniBatchSize = 10;
maxEpochs = 15;
valFreq = floor(numel(augimgsTrain.Files)/miniBatchSize);
opts = trainingOptions("sgdm", ...
    MiniBatchSize=miniBatchSize, ...
    MaxEpochs=maxEpochs, ...
    InitialLearnRate=ilr, ...
    ValidationData=augimgsValidation, ...
    ValidationFrequency=valFreq, ...
    Verbose=1, ...
    Plots="training-progress");

trainedSN = trainNetwork(augimgsTrain,lgraphSqz,opts);
Training on single GPU.
Initializing input data normalization.
|======================================================================================================================|
|  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  |
|         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       |
|======================================================================================================================|
|       1 |           1 |       00:00:02 |       10.00% |       56.25% |       4.0078 |       2.3776 |          0.0003 |
|       1 |          13 |       00:00:03 |       50.00% |       68.75% |       0.9363 |       0.9029 |          0.0003 |
|       2 |          26 |       00:00:05 |       70.00% |       78.12% |       0.8365 |       0.7137 |          0.0003 |
|       3 |          39 |       00:00:07 |       70.00% |       81.25% |       0.7844 |       0.5915 |          0.0003 |
|       4 |          50 |       00:00:08 |       70.00% |              |       0.5947 |              |          0.0003 |
|       4 |          52 |       00:00:09 |       70.00% |       84.38% |       0.5064 |       0.5000 |          0.0003 |
|       5 |          65 |       00:00:10 |       90.00% |       81.25% |       0.3023 |       0.3732 |          0.0003 |
|       6 |          78 |       00:00:12 |      100.00% |       87.50% |       0.0815 |       0.2651 |          0.0003 |
|       7 |          91 |       00:00:14 |      100.00% |       90.62% |       0.0644 |       0.2409 |          0.0003 |
|       8 |         100 |       00:00:15 |       80.00% |              |       0.2182 |              |          0.0003 |
|       8 |         104 |       00:00:16 |      100.00% |       96.88% |       0.1349 |       0.1818 |          0.0003 |
|       9 |         117 |       00:00:18 |       90.00% |       93.75% |       0.1663 |       0.1920 |          0.0003 |
|      10 |         130 |       00:00:19 |       90.00% |       90.62% |       0.1899 |       0.2258 |          0.0003 |
|      11 |         143 |       00:00:21 |      100.00% |       90.62% |       0.0962 |       0.1869 |          0.0003 |
|      12 |         150 |       00:00:22 |      100.00% |              |       0.0075 |              |          0.0003 |
|      12 |         156 |       00:00:23 |       90.00% |       93.75% |       0.2934 |       0.2201 |          0.0003 |
|      13 |         169 |       00:00:25 |      100.00% |       90.62% |       0.0958 |       0.3039 |          0.0003 |
|      14 |         182 |       00:00:27 |      100.00% |       93.75% |       0.0089 |       0.1430 |          0.0003 |
|      15 |         195 |       00:00:28 |      100.00% |       93.75% |       0.0068 |       0.2061 |          0.0003 |
|======================================================================================================================|
Training finished: Max epochs completed.

ネットワークの最後の層を検査します。3 つのクラスを含む分類出力層を確認します。

trainedSN.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'new_classoutput'
         Classes: [ARR    CHF    NSR]
    ClassWeights: 'none'
      OutputSize: 3

   Hyperparameters
    LossFunction: 'crossentropyex'

SqueezeNet 精度の評価

検証データを使用してネットワークを評価します。

[YPred,probs] = classify(trainedSN,augimgsValidation);
accuracy = mean(YPred==imgsValidation.Labels);
disp("SqueezeNet Accuracy: "+num2str(100*accuracy)+"%")
SqueezeNet Accuracy: 93.75%

まとめ

この例では、転移学習と連続ウェーブレット解析を使用して、事前学習済みの CNN、GoogLeNet および SqueezeNet を利用することによって ECG 信号の 3 つのクラスを分類する方法を説明します。ECG 信号のウェーブレットベースの時間-周波数表現を使用してスカログラムを作成します。スカログラムの RGB イメージが生成されます。イメージは両方の深層 CNN を微調整するために使用されます。また、異なるネットワーク層の活性化についても調査しました。

この例は、事前学習済みの CNN モデルを使用して信号を分類するために使用できる 1 つのワークフローを示しています。その他のワークフローも使用できます。ウェーブレット解析と深層学習を使用した NVIDIA Jetson への信号分類器の展開およびRaspberry Pi におけるウェーブレットおよび深層学習を使用した信号分類器の展開では信号を分類するためのコードをハードウェアに展開する方法を示しています。GoogLeNet および SqueezeNet は ImageNet データベースのサブセットで事前学習済みのモデルです [10]。これは、ImageNet Large-Scale Visual Recognition Challenge (ILSVRC) で使用されます [8]。ImageNet コレクションには、魚、鳥、機器、および菌類など、実世界のオブジェクトのイメージが含まれます。スカログラムは実世界のオブジェクトのクラスの範囲外にあります。GoogLeNet および SqueezeNet アーキテクチャに適合させるために、スカログラムはデータ削減も行いました。事前学習済みの CNN を微調整してスカログラムの異なるクラスを区別する代わりに、元のスカログラムの次元でゼロから CNN を学習する作業はオプションです。

参考文献

  1. Baim, D. S., W. S. Colucci, E. S. Monrad, H. S. Smith, R. F. Wright, A. Lanoue, D. F. Gauthier, B. J. Ransil, W. Grossman, and E. Braunwald."Survival of patients with severe congestive heart failure treated with oral milrinone."Journal of the American College of Cardiology.Vol. 7, Number 3, 1986, pp. 661–670.

  2. Engin, M. "ECG beat classification using neuro-fuzzy network." Pattern Recognition Letters. Vol. 25, Number 15, 2004, pp.1715–1722.

  3. Goldberger A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch. Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K. Peng, and H. E. Stanley. "PhysioBank, PhysioToolkit,and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals." Circulation. Vol. 101, Number 23: e215–e220. [Circulation Electronic Pages; http://circ.ahajournals.org/content/101/23/e215.full]; 2000 (June 13). doi: 10.1161/01.CIR.101.23.e215.

  4. Leonarduzzi, R. F., G. Schlotthauer, and M. E. Torres."Wavelet leader based multifractal analysis of heart rate variability during myocardial ischaemia."In Engineering in Medicine and Biology Society (EMBC), Annual International Conference of the IEEE, 110–113.Buenos Aires, Argentina: IEEE, 2010.

  5. Li, T., and M. Zhou."ECG classification using wavelet packet entropy and random forests."Entropy.Vol. 18, Number 8, 2016, p.285.

  6. Maharaj, E. A., and A. M. Alonso."Discriminant analysis of multivariate time series: Application to diagnosis based on ECG signals." Computational Statistics and Data Analysis.Vol. 70, 2014, pp. 67–87.

  7. Moody, G. B., and R. G. Mark."The impact of the MIT-BIH Arrhythmia Database." IEEE Engineering in Medicine and Biology Magazine. Vol. 20. Number 3, May-June 2001, pp. 45–50. (PMID: 11446209)

  8. Russakovsky, O., J. Deng, and H. Su et al. "ImageNet Large Scale Visual Recognition Challenge." International Journal of Computer Vision.Vol. 115, Number 3, 2015, pp. 211–252.

  9. Zhao, Q., and L. Zhang."ECG feature extraction and classification using wavelet transform and support vector machines."In IEEE International Conference on Neural Networks and Brain, 1089–1092.Beijing, China: IEEE, 2005.

  10. ImageNet. http://www.image-net.org

サポート関数

helperCreateECGDataDirectories は、親ディレクトリ内にデータ ディレクトリを作成してから、データ ディレクトリ内に 3 つのサブディレクトリを作成します。サブディレクトリには ECGData で見つかった ECG 信号の各クラスに由来した名前が付けられます。

function helperCreateECGDirectories(ECGData,parentFolder,dataFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

rootFolder = parentFolder;
localFolder = dataFolder;
mkdir(fullfile(rootFolder,localFolder))

folderLabels = unique(ECGData.Labels);
for i = 1:numel(folderLabels)
    mkdir(fullfile(rootFolder,localFolder,char(folderLabels(i))));
end
end

helperPlotReps は、ECGData で見つかった ECG 信号の各クラスの代表的なものから最初の 1000 サンプルをプロットします。

function helperPlotReps(ECGData)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

folderLabels = unique(ECGData.Labels);

for k=1:3
    ecgType = folderLabels{k};
    ind = find(ismember(ECGData.Labels,ecgType));
    subplot(3,1,k)
    plot(ECGData.Data(ind(1),1:1000));
    grid on
    title(ecgType)
end
end

helperCreateRGBfromTFcwtfilterbankを使用して、ECG 信号の連続ウェーブレット変換を取得し、ウェーブレット係数からスカログラムを生成します。補助関数はスカログラムのサイズを変更し、jpeg イメージとしてディスクに書き込みます。

function helperCreateRGBfromTF(ECGData,parentFolder,childFolder)
% This function is only intended to support the ECGAndDeepLearningExample.
% It may change or be removed in a future release.

imageRoot = fullfile(parentFolder,childFolder);

data = ECGData.Data;
labels = ECGData.Labels;

[~,signalLength] = size(data);

fb = cwtfilterbank(SignalLength=signalLength,VoicesPerOctave=12);
r = size(data,1);

for ii = 1:r
    cfs = abs(fb.wt(data(ii,:)));
    im = ind2rgb(round(rescale(cfs,0,255)),jet(128));
    
    imgLoc = fullfile(imageRoot,char(labels(ii)));
    imFileName = char(labels(ii))+"_"+num2str(ii)+".jpg";
    imwrite(imresize(im,[224 224]),fullfile(imgLoc,imFileName));
end
end

参考

| (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | | (Deep Learning Toolbox)

関連するトピック