ドキュメンテーション

最新のリリースでは、このページがまだ翻訳されていません。 このページの最新版は英語でご覧になれます。

parfeval を使用した複数の深層学習ネットワークの学習

この例では、学習中に parfeval を使用して、深層学習ネットワークのネットワーク アーキテクチャの深さについてパラメーター スイープを行い、データを取得する方法を説明します。

深層学習での学習は多くの場合、数時間または数日を要し、良好なアーキテクチャの探索は困難なことがあります。並列計算を使用して、良好なモデルの探索の高速化および自動化ができます。複数のグラフィック処理ユニット (GPU) を搭載したマシンにアクセスできる場合は、ローカル並列プールを使用して、データセットのローカル コピーでこの例を実行できます。さらに多くのリソースが必要な場合は、深層学習をクラウドにスケール アップできます。この例では、parfeval を使用してクラウド上のクラスター内にあるネットワーク アーキテクチャの深さについてパラメーター スイープを実行する方法を説明します。parfeval を使用することで、MATLAB をブロックすることなくバックグラウンドで学習させることができ、結果に満足した場合の早期停止オプションがあります。スクリプトを編集して、他の任意のパラメーターについてパラメーター スイープを実行できます。また、この例では DataQueue を使用して計算中にワーカーからフィードバックを取得する方法も説明します。

要件

この例を実行するには、クラスターを構成し、データをクラウドにアップロードしなければなりません。MATLAB では、MATLAB デスクトップから直接クラウドにクラスターを作成できます。[ホーム] タブの [並列] メニューで [クラスターの作成と管理] を選択します。クラスター プロファイル マネージャーで、[クラウド クラスターの作成] をクリックします。あるいは、MathWorks Cloud Center を使用して、計算クラスターの作成およびアクセスができます。詳細については、Getting Started with Cloud Center を参照してください。この例では、MATLAB の [ホーム] タブの [並列][既定のクラスターの選択] で、使用するクラスターを確実に既定として設定します。その後、Amazon S3 バケットにデータをアップロードし、MATLAB から直接使用します。この例では、既に Amazon S3 に保存されている CIFAR-10 データセットのコピーを使用します。詳細については、クラウドへの深層学習データのアップロード (Deep Learning Toolbox)を参照してください。

クラウドからのデータセットの読み込み

imageDatastore を使用して、クラウドから学習データセットおよびテスト データセットを読み込みます。学習データセットを学習セットと検証セットに分割し、パラメーター スイープから最良のネットワークをテストするためにテスト データセットを保持しておきます。この例では、Amazon S3 に保存されている CIFAR-10 データセットのコピーを使用します。クラウド内のデータ ストアへのアクセス権をワーカーが確実にもつように、AWS 認証情報の環境変数が正しく設定されていることを確認してください。クラウドへの深層学習データのアップロード (Deep Learning Toolbox)を参照してください。

imds = imageDatastore('s3://cifar10cloud/cifar10/train', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

imdsTest = imageDatastore('s3://cifar10cloud/cifar10/test', ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);

augmentedImageDatastore オブジェクトを作成し、拡張イメージ データを使用してネットワークに学習させます。ランダムな平行移動と水平方向の反射パターンを使用します。データ拡張は、ネットワークによる過適合と、学習イメージそのものの細部の記憶を防ぐ上で役立ちます。

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange);
augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ...
    'DataAugmentation',imageAugmenter, ...
    'OutputSizeMode','randcrop');

複数のネットワークで同時に学習

学習オプションを定義します。ミニバッチ サイズを設定し、ミニバッチ サイズに従って初期学習率を線形にスケーリングします。trainNetwork がエポックごとに 1 回ネットワークを検証するように、検証頻度を設定します。

miniBatchSize = 128;
initialLearnRate = 1e-1 * miniBatchSize/256;
validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'MiniBatchSize',miniBatchSize, ... % Set the mini-batch size
    'Verbose',false, ... % Do not send command line output.
    'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate.
    'L2Regularization',1e-10, ...
    'MaxEpochs',30, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsValidation, ...
    'ValidationFrequency', validationFrequency);

パラメーター スイープの対象とするネットワーク アーキテクチャの深さを指定します。parfeval を使用して、並列パラメーター スイープ学習をいくつかのネットワークで同時に実行します。ループを使用して、スイープにおいて様々なネットワーク アーキテクチャを反復します。スクリプトの末尾に補助関数 createNetworkArchitecture を作成します。これは、ネットワークの深さを制御するための入力引数を受け入れ、CIFAR-10 用のアーキテクチャを作成します。parfeval を使用して、trainNetwork によって実行される計算をクラスター内のワーカーにオフロードします。parfeval は計算の終了時に学習済みネットワークおよび学習情報を保持する future 変数を返します。

netDepths = 1:4;
for idx = 1:numel(netDepths)
    networksFuture(idx) = parfeval(@trainNetwork,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options);
end
Starting parallel pool (parpool) using the 'MyCluster' profile ...
Connected to the parallel pool (number of workers: 4).

parfeval は MATLAB をブロックしません。つまり、コマンドの実行を続けることができます。この例では、fetchOutputsnetworksFuture に対して使用して、学習済みネットワークおよびその学習情報を取得します。関数 fetchOutputs は future 変数が完了するまで待機します。

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

