ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

parfeval を使用した複数の深層学習ネットワークの学習

この例では、ネットワーク アーキテクチャの深さについてのパラメーター スイープに parfeval を使用する方法を説明します。多くの場合、深層学習における学習には数時間または数日かかるため、適切なアーキテクチャを見つけるのが難しい場合があります。並列計算を使用して、適切なモデルの探索を高速化および自動化することができます。複数の GPU があるマシンにアクセスできる場合は、ローカルの parpool を使用してデータセットのローカル コピーに対してこのスクリプトを実行できます。より多くのリソースを使用する必要がある場合は、深層学習における学習をクラウドにスケールアップできます。この例では、parfeval を使用して、クラウドのクラスターにおけるネットワーク アーキテクチャの深さについてのパラメーター スイープを実行する方法を説明します。parfeval を使用すると、MATLAB をブロックすることなくバックグラウンドで学習を実行でき、結果が十分適切な場合に早期の停止を選択できます。このスクリプトを変更して、他のパラメーターについてのパラメーター スイープを実行できます。また、この例では、DataQueue を使用して計算中にワーカーからフィードバックを取得する方法も説明します。

この例を実行する前に、クラスターを構成し、データをクラウドにアップロードする必要があります。クラウドでの作業を始めるには、Cloud Center を設定して、それを Amazon Web Services (AWS) アカウントにリンクし、クラスターを作成します。手順については、Getting Started with Cloud Center を参照してください。その後、データを Amazon S3 バケットにアップロードして、MATLAB から直接使用します。例については、クラウドへの深層学習データのアップロードを参照してください。

クラウドからのデータセットの読み込み

imageDatastore を使用して、学習データセットとテスト データセットをクラウドから読み込みます。学習データセットは学習用と検証用に分割し、テスト データセットはパラメーター スイープから得られる最適なネットワークをテストするために保持します。この例では、Amazon S3 に既に格納されている CIFAR-10 データのコピーを使用する方法を説明します。ワーカーがクラウドのデータストアに確実にアクセスできるように、AWS 資格情報の環境変数が正しく設定されていることを確認してください。手順については、クラウドへの深層学習データのアップロードを参照してください。

imds = imageDatastore('s3://cifar10cloud/cifar10/train', ...
 'IncludeSubfolders',true, ...
 'LabelSource','foldernames');

imdsTest = imageDatastore('s3://cifar10cloud/cifar10/test', ...
 'IncludeSubfolders',true, ...
 'LabelSource','foldernames');

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);

拡張イメージ データを使用してネットワークに学習させるには、augmentedImageDatastore オブジェクトを作成します。ランダムな平行移動と水平方向の反転を使用します。データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ...
    'DataAugmentation',imageAugmenter, ...
    'OutputSizeMode','randcrop');

複数のネットワークの同時学習

学習オプションを定義します。ミニバッチ サイズを設定し、ミニバッチ サイズに応じて初期学習率を線形にスケーリングします。trainNetwork でエポックごとに 1 回ネットワークが検証されるように、検証頻度を設定します。

miniBatchSize = 128;
initialLearnRate = 1e-1 * miniBatchSize/256;
validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ... % Set the mini-batch size
    'Verbose',false, ... % Do not send command line output.
    'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate.
    'L2Regularization',1e-10, ...
    'MaxEpochs',40, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsValidation, ...
    'ValidationPatience',Inf, ...
    'ValidationFrequency', validationFrequency, ...
    'LearnRateSchedule','piecewise', ...
    'LearnRateDropFactor',0.1, ...
    'LearnRateDropPeriod',35);

パラメーター スイープを実行するネットワーク アーキテクチャの深さを選択します。parfeval を使用して複数のネットワークに同時に学習させる並列パラメーター スイープを実行します。ループを使用して、スイープ内で異なるネットワーク アーキテクチャについて繰り返します。ネットワークの深さを制御する入力引数を取り、CIFAR-10 のアーキテクチャを作成する補助関数 createNetworkArchitecture を作成します。parfeval を使用して、trainNetwork によって実行される計算の負荷をクラスター内のワーカーに移します。parfeval は、計算が完了すると、学習済みネットワークを保持する future 変数を返します。ワーカーから学習の進行状況プロットの形式でフィードバックを取得するには、以下の「学習時のフィードバック データの送信」を参照してください。

netDepths = 1:4;
for idx = 1:numel(netDepths)
    networksFuture(idx) = parfeval(@trainNetwork, 1, ...
        augmentedImdsTrain, createNetworkArchitecture(netDepths(idx)), options);
end

parfeval は MATLAB をブロックしないため、コマンドの実行を継続できます。fetchNext を使用すると、次のネットワークの学習が完了するのを待って、残りのネットワークの学習が行われている間にそのネットワークを使用できます。たとえば、その精度を確認して、目的のしきい値を超えている場合、future 配列に対して cancel を呼び出して、残りの future をキャンセルできます。または、関数 afterEach を使用して、ネットワークの学習が終了した直後に学習済みのネットワークに対してある関数を自動的に呼び出すこともできます。次のコードは、各ネットワークの学習が終了した直後、残りのネットワークの学習中に精度を計算します。ネットワークの精度を求めるには、検証データセットを使用して分類し、予測ラベルを検証データセットのラベルと比較します。

accuraciesFuture = afterEach(networksFuture, @(network) mean(classify(network,imdsValidation) == imdsValidation.Labels), 1);

学習済みネットワークとその精度を取得するには、networksFutureaccuraciesFuture に対して fetchOutputs を使用します。fetchOutputs は、future の計算が終了するまで待機します。cell 配列出力を取得するには、名前と値のペアの引数 UniformOutputfalse に設定します。

