Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

ニューラル ODE ネットワークの学習

この例では、拡張ニューラル常微分方程式 (ODE) ネットワークに学習させる方法を示します。

ニューラル ODE [1] は、ODE の解を返す深層学習演算です。特に、入力が与えられると、ニューラル ODE 演算は、時間ホライズンを (t0,t1) とし、初期条件を y(t0)=y0 とする ODE y=f(t,y,θ ) の数値解を出力します。ここで、ty は ODE 関数の入力であり、θ は学習可能なパラメーターのセットです。通常、初期条件 y0 には、ネットワーク入力が使用されるか、この例のように別の深層学習演算の出力が使用されます。

"拡張" ニューラル ODE [2] 演算は、入力データを追加のチャネルで拡張し、ニューラル ODE 演算の後に拡張を破棄することにより、標準のニューラル ODE を改善します。経験的に、拡張ニューラル ODE は、ニューラル ODE よりも安定しており、汎化が良好で、計算コストが低くなります。

この例では、拡張ニューラル ODE 演算を使用して単純な畳み込みニューラル ネットワークに学習させます。

2021-04-20_18-34-50.png

ODE 関数は、それ自体がニューラル ネットワークです。この例では、モデルは畳み込み層と tanh 層をもつ次のようなネットワークを使用します。

2021-04-20_18-35-22.png

この例では、ニューラル ネットワークに学習させ、拡張ニューラル ODE 演算を使用して数字のイメージを分類する方法を示します。

学習データの読み込み

関数 digitTrain4DArrayData を使用して学習させるイメージとラベルを読み込みます。

load DigitsDataTrain

学習データのクラスの数を表示します。

TTrain = labelsTrain;
classNames = categories(TTrain);
numClasses = numel(classNames)
numClasses = 10

学習データからの一部のイメージを表示します。

numObservations = size(XTrain,4);
idx = randperm(numObservations,64);
I = imtile(XTrain(:,:,:,idx));
figure
imshow(I)

ニューラル ネットワーク アーキテクチャの定義

イメージを分類する次のネットワークを定義します。

  • ストライドが 2 の 3 行 3 列のフィルターを 8 個もつ convolution-ReLU ブロック

  • 出力が入力の 2 倍のチャネル数をもつように、ゼロの配列を入力に連結する拡張層

  • 3 行 3 列のフィルターを 16 個もつ convolution-tanh ブロックを含む ODE 関数を使用したニューラル ODE 演算

  • 出力が入力の半分のチャネル数をもつように、チャネル次元の末尾の要素をトリミングする拡張破棄層

  • 分類出力用の、サイズ 10 (クラスの数) の全結合演算とソフトマックス演算

2021-04-20_18-34-50.png

ニューラル ODE 層は、指定された ODE 関数の解を出力します。この例では、畳み込み層と tanh 層を含むニューラル ネットワークを ODE 関数として指定します。

2021-04-20_18-35-22.png

ニューラル ODE ネットワークの入力サイズと出力サイズは一致していなければなりません。ODE 層のニューラル ネットワークの入力サイズを計算する際、次の点に注意します。

  • イメージ分類ネットワークの入力データは、28×28×1 のイメージの配列。

  • イメージは、2 分の 1 にダウンサンプリングする 8 つのフィルターをもつ畳み込み層を通過する。

  • 畳み込み層の出力は、チャネル次元の数を 2 倍にする拡張層を通過する。

これは、ニューラル ODE 層への入力が 14×14×16 の配列であり、空間次元のサイズが 14、チャネル次元のサイズが 16 であることを意味します。畳み込み層が 28×28 のイメージを 2 分の 1 にダウンサンプリングするため、空間サイズは 14 になります。畳み込み層が 8 チャネル (畳み込み層のフィルター数) を出力し、拡張層がチャネル数を 2 倍にするため、チャネル サイズは 16 になります。

ニューラル ODE 層に使用するニューラル ネットワークを作成します。ネットワークには入力層がないため、ネットワークは初期化しません。

numFilters = 8;

layersODE = [
    convolution2dLayer(3,2*numFilters,Padding="same")
    tanhLayer];

netODE = dlnetwork(layersODE,Initialize=false);

イメージ分類ネットワークを作成します。拡張層と拡張破棄層の場合は、それぞれこの例のチャネル拡張関数セクションおよびチャネル拡張破棄関数セクションにリストされている関数 channelAugmentation および関数 discardChannelAugmentation とともに関数層を使用します。これらの関数にアクセスするには、例をライブ スクリプトとして開きます。

inputSize = size(XTrain,1:3);
filterSize = 3;
tspan = [0 0.1];

