このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。
parfor を使用した複数の深層学習ネットワークの学習
この例では、parfor
ループを使用して、学習オプションについてのパラメーター スイープを実行する方法を説明します。
多くの場合、深層学習における学習には数時間または数日かかるため、適切な学習オプションを見つけるのが難しい場合があります。並列計算を使用すると、適切なモデルの探索を高速化および自動化できます。複数のグラフィックス処理装置 (GPU) があるマシンにアクセスできる場合は、ローカルの parpool を使用してデータセットのローカル コピーに対してこの例を完了させることができます。より多くのリソースを使用する必要がある場合は、深層学習における学習をクラウドにスケールアップできます。この例では、parfor ループを使用して、クラウドのクラスターにおける学習オプション MiniBatchSize
についてのパラメーター スイープを実行する方法を説明します。スクリプトを変更して、他の学習オプションについてのパラメーター スイープを実行できます。また、この例では、DataQueue
を使用して計算中にワーカーからフィードバックを取得する方法も説明します。スクリプトをバッチ ジョブとしてクラスターに送信することもできるため、作業を続行したり、MATLAB を閉じて後で結果を取得したりできます。詳細については、深層学習バッチ ジョブのクラスターへの送信を参照してください。
要件
この例を実行する前に、クラスターを構成し、データをクラウドにアップロードする必要があります。MATLAB では、MATLAB デスクトップから直接、クラウドにクラスターを作成できます。[ホーム] タブの [並列] メニューで、[クラスターの作成と管理] を選択します。クラスター プロファイル マネージャーで、[クラウド クラスターの作成] をクリックします。または、MathWorks Cloud Center を使用して計算クラスターを作成し、そのクラスターにアクセスすることもできます。詳細については、Getting Started with Cloud Center を参照してください。この例では、MATLAB の [ホーム] タブの [並列]、[既定のクラスターの選択] で、クラスターが既定として設定されていることを確認します。その後、データを Amazon S3 バケットにアップロードして、MATLAB から直接使用します。この例では、Amazon S3 に既に格納されている CIFAR-10 データセットのコピーを使用します。手順については、クラウドへの深層学習データのアップロードを参照してください。
クラウドからのデータセットの読み込み
imageDatastore
を使用して、学習データセットとテスト データセットをクラウドから読み込みます。学習データセットは学習セットと検証セットに分割し、テスト データセットはパラメーター スイープから得られる最適なネットワークをテストするために保持します。この例では、Amazon S3 に格納されている CIFAR-10 データセットのコピーを使用します。ワーカーがクラウドのデータストアに確実にアクセスできるように、AWS 資格情報の環境変数が正しく設定されていることを確認してください。クラウドへの深層学習データのアップロードを参照してください。
imds = imageDatastore('s3://cifar10cloud/cifar10/train', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); imdsTest = imageDatastore('s3://cifar10cloud/cifar10/test', ... 'IncludeSubfolders',true, ... 'LabelSource','foldernames'); [imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);
augmentedImageDatastore
オブジェクトを作成し、拡張イメージ データを使用してネットワークに学習させます。ランダムな平行移動と水平方向の反転を使用します。データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。
imageSize = [32 32 3]; pixelRange = [-4 4]; imageAugmenter = imageDataAugmenter( ... 'RandXReflection',true, ... 'RandXTranslation',pixelRange, ... 'RandYTranslation',pixelRange); augmentedImdsTrain = augmentedImageDatastore(imageSize,imdsTrain, ... 'DataAugmentation',imageAugmenter, ... 'OutputSizeMode','randcrop');
ネットワーク アーキテクチャの定義
CIFAR-10 データセット用のネットワーク アーキテクチャを定義します。コードを簡略化するために、入力を畳み込む畳み込みブロックを使用します。プーリング層は空間次元をダウンサンプリングします。
imageSize = [32 32 3]; netDepth = 2; % netDepth controls the depth of a convolutional block netWidth = 16; % netWidth controls the number of filters in a convolutional block layers = [ imageInputLayer(imageSize) convolutionalBlock(netWidth,netDepth) maxPooling2dLayer(2,'Stride',2) convolutionalBlock(2*netWidth,netDepth) maxPooling2dLayer(2,'Stride',2) convolutionalBlock(4*netWidth,netDepth) averagePooling2dLayer(8) fullyConnectedLayer(10) softmaxLayer classificationLayer ];
複数のネットワークの同時学習
パラメーター スイープを実行するミニバッチ サイズを指定します。結果として得られるネットワークおよび精度のための変数を割り当てます。
miniBatchSizes = [64 128 256 512]; numMiniBatchSizes = numel(miniBatchSizes); trainedNetworks = cell(numMiniBatchSizes,1); accuracies = zeros(numMiniBatchSizes,1);
parfor
ループ内で複数のネットワークに学習させ、ミニバッチ サイズを変化させる並列パラメーター スイープを実行します。クラスター内のワーカーは複数のネットワークに同時に学習させ、学習が完了すると、学習済みネットワークと精度を送り返します。学習が適切に行われていることを確認する場合は、学習オプションで Verbose
を true
に設定します。ワーカーは個別に計算を行うため、コマンド ライン出力は反復と同じ順序にはならないことに注意してください。
parfor idx = 1:numMiniBatchSizes miniBatchSize = miniBatchSizes(idx); initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size. % Define the training options. Set the mini-batch size. options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep. 'Verbose',false, ... % Do not send command line output. 'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate. 'L2Regularization',1e-10, ... 'MaxEpochs',30, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropFactor',0.1, ... 'LearnRateDropPeriod',25); % Train the network in a worker in the cluster. net = trainNetwork(augmentedImdsTrain,layers,options); % To obtain the accuracy of this network, use the trained network to % classify the validation images on the worker and compare the predicted labels to the % actual labels. YPredicted = classify(net,imdsValidation); accuracies(idx) = sum(YPredicted == imdsValidation.Labels)/numel(imdsValidation.Labels); % Send the trained network back to the client. trainedNetworks{idx} = net; end
Starting parallel pool (parpool) using the 'MyClusterInTheCloud' profile ... Connected to the parallel pool (number of workers: 4).
parfor
が終了すると、trainedNetworks
にはワーカーによる学習の結果として得られたネットワークが含まれます。学習済みネットワークとその精度を表示します。
trainedNetworks
trainedNetworks = 4×1 cell array
{1×1 SeriesNetwork}
{1×1 SeriesNetwork}
{1×1 SeriesNetwork}
{1×1 SeriesNetwork}
accuracies
accuracies = 4×1
0.8188
0.8232
0.8162
0.8050
精度に基づいて最適なネットワークを選択します。テスト データセットに対してそのパフォーマンスをテストします。
[~, I] = max(accuracies); bestNetwork = trainedNetworks{I(1)}; YPredicted = classify(bestNetwork,imdsTest); accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.8173
学習時のフィードバック データの送信
各ワーカーの学習の進行状況を示すプロットを準備して初期化します。animatedLine
を使用すると、変化するデータを簡単に表示できます。
f = figure; f.Visible = true; for i=1:4 subplot(2,2,i) xlabel('Iteration'); ylabel('Training accuracy'); lines(i) = animatedline; end
DataQueue
を使用して学習の進行状況データをワーカーからクライアントに送信してから、データをプロットします。afterEach
を使用して、ワーカーから学習の進行状況フィードバックが送信されるたびにプロットを更新します。パラメーター opts
には、ワーカー、学習の反復、および学習精度に関する情報が含まれます。
D = parallel.pool.DataQueue; afterEach(D, @(opts) updatePlot(lines, opts{:}));
異なるミニバッチ サイズによる parfor ループ内で複数のネットワークに学習させる並列パラメーター スイープを実行します。反復ごとに学習の進行状況をクライアントに送信するために、学習オプションで OutputFcn
を使用することに注意してください。この図は、4 つの異なるワーカーについて、下記のコードの実行時における学習の進行状況を示しています。
parfor idx = 1:numel(miniBatchSizes) miniBatchSize = miniBatchSizes(idx); initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the miniBatchSize. % Define the training options. Set an output function to send data back % to the client each iteration. options = trainingOptions('sgdm', ... 'MiniBatchSize',miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep. 'Verbose',false, ... % Do not send command line output. 'InitialLearnRate',initialLearnRate, ... % Set the scaled learning rate. 'OutputFcn',@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client. 'L2Regularization',1e-10, ... 'MaxEpochs',30, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropFactor',0.1, ... 'LearnRateDropPeriod',25); % Train the network in a worker in the cluster. The workers send % training progress information during training to the client. net = trainNetwork(augmentedImdsTrain,layers,options); % To obtain the accuracy of this network, use the trained network to % classify the validation images on the worker and compare the predicted labels to the % actual labels. YPredicted = classify(net,imdsValidation); accuracies(idx) = sum(YPredicted == imdsValidation.Labels)/numel(imdsValidation.Labels); % Send the trained network back to the client. trainedNetworks{idx} = net; end
Analyzing and transferring files to the workers ...done.
parfor
が終了すると、trainedNetworks
にはワーカーによる学習の結果として得られたネットワークが含まれます。学習済みネットワークとその精度を表示します。
trainedNetworks
trainedNetworks = 4×1 cell array
{1×1 SeriesNetwork}
{1×1 SeriesNetwork}
{1×1 SeriesNetwork}
{1×1 SeriesNetwork}
accuracies
accuracies = 4×1
0.8214
0.8172
0.8132
0.8084
精度に基づいて最適なネットワークを選択します。テスト データセットに対してそのパフォーマンスをテストします。
[~, I] = max(accuracies); bestNetwork = trainedNetworks{I(1)}; YPredicted = classify(bestNetwork,imdsTest); accuracy = sum(YPredicted == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.8187
補助関数
ネットワーク アーキテクチャで畳み込みブロックを作成する関数を定義します。
function layers = convolutionalBlock(numFilters,numConvLayers) layers = [ convolution2dLayer(3,numFilters,'Padding','same') batchNormalizationLayer reluLayer ]; layers = repmat(layers,numConvLayers,1); end
DataQueue
を通じて学習の進行状況をクライアントに送信する関数を定義します。
function sendTrainingProgress(D,idx,info) if info.State == "iteration" send(D,{idx,info.Iteration,info.TrainingAccuracy}); end end
ワーカーが中間結果を送信したときにプロットを更新する更新関数を定義します。
function updatePlot(lines,idx,iter,acc) addpoints(lines(idx),iter,acc); drawnow limitrate nocallbacks end
参考
trainNetwork
| parallel.pool.DataQueue
(Parallel Computing Toolbox) | imageDatastore
関連する例
詳細
- 並列 for ループ (parfor) (Parallel Computing Toolbox)