ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

ベイズ最適化を使用した深層学習

この例では、深層学習にベイズ最適化を適用して、畳み込みニューラル ネットワークに最適なネットワーク ハイパーパラメーターと学習オプションを求める方法を説明します。

深層ニューラル ネットワークに学習させるには、ニューラル ネットワークのアーキテクチャおよび学習アルゴリズムのオプションを指定しなければなりません。これらのハイパーパラメーターの選択と調整は難しく時間がかかることがあります。ベイズ最適化は、分類モデルと回帰モデルのハイパーパラメーターの最適化に非常に適したアルゴリズムです。ベイズ最適化を使用して、微分不可能な関数、不連続な関数、および評価に時間のかかる関数を最適化できます。このアルゴリズムでは、目的関数のガウス過程モデルを内部に保持し、目的関数の評価を使用してこのモデルに学習をさせます。

この例では、以下の方法を説明します。

  • ネットワーク学習用の CIFAR-10 データセットをダウンロードして準備する。このデータセットは、イメージ分類モデルのテストに最も広く使用されているデータセットの 1 つです。

  • ベイズ最適化を使用して最適化する変数を指定する。これらの変数は学習アルゴリズムのオプションであり、ネットワーク アーキテクチャ自体のパラメーターでもあります。

  • 入力として最適化変数の値を取り、ネットワーク アーキテクチャと学習オプションを指定し、ネットワークの学習と検証を行い、学習済みネットワークをディスクに保存する目的関数を定義する。目的関数の定義は、このスクリプトの終わりで行います。

  • 検証セットの分類誤差を最小化することにより、ベイズ最適化を実行する。

  • ディスクから最適なネットワークを読み込み、テスト セットに対して評価する。

データの準備

CIFAR-10 データセット [1] をダウンロードします。このデータセットには 60,000 個のイメージが格納されており、各イメージのサイズは 32 x 32 で 3 つのカラー チャネル (RGB) があります。データセット全体のサイズは 175 MB です。インターネット接続の速度によっては、ダウンロード プロセスに時間がかかることがあります。

datadir = tempdir;
downloadCIFARData(datadir);

CIFAR-10 データセットを学習用のイメージとラベル、およびテスト用のイメージとラベルとして読み込みます。ネットワークの検証を有効にするには、検証用に 5000 個のテスト イメージを使用します。

[XTrain,YTrain,XTest,YTest] = loadCIFARData(datadir);

idx = randperm(numel(YTest),5000);
XValidation = XTest(:,:,:,idx);
XTest(:,:,:,idx) = [];
YValidation = YTest(idx);
YTest(idx) = [];

学習イメージのサンプルを表示します。

figure;
idx = randperm(numel(YTrain),20);
for i = 1:numel(idx)
    subplot(4,5,i);
    imshow(XTrain(:,:,:,idx(i)));
end

最適化する変数の選択

ベイズ最適化を使用して最適化する変数を選択し、検索範囲を指定します。さらに、変数が整数であるかどうか、検索区間を対数空間にするかどうかを指定します。次の変数を最適化します。

  • ネットワーク セクション深さ。このパラメーターはネットワークの深さを制御します。ネットワークには 3 つのセクションがあり、各セクションには SectionDepth 個の同一の畳み込み層があります。そのため、畳み込み層の総数は 3*SectionDepth になります。後でスクリプトで定義する目的関数が取る各層の畳み込みフィルターの数は 1/sqrt(SectionDepth) に比例します。その結果、パラメーター数と各反復に必要な計算量は、セクション深さが異なる場合でもほぼ同じになります。

  • 初期学習率。最適な学習率は、学習させるネットワークにも依りますが、データにも左右されます。

  • 確率的勾配降下モーメンタム。モーメンタムは、前の反復での更新に比例する寄与を現在の更新に含めることによって、パラメーター更新に慣性を追加します。これにより、パラメーター更新がより円滑に行われるようになり、確率的勾配降下に固有のノイズが低減されます。

  • L2 正則化強度。正則化を使用して過適合を防止します。正則化強度の空間を検索して、適切な値を求めます。データ拡張やバッチ正規化も、ネットワークの正則化に役立ちます。