trainedNetworks = fetchOutputs(networksFuture, 'UniformOutput', false)
accuracies = fetchOutputs(accuraciesFuture)
trainedNetworks =

  4×1 cell array

    {1×1 SeriesNetwork}
    {1×1 SeriesNetwork}
    {1×1 SeriesNetwork}
    {1×1 SeriesNetwork}


accuracies =

    0.7540
    0.8000
    0.8054
    0.8060

精度に基づいて最適なネットワークを選択します。テスト データセットに対してそのパフォーマンスをテストします。

[~, I] = max(accuracies);
bestNetwork = trainedNetworks{I(1)};
YPredicted = classify(bestNetwork,imdsTest);
accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy =

    0.8033

テスト データの混同行列を計算し、ヒートマップとして表示します。

figure
[cmat,classNames] = confusionmat(imdsTest.Labels,YPredicted);
h = heatmap(classNames,classNames,cmat);
xlabel('Predicted Class');
ylabel('True Class');
title('Confusion Matrix');

学習時のフィードバック データの送信

各ワーカーの学習の進行状況を示すプロットを準備して初期化します。animatedLine を使用すると、変化するデータを簡単に表示できます。

f = figure;
for i=1:4
    subplot(2,2,i)
    xlabel('Iteration');
    ylabel('Training accuracy');
    lines(i) = animatedline;
end

DataQueue を使用して学習の進行状況データをワーカーからクライアントに送信し、プロットします。afterEach を使用して、ワーカーから学習の進行状況フィードバックが送信されるたびにプロットを更新します。パラメーター opts には、ワーカー、学習の反復、および学習精度に関する情報が含まれます。

D = parallel.pool.DataQueue;
afterEach(D, @(opts) updatePlot(lines, opts{:}));

パラメーター スイープを実行するネットワーク アーキテクチャの深さを選択し、parfeval を使用して並列パラメーター スイープを実行します。このスクリプトでワーカーが補助関数にアクセスできるようにするために、スクリプトを添付ファイルとして現在のプールに追加します。学習オプションで出力関数を定義して、学習の進行状況をワーカーからクライアントに送信します。学習オプションは、ワーカーのインデックスに依存するため、for ループ内に含める必要があります。

netDepths = 1:4;
addAttachedFiles(gcp,mfilename);
for idx = 1:numel(netDepths)

    miniBatchSize = 128;
    initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size.
    validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);

    options = trainingOptions('sgdm', ...
        'OutputFcn',@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client.
        'MiniBatchSize',miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep.
        'Verbose',false, ... % Do not send command line output.
        'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate.
        'L2Regularization',1e-10, ...
        'MaxEpochs',40, ...
        'Shuffle','every-epoch', ...
        'ValidationData',imdsValidation, ...
        'ValidationPatience',Inf, ...
        'ValidationFrequency', validationFrequency, ...
        'LearnRateSchedule','piecewise', ...
        'LearnRateDropFactor',0.1, ...
        'LearnRateDropPeriod',35);

    networksFuture(idx) = parfeval(@trainNetwork, 1, ...
        augmentedImdsTrain, createNetworkArchitecture(netDepths(idx)), options);
end

afterEach を使用して、自動的に分類を行い、各ネットワークの学習終了後にその精度を求めます。

accuraciesFuture = afterEach(networksFuture, @(network) mean(classify(network,imdsValidation) == imdsValidation.Labels), 1);

future 変数に対して fetchOutputs を呼び出して、学習済みネットワークを取得し、その精度を求めます。

trainedNetworks = fetchOutputs(networksFuture, 'UniformOutput', false)
accuracies = fetchOutputs(accuraciesFuture)
trainedNetworks =

  4×1 cell array

    {1×1 SeriesNetwork}
    {1×1 SeriesNetwork}
    {1×1 SeriesNetwork}
    {1×1 SeriesNetwork}


accuracies =

    0.7540
    0.8000
    0.8054
    0.8060

補助関数

CIFAR-10 用のネットワーク アーキテクチャを定義します。コードを簡略化するために、入力を畳み込む複数の畳み込み層を含む、畳み込みブロックを使用します。プーリング層は空間次元をダウンサンプリングします。入力引数を使用してネットワークの深さを制御します。

function layers = createNetworkArchitecture(netDepth)
    imageSize = [32 32 3];
    netWidth = round(16/sqrt(netDepth)); % netWidth controls the number of filters in a convolutional block

    layers = [
        imageInputLayer(imageSize)

        convolutionalBlock(netWidth,netDepth)
        maxPooling2dLayer(2,'Stride',2)
        convolutionalBlock(2*netWidth,netDepth)
        maxPooling2dLayer(2,'Stride',2)
        convolutionalBlock(4*netWidth,netDepth)
        averagePooling2dLayer(8)

        fullyConnectedLayer(10)
        softmaxLayer
        classificationLayer
    ];
end

ネットワーク アーキテクチャでの畳み込みブロックの作成を容易にする関数を定義します。

function layers = convolutionalBlock(numFilters,numConvLayers)
    layers = [
        convolution2dLayer(3,numFilters,'Padding','same')
        batchNormalizationLayer
        reluLayer
    ];

    layers = repmat(layers,numConvLayers,1);
end

DataQueue を通じて学習の進行状況をクライアントに送信する関数を定義します。

function sendTrainingProgress(D,idx,info)
    if info.State == "iteration"
        send(D,{idx,info.Iteration,info.TrainingAccuracy});
    end
end

ワーカーが中間結果を送信したときにプロットを更新する更新関数を定義します。

function updatePlot(lines,idx,iter,acc)
    addpoints(lines(idx),iter,acc);
    drawnow limitrate nocallbacks
end

参考

| | | |

関連するトピック