Main Content

parfeval を使用した複数の深層学習ネットワークの学習

この例では、parfeval を使用して、深層学習ネットワークのネットワーク アーキテクチャの深さについてのパラメーター スイープを実行し、学習中にデータを取得する方法を説明します。

多くの場合、深層学習における学習には数時間または数日かかるため、適切なアーキテクチャを見つけるのが難しい場合があります。並列計算を使用すると、適切なモデルの探索を高速化および自動化できます。複数のグラフィックス処理装置 (GPU) があるマシンにアクセスできる場合は、ローカルの並列プールを使用してデータ セットのローカル コピーに対してこの例を完了させることができます。より多くのリソースを使用する必要がある場合は、深層学習における学習をクラウドにスケールアップできます。この例では、parfeval を使用して、クラウドのクラスターにおけるネットワーク アーキテクチャの深さについてのパラメーター スイープを実行する方法を説明します。parfeval を使用すると、MATLAB をブロックすることなくバックグラウンドで学習を実行でき、満足できる結果が得られた場合に早期の停止を選択できます。スクリプトを変更して、他のパラメーターについてのパラメーター スイープを実行できます。また、この例では、DataQueue を使用して計算中にワーカーからフィードバックを取得する方法も説明します。

要件

この例を実行する前に、クラスターを構成し、データをクラウドにアップロードする必要があります。MATLAB では、MATLAB デスクトップから直接、クラウドにクラスターを作成できます。[ホーム] タブの [並列] メニューで、[クラスターの作成と管理] を選択します。クラスター プロファイル マネージャーで、[クラウド クラスターの作成] をクリックします。または、MathWorks Cloud Center を使用して計算クラスターを作成し、そのクラスターにアクセスすることもできます。詳細については、Getting Started with Cloud Center を参照してください。この例では、MATLAB の [ホーム] タブの [並列][既定のクラスターの選択] で、クラスターが既定として設定されていることを確認します。その後、データを Amazon S3 バケットにアップロードして、MATLAB から直接使用します。この例では、Amazon S3 に既に格納されている CIFAR-10 データ セットのコピーを使用します。手順については、AWS での深層学習データの処理を参照してください。

クラウドからのデータ セットの読み込み

imageDatastore を使用して、学習データ セットとテスト データ セットをクラウドから読み込みます。学習データ セットは学習セットと検証セットに分割し、テスト データ セットはパラメーター スイープから得られる最適なネットワークをテストするために保持します。この例では、Amazon S3 に格納されている CIFAR-10 データ セットのコピーを使用します。ワーカーがクラウドのデータストアに確実にアクセスできるように、AWS 資格情報の環境変数が正しく設定されていることを確認してください。AWS での深層学習データの処理を参照してください。

imds = imageDatastore("s3://cifar10cloud/cifar10/train", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

imdsTest = imageDatastore("s3://cifar10cloud/cifar10/test", ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);

augmentedImageDatastore オブジェクトを作成し、拡張イメージ データを使用してネットワークに学習させます。ランダムな平行移動と水平方向の反転を使用します。データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。

imageSize = [32 32 3];
pixelRange = [-4 4];
imageAugmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandXTranslation=pixelRange, ...
    RandYTranslation=pixelRange);
augmentedImdsTrain=augmentedImageDatastore(imageSize,imdsTrain, ...
    DataAugmentation=imageAugmenter, ...
    OutputSizeMode="randcrop");

複数のネットワークの同時学習

学習オプションを指定します。ミニバッチ サイズを設定し、ミニバッチ サイズに応じて初期学習率を線形にスケーリングします。trainnet でエポックごとに 1 回ネットワークが検証されるように、検証頻度を設定します。

miniBatchSize = 128;
initialLearnRate = 1e-1 * miniBatchSize/256;
validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
options = trainingOptions("sgdm", ...
    MiniBatchSize=miniBatchSize, ... % Set the mini-batch size
    Verbose=false, ... % Do not send command line output.
    Metrics="accuracy", ...
    InitialLearnRate=initialLearnRate, ... % Set the scaled learning rate.
    L2Regularization=1e-10, ...
    MaxEpochs=30, ...
    Shuffle="every-epoch", ...
    ValidationData=imdsValidation, ...
    ValidationFrequency=validationFrequency);