optimVars = [
    optimizableVariable('SectionDepth',[1 3],'Type','integer')
    optimizableVariable('InitialLearnRate',[1e-2 1],'Transform','log')
    optimizableVariable('Momentum',[0.8 0.98])
    optimizableVariable('L2Regularization',[1e-10 1e-2],'Transform','log')];

ベイズ最適化の実行

学習データと検証データを入力として使用して、ベイズ オプティマイザーの目的関数を作成します。目的関数は畳み込みニューラル ネットワークに学習をさせ、検証セットに対してその分類誤差を返します。この関数の定義は、このスクリプトの終わりで行います。bayesopt は検証セットの誤差率を使用して最適なモデルを選択するため、最終的なネットワークが検証セットに過適合する可能性があります。そのため、最終的に選択されたモデルを独立したテスト セットでテストして、汎化誤差を推定します。

ObjFcn = makeObjFcn(XTrain,YTrain,XValidation,YValidation);

検証セットの分類誤差を最小化することにより、ベイズ最適化を実行する。合計最適化時間を秒単位で指定します。ベイズ最適化の能力を最大限に活用するには、目的関数の評価を 30 回以上行う必要があります。複数の GPU でネットワークの並列学習を行うには、'UseParallel' の値を true に設定します。GPU が 1 つで、'UseParallel' の値を true に設定した場合、すべてのワーカーがその GPU を共有するため、学習を高速化できず、GPU のメモリ不足が発生する可能性が高まります。

ネットワークの学習が完了するたびに、bayesopt によってコマンド ウィンドウに結果が表示されます。その後、関数 bayesoptBayesObject.UserDataTrace にファイル名を返します。目的関数は学習済みネットワークをディスクに保存し、bayesopt にファイル名を返します。