trainingInfo 構造体にアクセスし、特定のネットワークの最終的な検証精度を取得します。たとえば、1 番目のネットワークの精度を取得します。

accuracy = trainingInfo(1).ValidationAccuracy(end)
accuracy = 72.7600

最終的な検証精度をすべて取得するには、cellfun を使用します。

accuracies = cellfun(@(x) x(end),{trainingInfo.ValidationAccuracy})
accuracies = 1×4

   72.7600   77.7000   77.5000   76.1200

精度が最良のネットワークを選択します。テスト データセットに対するそのパフォーマンスをテストします。

[~, I] = max(accuracies);
bestNetwork = trainedNetworks(I(1));
YPredicted = classify(bestNetwork,imdsTest);
accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.7732

テスト データの混同行列を計算します。

figure('Units','normalized','Position',[0.2 0.2 0.4 0.4]);
confusionchart(imdsTest.Labels,YPredicted,'RowSummary','row-normalized','ColumnSummary','column-normalized');

学習中のフィードバック データの送信

各ワーカーの学習の進行状況を示すプロットを準備して初期化します。変化するデータを表示する便利な方法である animatedLine を使用します。

f = figure;
f.Visible = true;
for i=1:4
    subplot(2,2,i)
    xlabel('Iteration');
    ylabel('Training accuracy');
    lines(i) = animatedline;
end

DataQueue を使用して、ワーカーからクライアントに学習の進行状況データを送信してから、データをプロットします。afterEach を使用して、ワーカーから学習の進行状況のフィードバックが送信されるたびにプロットを更新します。パラメーター opts はワーカー、学習反復、学習精度に関する情報を含みます。

D = parallel.pool.DataQueue;
afterEach(D, @(opts) updatePlot(lines, opts{:}));

パラメーター スイープを実行する対象のネットワーク アーキテクチャの深さを指定し、parfeval を使用して並列パラメーター スイープを実行します。現在のプールにスクリプトを添付ファイルとして追加して、ワーカーがこのスクリプト内のすべての補助関数にアクセスできるようにします。学習オプションに、ワーカーからクライアントに学習の進行状況を送信する出力関数を定義します。学習オプションはワーカーのインデックスに依存するため、for ループ内に含めなければなりません。

netDepths = 1:4;
addAttachedFiles(gcp,mfilename);
for idx = 1:numel(netDepths)
    
    miniBatchSize = 128;
    initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size.
    validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
    
    options = trainingOptions('sgdm', ...
        'OutputFcn',@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client.
        'MiniBatchSize',miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep.
        'Verbose',false, ... % Do not send command line output.
        'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate.
        'L2Regularization',1e-10, ...
        'MaxEpochs',30, ...
        'Shuffle','every-epoch', ...
        'ValidationData',imdsValidation, ...
        'ValidationFrequency', validationFrequency);
    
    networksFuture(idx) = parfeval(@trainNetwork,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),options);
end

parfeval は、クラスターのワーカー上で trainNetwork を呼び出します。計算はバックグラウンドで行われるため、MATLAB での作業を続行できます。parfeval の計算を停止するには、対応する future 変数に対して cancel を呼び出します。たとえば、あるネットワークのパフォーマンスが低いと確認された場合、その future をキャンセルできます。これを行うと、キュー内の次の future 変数がその計算を開始します。

この例では、future 変数に対して fetchOutputs を呼び出すことにより、学習済みネットワークおよびその学習情報を取得します。

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

cellfun を使用して、各ネットワークの最終的な検証精度を取得します。

accuracies = cellfun(@(x) x(end),{trainingInfo.ValidationAccuracy})
accuracies = 1×4

   72.9200   77.4800   76.9200   77.0400

補助関数

関数を使用して CIFAR-10 データセットのネットワーク アーキテクチャを定義し、入力引数を使用してネットワークの深さを調整します。コードを簡略化するために、入力を畳み込む畳み込みブロックを使用します。プーリング層は空間次元をダウンサンプリングします。

function layers = createNetworkArchitecture(netDepth)
imageSize = [32 32 3];
netWidth = round(16/sqrt(netDepth)); % netWidth controls the number of filters in a convolutional block

layers = [
    imageInputLayer(imageSize)
    
    convolutionalBlock(netWidth,netDepth)
    maxPooling2dLayer(2,'Stride',2)
    convolutionalBlock(2*netWidth,netDepth)
    maxPooling2dLayer(2,'Stride',2)
    convolutionalBlock(4*netWidth,netDepth)
    averagePooling2dLayer(8)
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer
    ];
end

ネットワーク アーキテクチャ内で畳み込みブロックを作成する関数を定義します。

function layers = convolutionalBlock(numFilters,numConvLayers)
layers = [
    convolution2dLayer(3,numFilters,'Padding','same')
    batchNormalizationLayer
    reluLayer
    ];

layers = repmat(layers,numConvLayers,1);
end

DataQueue を介して学習の進行状況をクライアントに送信する関数を定義します。

function sendTrainingProgress(D,idx,info)
if info.State == "iteration"
    send(D,{idx,info.Iteration,info.TrainingAccuracy});
end
end

ワーカーが中間結果を送信したときにプロットを更新する、更新関数を定義します。

function updatePlot(lines,idx,iter,acc)
addpoints(lines(idx),iter,acc);
drawnow limitrate nocallbacks
end

参考

| | | |

関連するトピック