パラメーター スイープを実行するネットワーク アーキテクチャの深さを指定します。parfeval を使用して複数のネットワークに同時に学習させる並列パラメーター スイープを実行します。ループを使用して、スイープ内で異なるネットワーク アーキテクチャについて繰り返します。ネットワークの深さを制御する入力引数を取り、CIFAR-10 のアーキテクチャを作成する補助関数 createNetworkArchitecture をスクリプトの末尾に作成します。parfeval を使用して、trainnet によって実行される計算の負荷をクラスター内のワーカーに移します。parfeval は、計算が完了すると、学習済みネットワークを保持する future 変数と学習情報を返します。

既定では、関数 trainnet は利用可能な GPU がある場合にそれを使用します。GPU での学習には、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数 trainnet は CPU を使用します。実行環境を指定するには、ExecutionEnvironment 学習オプションを使用します。

netDepths = 1:4;
numExperiments = numel(netDepths);
for idx = 1:numExperiments
    networksFuture(idx) = parfeval(@trainnet,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),"crossentropy",options);
end
Starting parallel pool (parpool) using the 'MyCluster' profile ...
Connected to parallel pool with 4 workers (PreferredPoolNumWorkers).

parfeval は MATLAB をブロックしないため、コマンドの実行を継続できます。この場合、networksFuture に対して fetchOutputs を使用して、学習済みネットワークとその学習情報を取得します。関数 fetchOutputs は、future 変数が終了するまで待機します。

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

trainingInfo 構造体にアクセスして、ネットワークの最終検証精度を求めます。

for idx = 1:numExperiments
    validationHistory = trainingInfo(idx).ValidationHistory;
    accuracies(idx) = validationHistory.Accuracy(end);
end

accuracies
accuracies = 1×4

   70.7200   78.8200   76.1000   78.0200

精度に基づいて最適なネットワークを選択します。

[~, I] = max(accuracies);
bestNetwork = trainedNetworks(I(1));

テスト データ セットに対してそのパフォーマンスをテストします。複数の観測値を使用して予測を行うには、関数minibatchpredictを使用します。予測スコアをラベルに変換するには、関数 scores2label を使用します。関数 minibatchpredict は利用可能な GPU がある場合に自動的にそれを使用します。

classNames = categories(imdsTest.Labels);
scores = minibatchpredict(bestNetwork,imdsTest);
Y = scores2label(scores,classNames);
accuracy = sum(Y == imdsTest.Labels)/numel(imdsTest.Labels)
accuracy = 0.7798

テスト データの混同行列を計算します。

figure
confusionchart(imdsTest.Labels,Y,RowSummary="row-normalized",ColumnSummary="column-normalized");

学習時のフィードバック データの送信

各ワーカーの学習の進行状況を示すプロットを準備して初期化します。animatedLine を使用すると、変化するデータを簡単に表示できます。

f = figure;
f.Visible = true;
for i=1:4
    subplot(2,2,i)
    xlabel("Iteration");
    ylabel("Training accuracy");
    lines(i) = animatedline;
end

DataQueue を使用して学習の進行状況データをワーカーからクライアントに送信してから、データをプロットします。afterEach を使用して、ワーカーから学習の進行状況フィードバックが送信されるたびにプロットを更新します。パラメーター opts には、ワーカー、学習の反復、および学習精度に関する情報が含まれます。

D = parallel.pool.DataQueue;
afterEach(D, @(opts) updatePlot(lines,opts{:}));

パラメーター スイープを実行するネットワーク アーキテクチャの深さを指定し、parfeval を使用して並列パラメーター スイープを実行します。スクリプトを添付ファイルとして現在のプールに追加して、このスクリプトでワーカーが補助関数にアクセスできるようにします。学習オプションで出力関数を定義して、学習の進行状況をワーカーからクライアントに送信します。学習オプションは、ワーカーのインデックスに依存するため、for ループ内に含めなければなりません。