BayesObject = bayesopt(ObjFcn,optimVars, ...
    'MaxTime',14*60*60, ...
    'IsObjectiveDeterministic',false, ...
    'UseParallel',false);
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | SectionDepth | InitialLearn-|     Momentum | L2Regulariza-|
|      | result |             | runtime     | (observed)  | (estim.)    |              | Rate         |              | tion         |
|===================================================================================================================================|
|    1 | Best   |        0.19 |        2201 |        0.19 |        0.19 |            3 |     0.012114 |       0.8354 |    0.0010624 |
|    2 | Accept |      0.3224 |      1734.1 |        0.19 |     0.19636 |            1 |     0.066481 |      0.88231 |    0.0026626 |
|    3 | Accept |      0.2076 |      1688.7 |        0.19 |     0.19374 |            2 |     0.022346 |      0.91149 |    8.242e-10 |
|    4 | Accept |      0.1908 |      2167.2 |        0.19 |      0.1904 |            3 |      0.97586 |      0.83613 |   4.5143e-08 |
|    5 | Accept |      0.1972 |      2157.4 |        0.19 |     0.19274 |            3 |      0.21193 |      0.97995 |   1.4691e-05 |
|    6 | Accept |      0.2594 |      2152.8 |        0.19 |        0.19 |            3 |      0.98723 |      0.97931 |   2.4847e-10 |
|    7 | Best   |      0.1882 |      2257.5 |      0.1882 |     0.18819 |            3 |       0.1722 |       0.8019 |   4.2149e-06 |
|    8 | Accept |      0.8116 |      1989.7 |      0.1882 |     0.18818 |            3 |      0.42085 |      0.95355 |    0.0092026 |
|    9 | Accept |      0.1986 |        1836 |      0.1882 |     0.18821 |            2 |     0.030291 |      0.94711 |   2.5062e-05 |
|   10 | Accept |      0.2146 |      1909.4 |      0.1882 |     0.18816 |            2 |     0.013379 |       0.8785 |   7.6354e-09 |
|   11 | Accept |      0.2194 |        1562 |      0.1882 |     0.18815 |            1 |      0.14682 |      0.86272 |   8.6242e-09 |
|   12 | Accept |      0.2246 |      1591.2 |      0.1882 |     0.18813 |            1 |      0.70438 |      0.82809 |   1.0102e-06 |
|   13 | Accept |      0.2648 |      1621.8 |      0.1882 |     0.18824 |            1 |     0.010109 |      0.89989 |   1.0481e-10 |
|   14 | Accept |      0.2222 |        1562 |      0.1882 |     0.18812 |            1 |      0.11058 |      0.97432 |   2.4101e-07 |
|   15 | Accept |      0.2364 |      1625.7 |      0.1882 |     0.18813 |            1 |     0.079381 |       0.8292 |   2.6722e-05 |
|   16 | Accept |        0.26 |      1706.2 |      0.1882 |     0.18815 |            1 |     0.010041 |      0.96229 |   1.1066e-05 |
|   17 | Accept |      0.1986 |      2188.3 |      0.1882 |     0.18635 |            3 |      0.35949 |      0.97824 |    3.153e-07 |
|   18 | Accept |      0.1938 |      2169.6 |      0.1882 |     0.18817 |            3 |     0.024365 |      0.88464 |   0.00024507 |
|   19 | Accept |      0.3588 |      1713.7 |      0.1882 |     0.18216 |            1 |     0.010177 |      0.89427 |    0.0090342 |
|   20 | Accept |      0.2224 |      1721.4 |      0.1882 |     0.18193 |            1 |      0.09804 |      0.97947 |   1.0727e-10 |
|===================================================================================================================================|
| Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   | SectionDepth | InitialLearn-|     Momentum | L2Regulariza-|
|      | result |             | runtime     | (observed)  | (estim.)    |              | Rate         |              | tion         |
|===================================================================================================================================|
|   21 | Accept |      0.1904 |      2184.7 |      0.1882 |     0.18498 |            3 |     0.017697 |      0.95057 |   0.00022247 |
|   22 | Accept |      0.1928 |      2184.4 |      0.1882 |     0.18527 |            3 |      0.06813 |       0.9027 |   1.3521e-09 |
|   23 | Accept |      0.1934 |      2183.6 |      0.1882 |      0.1882 |            3 |     0.018269 |      0.90432 |    0.0003573 |
|   24 | Accept |       0.303 |      1707.9 |      0.1882 |     0.18809 |            1 |     0.010157 |      0.88226 |   0.00088737 |
|   25 | Accept |       0.194 |      2189.1 |      0.1882 |     0.18808 |            3 |     0.019354 |      0.94156 |   9.6197e-07 |
|   26 | Accept |      0.2192 |      1752.2 |      0.1882 |     0.18809 |            1 |      0.99324 |      0.91165 |   1.1521e-08 |
|   27 | Accept |      0.1918 |        2185 |      0.1882 |     0.18813 |            3 |      0.05292 |       0.8689 |   1.2449e-05 |

__________________________________________________________
Optimization completed.
MaxTime of 50400 seconds reached.
Total function evaluations: 27
Total elapsed time: 51962.3666 seconds.
Total objective function evaluation time: 51942.8833

Best observed feasible point:
    SectionDepth    InitialLearnRate    Momentum    L2Regularization
    ____________    ________________    ________    ________________

         3               0.1722          0.8019        4.2149e-06   

Observed objective function value = 0.1882
Estimated objective function value = 0.18813
Function evaluation time = 2257.4627

Best estimated feasible point (according to models):
    SectionDepth    InitialLearnRate    Momentum    L2Regularization
    ____________    ________________    ________    ________________

         3               0.1722          0.8019        4.2149e-06   

Estimated objective function value = 0.18813
Estimated function evaluation time = 2166.2402

最終的なネットワークの評価

最適化で見つかった最適なネットワークと、その検証精度を読み込みます。

bestIdx = BayesObject.IndexOfMinimumTrace(end);
fileName = BayesObject.UserDataTrace{bestIdx};
savedStruct = load(fileName);
valError = savedStruct.valError
valError = 0.1882