layers = [
    imageInputLayer(inputSize)
    convolution2dLayer(filterSize,numFilters)
    functionLayer(@channelAugmentation,Acceleratable=true,Formattable=true)
    neuralODELayer(netODE,tspan,GradientMode="adjoint")
    functionLayer(@discardChannelAugmentation,Acceleratable=true,Formattable=true)
    fullyConnectedLayer(numClasses)
    softmaxLayer];

学習オプションの指定

学習オプションを指定します。オプションの中から選択するには、経験的解析が必要です。実験を実行してさまざまな学習オプションの構成を調べるには、Experiment Managerアプリを使用できます。

  • Adam ソルバーを使用して学習させます。

  • 学習率を 0.01 にして学習を行います。

  • すべてのエポックでデータをシャッフルします。

  • 学習の進行状況をプロットで監視し、精度を表示します。

  • 詳細出力を無効にします。

options = trainingOptions("adam", ...
    InitialLearnRate=0.01, ...
    Shuffle="every-epoch", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

関数trainnetを使用してニューラル ネットワークに学習させます。分類には、クロスエントロピー損失を使用します。既定では、関数 trainnet は利用可能な GPU がある場合にそれを使用します。GPU での学習には、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数 trainnet は CPU を使用します。実行環境を指定するには、ExecutionEnvironment 学習オプションを使用します。

net = trainnet(XTrain,TTrain,layers,"crossentropy",options);

モデルのテスト

真のラベルをもつホールドアウトされたテスト セットの予測を比較して、モデルの分類精度をテストします。

テスト データを読み込みます。

load DigitsDataTest
TTest = labelsTest;

学習後に新しいデータについて予測を行う際、ラベルは必要ありません。テスト データの予測子のみを含む minibatchqueue オブジェクトを、次のようにして作成します。

  • ミニバッチ キューの出力数を 1 に設定します。

  • 例のミニバッチ予測前処理関数セクションにリストされている関数 preprocessPredictors を使用して予測子を前処理します。

  • データストアの単一の出力では、ミニバッチの形式 "SSCB" (spatial、spatial、channel、batch) を指定します。

dsTest = arrayDatastore(XTest,IterationDimension=4);

mbqTest = minibatchqueue(dsTest,1, ...
    MiniBatchFormat="SSCB", ...
    MiniBatchFcn=@preprocessPredictors);

ミニバッチをループ処理し、例のモデル予測関数セクションにリストされている関数 modelPredictions を使用してシーケンスを分類します。

YTest = modelPredictions(net,mbqTest,classNames);

混同行列で予測を可視化します。

figure
confusionchart(TTest,YTest)

分類精度を計算します。

accuracy = mean(TTest==YTest)
accuracy = 0.9262

チャネル拡張関数

関数 channelAugmentation は、出力のチャネル数が 2 倍になるように、入力データ X のチャネル次元をパディングします。

function Z = channelAugmentation(X)

idxC = finddim(X,"C");
szC = size(X,idxC);
Z = paddata(X,2*szC,Dimension=idxC);

end

チャネル拡張破棄関数

関数 discardChannelAugmentation は、出力のチャネル数が半分になるように、入力データ X のチャネル次元をトリミングします。

function Z = discardChannelAugmentation(X)

idxC = finddim(X,"C");
szC = size(X,idxC);
Z = trimdata(X,floor(szC/2),Dimension=idxC);

end

モデル予測関数

関数 modelPredictions は、ニューラル ネットワーク、入力データのミニバッチ キュー mbq、およびクラス名を入力として受け取り、すべてのデータを反復処理してモデル予測を計算します。この関数は、関数 onehotdecode を使用して、スコアが最も高い予測されたクラスを見つけます。

function predictions = modelPredictions(net,mbq,classNames)

predictions = [];

while hasdata(mbq)
    X = next(mbq);
    Y = predict(net,X);
    Y = onehotdecode(Y,classNames,1)';
    predictions = [predictions; Y];
end

end

予測子前処理関数

関数 preprocessPredictors は、入力 cell 配列からイメージ データを抽出することで予測子のミニバッチを前処理し、数値配列にデータを連結します。グレースケール入力の場合、4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、大きさが 1 のチャネル次元として使用されます。

function X = preprocessPredictors(dataX)

X = cat(4,dataX{:});

end

参考文献

  1. Chen, Ricky T. Q., Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. “Neural Ordinary Differential Equations.” Preprint, submitted June 19, 2018. https://arxiv.org/abs/1806.07366.

  2. Dupont, Emilien, Arnaud Doucet, and Yee Whye Teh. “Augmented Neural ODEs.” Preprint, submitted October 26, 2019. https://arxiv.org/abs/1904.01681.

参考

| | |

関連するトピック