netDepths = 1:4;
addAttachedFiles(gcp,mfilename);
for idx = 1:numel(netDepths)
    
    miniBatchSize = 128;
    initialLearnRate = 1e-1 * miniBatchSize/256; % Scale the learning rate according to the mini-batch size.
    validationFrequency = floor(numel(imdsTrain.Labels)/miniBatchSize);
    
    options = trainingOptions("sgdm", ...
        OutputFcn=@(state) sendTrainingProgress(D,idx,state), ... % Set an output function to send intermediate results to the client.
        MiniBatchSize=miniBatchSize, ... % Set the corresponding MiniBatchSize in the sweep.
        Verbose=false, ... % Do not send command line output.
        InitialLearnRate=initialLearnRate, ... % Set the scaled learning rate.
        Metrics="accuracy", ...
        L2Regularization=1e-10, ...
        MaxEpochs=30, ...
        Shuffle="every-epoch", ...
        ValidationData=imdsValidation, ...
        ValidationFrequency=validationFrequency);
    
    networksFuture(idx) = parfeval(@trainnet,2, ...
        augmentedImdsTrain,createNetworkArchitecture(netDepths(idx)),"crossentropy",options);
end

parfeval は、クラスター内のワーカーに対して trainnet を呼び出します。計算はバックグラウンドで行われるため、MATLAB での作業を続行できます。parfeval による計算を停止する場合は、対応する future 変数に対して cancel を呼び出します。たとえば、あるネットワークのパフォーマンスが低下していることがわかった場合、そのネットワークの future をキャンセルできます。これを行うと、次にキューに入っている future 変数の計算が開始されます。

この場合、future 変数に対して fetchOutputs を呼び出して、学習済みネットワークとその学習情報を取得します。

[trainedNetworks,trainingInfo] = fetchOutputs(networksFuture);

各ネットワークの最終検証精度を求めます。

for idx = 1:numExperiments
    validationHistory = trainingInfo(idx).ValidationHistory;
    accuracies(idx) = validationHistory.Accuracy(end);
end

accuracies
accuracies = 1×4

   71.4600   78.3600   74.4000   79.3800

補助関数

関数を使用して CIFAR-10 データ セットのネットワーク アーキテクチャを定義し、入力引数を使用してネットワークの深さを調整します。コードを簡略化するために、入力を畳み込む畳み込みブロックを使用します。プーリング層は空間次元をダウンサンプリングします。

function layers = createNetworkArchitecture(netDepth)
imageSize = [32 32 3];
netWidth = round(16/sqrt(netDepth)); % netWidth controls the number of filters in a convolutional block

layers = [
    imageInputLayer(imageSize)
    
    convolutionalBlock(netWidth,netDepth)
    maxPooling2dLayer(2,Stride=2)
    convolutionalBlock(2*netWidth,netDepth)
    maxPooling2dLayer(2,Stride=2)
    convolutionalBlock(4*netWidth,netDepth)
    averagePooling2dLayer(8)
    
    fullyConnectedLayer(10)
    softmaxLayer
    ];
end

ネットワーク アーキテクチャで畳み込みブロックを作成する関数を定義します。

function layers = convolutionalBlock(numFilters,numConvLayers)
layers = [
    convolution2dLayer(3,numFilters,Padding="same")
    batchNormalizationLayer
    reluLayer
    ];

layers = repmat(layers,numConvLayers,1);
end

DataQueue を通じて学習の進行状況をクライアントに送信する関数を定義します。

function stop = sendTrainingProgress(D,idx,info)
if info.State == "iteration" && ~isempty(info.TrainingAccuracy)
    send(D,{idx,info.Iteration,info.TrainingAccuracy});
end
stop = false;
end

ワーカーが中間結果を送信したときにプロットを更新する更新関数を定義します。

function updatePlot(lines,idx,iter,acc)
addpoints(lines(idx),iter,acc);
drawnow limitrate nocallbacks
end

参考

(Parallel Computing Toolbox) | | | | |

関連するトピック