テスト セットのラベルを予測して、テスト誤差を計算します。テスト セットの各イメージの分類を特定の成功確率を持つ独立事象として扱います。これは、誤って分類されるイメージの数が二項分布に従うことを意味します。これを使用して、標準誤差 (testErrorSE) と、汎化誤差率の近似した 95% 信頼区間 (testError95CI) を計算します。この方法は通常、"Wald 法" と呼ばれます。bayesopt は、ネットワークにテスト セットを当てることなく、検証セットを使用して最適なネットワークを決定します。これにより、テスト誤差が検証誤差より高くなる可能性があります。

[YPredicted,probs] = classify(savedStruct.trainedNet,XTest);
testError = 1 - mean(YPredicted == YTest)
testError = 0.1864
NTest = numel(YTest);
testErrorSE = sqrt(testError*(1-testError)/NTest);
testError95CI = [testError - 1.96*testErrorSE, testError + 1.96*testErrorSE]
testError95CI = 1×2

    0.1756    0.1972

テスト データの混同行列をプロットします。列と行の要約を使用して、各クラスの適合率と再現率を表示します。

figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
cm = confusionchart(YTest,YPredicted);
cm.Title = 'Confusion Matrix for Test Data';
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';

いくつかのテスト イメージを、予測されたクラスとそのクラスである確率と共に表示します。

figure
idx = randperm(numel(YTest),9);
for i = 1:numel(idx)
    subplot(3,3,i)
    imshow(XTest(:,:,:,idx(i)));
    prob = num2str(100*max(probs(idx(i),:)),3);
    predClass = char(YPredicted(idx(i)));
    label = [predClass,', ',prob,'%'];
    title(label)
end

最適化の目的関数

最適化の目的関数を定義します。この関数は、以下の手順を実行します。

  1. 入力として最適化変数の値を取ります。bayesopt は、各列の名前が変数名と等しい table に含まれる最適化変数の現在の値を使用して、目的関数を呼び出します。たとえば、ネットワーク セクション深さの現在の値は optVars.SectionDepth です。

  2. ネットワーク アーキテクチャと学習オプションを定義します。

  3. ネットワークの学習と検証を行います。

  4. 学習済みネットワーク、検証誤差、および学習オプションをディスクに保存します。

  5. 検証誤差と保存されたネットワークのファイル名を返します。

function ObjFcn = makeObjFcn(XTrain,YTrain,XValidation,YValidation)
ObjFcn = @valErrorFun;
    function [valError,cons,fileName] = valErrorFun(optVars)

畳み込みニューラル ネットワーク アーキテクチャを定義します。

  • 空間の出力サイズが入力サイズと常に同じになるように、畳み込み層にパディングを追加します。

  • 最大プーリング層を使用して係数 2 で空間次元をダウンサンプリングするたびに、フィルターの数を 2 倍ずつ増加させます。それにより、各畳み込み層で必要な計算量がほぼ同じになります。

  • フィルターの数を 1/sqrt(SectionDepth) に比例するように選択し、深さの異なるネットワークにほぼ同じ数のパラメーターがあり、反復ごとにほぼ同じ量の計算が必要になるようにします。ネットワーク パラメーターの数を増やし、ネットワーク全体の柔軟性を高めるには、numF を大きくします。さらに深いネットワークに学習させるには、変数 SectionDepth の範囲を変更します。

  • convBlock(filterSize,numFilters,numConvLayers) を使用して、numConvLayers 個の畳み込み層のブロックを作成します。各層には filterSizenumFilters 個のフィルターを指定し、それぞれの後にバッチ正規化層と ReLU 層を配置します。関数 convBlock の定義は、この例の終わりで行います。

        imageSize = [32 32 3];
        numClasses = numel(unique(YTrain));
        numF = round(16/sqrt(optVars.SectionDepth));
        layers = [
            imageInputLayer(imageSize)
            
            % The spatial input and output sizes of these convolutional
            % layers are 32-by-32, and the following max pooling layer
            % reduces this to 16-by-16.
            convBlock(3,numF,optVars.SectionDepth)
            maxPooling2dLayer(3,'Stride',2,'Padding','same')
            
            % The spatial input and output sizes of these convolutional
            % layers are 16-by-16, and the following max pooling layer
            % reduces this to 8-by-8.
            convBlock(3,2*numF,optVars.SectionDepth)
            maxPooling2dLayer(3,'Stride',2,'Padding','same')
            
            % The spatial input and output sizes of these convolutional
            % layers are 8-by-8. The global average pooling layer averages
            % over the 8-by-8 inputs, giving an output of size
            % 1-by-1-by-4*initialNumFilters. With a global average
            % pooling layer, the final classification output is only
            % sensitive to the total amount of each feature present in the
            % input image, but insensitive to the spatial positions of the
            % features.
            convBlock(3,4*numF,optVars.SectionDepth)
            averagePooling2dLayer(8)
            
            % Add the fully connected layer and the final softmax and
            % classification layers.
            fullyConnectedLayer(numClasses)
            softmaxLayer
            classificationLayer];

ネットワーク学習のオプションを指定します。初期学習率、SGD モーメンタム、および L2 正則化強度を最適化します。

検証データを指定し、trainNetwork によってエポックごとに 1 回ネットワークが検証されるように、'ValidationFrequency' の値を選択します。一定のエポック数の学習を行い、最後の方のエポックでは学習率を 10 分の 1 に下げます。これにより、パラメーター更新のノイズが減少し、ネットワーク パラメーターが安定して損失関数が最小値に近づくようになります。

        miniBatchSize = 256;
        validationFrequency = floor(numel(YTrain)/miniBatchSize);
        options = trainingOptions('sgdm', ...
            'InitialLearnRate',optVars.InitialLearnRate, ...
            'Momentum',optVars.Momentum, ...
            'MaxEpochs',60, ...
            'LearnRateSchedule','piecewise', ...
            'LearnRateDropPeriod',40, ...
            'LearnRateDropFactor',0.1, ...
            'MiniBatchSize',miniBatchSize, ...
            'L2Regularization',optVars.L2Regularization, ...
            'Shuffle','every-epoch', ...
            'Verbose',false, ...
            'Plots','training-progress', ...
            'ValidationData',{XValidation,YValidation}, ...
            'ValidationFrequency',validationFrequency);

データ拡張を使用して、縦軸に沿って学習イメージをランダムに反転させ、水平方向および垂直方向に最大 4 ピクセルだけランダムに平行移動させます。データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。

        pixelRange = [-4 4];
        imageAugmenter = imageDataAugmenter( ...
            'RandXReflection',true, ...
            'RandXTranslation',pixelRange, ...
            'RandYTranslation',pixelRange);
        datasource = augmentedImageDatastore(imageSize,XTrain,YTrain,'DataAugmentation',imageAugmenter);

ネットワークに学習させ、学習中にその進行状況をプロットします。学習が終了したら、すべての学習プロットを閉じます。

        trainedNet = trainNetwork(datasource,layers,options);
        close(findall(groot,'Tag','NNET_CNN_TRAININGPLOT_FIGURE'))

検証セットに対して学習済みネットワークを評価して、予測イメージ ラベルを計算し、検証データにおける誤差率を計算します。

        YPredicted = classify(trainedNet,XValidation);
        valError = 1 - mean(YPredicted == YValidation);

検証誤差を含むファイル名を作成し、ネットワーク、検証誤差、および学習オプションをディスクに保存します。目的関数は出力引数として fileName を返し、bayesoptBayesObject.UserDataTrace にすべてのファイル名を返します。別途要求された出力引数 cons は、変数間の制約を指定します。変数の制約はありません。

        fileName = num2str(valError) + ".mat";
        save(fileName,'trainedNet','valError','options')
        cons = [];
        
    end
end

関数 convBlock は、numConvLayers 畳み込み層のブロックを作成します。各層には filterSizenumFilters 個のフィルターを指定し、それぞれの後にバッチ正規化層と ReLU 層を配置します。

function layers = convBlock(filterSize,numFilters,numConvLayers)
layers = [
    convolution2dLayer(filterSize,numFilters,'Padding','same')
    batchNormalizationLayer
    reluLayer];
layers = repmat(layers,numConvLayers,1);
end

参照

[1] Krizhevsky, Alex. "Learning multiple layers of features from tiny images." (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf

参考

| |

関連するトピック