Main Content

カスタム学習ループを使用したネットワークの並列学習

この例では、ネットワークに並列学習させるためのカスタム学習ループを設定する方法を説明します。この例では、並列ワーカーによりミニバッチの一部で学習が実行されます。GPU がある場合、GPU 上で学習が行われます。学習中、DataQueue オブジェクトによって、学習の進行状況の情報が MATLAB クライアントに送り返されます。

データセットの読み込み

数字のデータセットを読み込み、このデータセットのイメージ データストアを作成します。データストアを学習データストアとテスト データストアにランダムに分割します。学習データを格納する augmentedImageDatastore を作成します。

digitDatasetPath = fullfile(toolboxdir("nnet"),"nndemos", ...
    "nndatasets","DigitDataset");
imds = imageDatastore(digitDatasetPath, ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

[imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

inputSize = [28 28 1];
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);

学習セットに含まれる異なるクラスを判別します。

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

ネットワークの定義

ネットワーク アーキテクチャを定義します。このネットワーク アーキテクチャには、バッチ正規化層が含まれています。この層は、データセットの統計量である平均と分散を追跡します。並列学習の場合、各反復ステップの最後にすべてのワーカーからの統計量を結合して、ネットワークの状態が必ずミニバッチ全体を反映するようにします。そうでない場合、ネットワークの状態がワーカー間で異なる可能性があります。たとえば、ステートフル再帰型ニューラル ネットワーク (RNN) の学習において、小さいシーケンスに分割されたシーケンス データを使用して LSTM 層または GRU 層を含むネットワークに学習させる場合、ワーカー間の状態の管理もしなければなりません。

layers = [
    imageInputLayer(inputSize,Normalization="none")
    convolution2dLayer(5,20)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding=1)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,20,Padding=1)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

層配列から dlnetwork オブジェクトを作成します。dlnetwork オブジェクトにより、カスタム ループを使用した学習が可能になります。

net = dlnetwork(layers)
net = 
  dlnetwork with properties:

         Layers: [12×1 nnet.cnn.layer.Layer]
    Connections: [11×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 1

  View summary with summary.

並列環境の設定

関数 canUseGPU を使用し、MATLAB で GPU が使用可能かどうかを判定します。

  • 使用できる GPU がある場合、GPU 上で学習を実行。GPU と同じ数のワーカーを使用して並列プールを作成。

  • 使用できる GPU がない場合、CPU 上で学習を実行。既定の数のワーカーを使用して並列プールを作成。

if canUseGPU
    executionEnvironment = "gpu";
    numberOfGPUs = gpuDeviceCount("available");
    pool = parpool(numberOfGPUs);
else
    executionEnvironment = "cpu";
    pool = parpool;
end
Starting parallel pool (parpool) using the 'Processes' profile ...
Connected to the parallel pool (number of workers: 4).

並列プール内のワーカー数を取得します。この例では後ほど、この数に基づいて作業負荷を分割します。

numWorkers = pool.NumWorkers;

モデルの学習

学習オプションを指定します。

numEpochs = 20;
miniBatchSize = 128;
velocity = [];

GPU を使用した学習では、GPU の数でミニバッチ サイズを線形にスケールアップし、各 GPU における作業負荷を一定に保つことを推奨します。関連するアドバイスの詳細については、MATLAB による複数の GPU での深層学習を参照してください。

if executionEnvironment == "gpu"
     miniBatchSize = miniBatchSize .* numWorkers
end
miniBatchSize = 512

ミニバッチ全体のサイズをワーカー間で均等に配分し、各ワーカーのミニバッチ サイズを計算します。余りは最初のワーカー間で分配します。

workerMiniBatchSize = floor(miniBatchSize ./ repmat(numWorkers,1,numWorkers));
remainder = miniBatchSize - sum(workerMiniBatchSize);
workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,numWorkers-remainder)]
workerMiniBatchSize = 1×4

   128   128   128   128

このネットワークには、ネットワークに学習させているデータの平均と分散を追跡するバッチ正規化層が含まれています。各ワーカーは各反復中に各ミニバッチの一部を処理するため、平均と分散はすべてのワーカーにわたって集計しなければなりません。ネットワークの状態プロパティで、バッチ正規化層の平均と分散の状態パラメーターについて、インデックスを検索します。

batchNormLayers = arrayfun(@(l)isa(l,"nnet.cnn.layer.BatchNormalizationLayer"),net.Layers);
batchNormLayersNames = string({net.Layers(batchNormLayers).Name});
state = net.State;
isBatchNormalizationStateMean = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedMean";
isBatchNormalizationStateVariance = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedVariance";

TrainingProgressMonitor オブジェクトを初期化します。monitor オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。

monitor = trainingProgressMonitor( ...
    Metrics="TrainingLoss", ...
    Info=["Epoch" "Workers"], ...
    XLabel="Iteration");

ワーカーで Dataqueue オブジェクトを作成して、[停止] ボタンが押されたときに学習を停止するためのフラグを送信します。

spmd
    stopTrainingEventQueue = parallel.pool.DataQueue;
end
stopTrainingQueue = stopTrainingEventQueue{1};

学習中にワーカーからデータを返すため、DataQueue オブジェクトを作成します。afterEach を使用して関数 displayTrainingProgress を設定し、ワーカーがデータを送信するたびに呼び出されるようにします。displayTrainingProgress は、TrainingProgressMonitor オブジェクトを更新して表示し、ワーカーから送信される学習の進行状況の情報を表示し、[停止] ボタンが押された場合にワーカーにフラグを送信するサポート関数 (この例の最後で定義) です。

dataQueue = parallel.pool.DataQueue;
displayFcn = @(x) displayTrainingProgress(x,numEpochs,numWorkers,monitor,stopTrainingQueue);
afterEach(dataQueue,displayFcn)

次の手順で詳しく説明するように、カスタム並列学習ループを使用してモデルに学習させます。すべてのワーカーで同時にコードを実行するには、spmd ブロックを使用します。spmd ブロック内の spmdIndex により、現在コードを実行しているワーカーのインデックスが与えられます。

学習前に、関数 partition を使用して各ワーカーのデータストアを分割します。分割されたデータストアを使用して、各ワーカーに minibatchqueue を作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、データを正規化し、ラベルを one-hot 符号化変数に変換し、ミニバッチ内の観測値の数を判定します。

  • イメージ データを次元ラベル "SSCB" (spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。クラス ラベルや観測数には書式を追加しないでください。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox) (Parallel Computing Toolbox) を参照してください。

関数 reset と関数 shuffle を使用し、エポックごとにデータストアをリセットしてシャッフルします。エポック内のそれぞれの反復で次を行います。

  • データの並列処理を開始する前に、spmdreduce を使用してグローバルな and 演算を関数 hasdata の結果に対して実行し、すべてのワーカーに利用可能なデータがあることを確認。

  • 関数 next を使用して、minibatchqueue からミニバッチを読み取ります。

  • 関数 modelLossdlfeval を呼び出すことによって、各ワーカーのネットワークの損失と勾配を計算。関数 dlfeval は、自動微分を有効にして補助関数 modelLoss を評価し、modelLoss が損失の勾配を自動的に計算できるようにします。modelLoss (この例の最後で定義) は、ネットワーク、データのミニバッチ、真のラベルを受け取り、損失と勾配を返します。

  • 全体的な損失を取得するために、すべてのワーカーの損失を集計。この例では、損失関数の交差エントロピーを使用します。集計された損失はすべての損失の合計です。集計する前に、ミニバッチ全体のうちワーカーが処理している割合で乗算し、各損失を正規化します。spmdPlus を使用してすべての損失を加算し、ワーカー全体にその結果を複製します。

  • すべてのワーカーの勾配を集計および更新するために、関数 aggregateGradients で関数 dlupdate を使用。aggregateGradients はこの例の終わりで定義するサポート関数です。この関数は、ミニバッチ全体のうち各ワーカーが処理している割合に基づいて正規化した後、spmdPlus を使用し、勾配を加算してワーカー全体に複製します。

  • 関数 aggregateState を使用して、すべてのワーカーのネットワークの状態を集約。aggregateState は、この例の最後で定義されているサポート関数です。ネットワークのバッチ正規化層がデータの平均と分散を追跡します。ミニバッチ全体が複数のワーカーに分散されているため、各反復の後にネットワークの状態を集計し、ミニバッチ全体の平均と分散を計算します。

  • 最終勾配を計算した後、関数 sgdmupdate を使用し、ネットワークの学習可能なパラメーターを更新。

各エポックの後、[停止] ボタンが押されたかどうかを確認し、Dataqueue オブジェクトと関数 send を使用して学習の進行状況の情報をクライアントに送り返します。すべてのワーカーが同じ損失情報をもっているため、必要なのは 1 つのワーカーを使用してデータを送り返すことだけです。データが確実に CPU 上にあり、GPU を搭載していないクライアント マシンがデータにアクセスできるようにするには、データをクライアントに送信する前に、dlarray に対して gather を使用します。各エポックの後にワーカー間の通信が発生するため、[停止] をクリックして、現在のエポックの最後で学習を停止します。各反復の最後に [停止] ボタンで学習を停止させる場合、[停止] ボタンが押されたかどうかを確認し、反復ごとに学習の進行状況の情報をクライアントに送り返せますが、通信オーバーヘッドが増加します。

spmd
    % Reset and shuffle the datastore.
    reset(augimdsTrain);
    augimdsTrain = shuffle(augimdsTrain);

    % Partition datastore.
    workerImds = partition(augimdsTrain,numWorkers,spmdIndex);

    % Create minibatchqueue using partitioned datastore on each worker
    workerMbq = minibatchqueue(workerImds,3,...
        MiniBatchSize=workerMiniBatchSize(spmdIndex),...
        MiniBatchFcn=@preprocessMiniBatch,...
        MiniBatchFormat=["SSCB" "" ""]);

    workerVelocity = velocity;
    epoch = 0;
    iteration = 0;
    stopRequest = false;

    while epoch < numEpochs && ~stopRequest
        epoch = epoch + 1;
        shuffle(workerMbq);

        % Loop over mini-batches
        while spmdReduce(@and,hasdata(workerMbq)) && ~stopRequest
            iteration = iteration + 1;

            % Read a mini-batch of data
            [workerX,workerT,workerNumObservations] = next(workerMbq);

            % Evaluate the model loss and gradients on the worker
            [workerLoss,workerGradients,workerState] = dlfeval(@modelLoss,net,workerX,workerT);

            % Aggregate the losses on all workers
            workerNormalizationFactor = workerMiniBatchSize(spmdIndex)./miniBatchSize;
            loss = spmdPlus(workerNormalizationFactor*extractdata(workerLoss));

            % Aggregate the network state on all workers
            net.State = aggregateState(workerState,workerNormalizationFactor,...
                isBatchNormalizationStateMean,isBatchNormalizationStateVariance);

            % Aggregate the gradients on all workers
            workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});

            % Update the network parameters using the SGDM optimizer
            [net,workerVelocity] = sgdmupdate(net,workerGradients,workerVelocity);
        end

        % Stop training if the Stop button has been clicked
        stopRequest = spmdPlus(stopTrainingEventQueue.QueueLength);

        % Send training progress information to the client
        if spmdIndex == 1
            data = [epoch loss iteration];
            send(dataQueue,gather(data));
        end
    end

end

モデルのテスト

ネットワークに学習させた後、その精度をテストできます。

readall を使用してテスト データストアにあるテスト イメージをメモリに読み込み、それらを連結して正規化します。

XTest = readall(imdsTest);
XTest = cat(4,XTest{:});
XTest = single(XTest) ./ 255;
TTest = imdsTest.Labels;

学習が完了すると、各ワーカーがもつ完全な学習済みネットワークはすべて同じになります。それらのいずれかを取得します。

netFinal = net{1};

dlnetwork オブジェクトを使用してイメージを分類するには、dlarray に対して関数 predict を使用します。

YTest = predict(netFinal,dlarray(XTest,"SSCB"));

関数 max を使用し、予測スコアからスコアが最も高いクラスを見つけます。これを行う前に、関数 extractdata を使用して dlarray からデータを抽出します。

[~,idx] = max(extractdata(YTest),[],1);
YTest = classes(idx);

モデルの分類精度を取得するには、テスト セットにおける予測を真のラベルと比較します。

accuracy = mean(YTest==TTest)
accuracy = 0.9440

ミニ バッチ前処理関数

関数 preprocessMiniBatch は、次の手順を使用して予測子とラベルのミニバッチを前処理します。

  1. ミニバッチ内の観測数を判定します。

  2. 関数 preprocessMiniBatchPredictors を使用してイメージを前処理します。

  3. 入力 cell 配列からラベル データを抽出し、2 番目の次元に沿って categorical 配列に連結します。

  4. カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。

function [X,Y,numObs] = preprocessMiniBatch(XCell,YCell)

numObs = numel(YCell);

% Preprocess predictors.
X = preprocessMiniBatchPredictors(XCell);

% Extract label data from cell and concatenate.
Y = cat(2,YCell{1:end});

% One-hot encode labels.
Y = onehotencode(Y,1);

end

ミニバッチ予測子前処理関数

関数 preprocessMiniBatchPredictors は、入力 cell 配列からイメージ データを抽出することで予測子のミニバッチを前処理し、数値配列に連結します。グレースケール入力では、4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されることになります。その後、データが正規化されます。

function X = preprocessMiniBatchPredictors(XCell)

% Concatenate.
X = cat(4,XCell{1:end});

% Normalize.
X =  X ./ 255;

end

モデル損失関数

ネットワークの学習可能なパラメーターについて損失の勾配を計算する関数 modelLoss を定義します。この関数は、forward を使用してミニバッチ X に対するネットワークの出力を計算し、クロス エントロピーを使用して、本来の出力が与えられたときの損失を計算します。dlfeval と共にこの関数を呼び出すと、自動微分が有効になり、dlgradient は学習可能なパラメーターについての損失の勾配を自動的に計算できます。

function [loss,gradients,state] = modelLoss(net,X,T)

[Y,state] = forward(net,X);

loss = crossentropy(Y,T);
gradients = dlgradient(loss,net.Learnables);

end

学習の進行状況を表示する関数

ワーカーから送信される学習の進行状況の情報を表示し、[停止] ボタンがクリックされたかどうかを確認する関数を定義します。[停止] ボタンがクリックされると、学習を停止する必要があることを示すフラグがワーカーに送信されます。この例では、ワーカーがデータを送信するたびに DataQueue によってこの関数が呼び出されます。

function displayTrainingProgress(data,numEpochs,numWorkers,monitor,stopTrainingQueue)

epoch = data(1);
loss = data(2);
iteration = data(3);

recordMetrics(monitor,iteration,TrainingLoss=loss);
updateInfo(monitor,Epoch=epoch + " of " + numEpochs, Workers= numWorkers);
monitor.Progress = 100 * epoch/numEpochs;

if monitor.Stop
    send(stopTrainingQueue,true);
end

end

勾配集計関数

すべてのワーカーの勾配を加算して集計する関数を定義します。spmdPlus は、ワーカー上ですべての勾配を加算して複製します。加算する前に、ミニバッチ全体のうちワーカーが処理している割合を表す係数を勾配に乗算し、それらを正規化します。dlarray の内容を取得するには、extractdata を使用します。

function gradients = aggregateGradients(gradients,factor)

gradients = extractdata(gradients);
gradients = spmdPlus(factor*gradients);

end

状態集計関数

すべてのワーカーでネットワークの状態を集計する関数を定義します。このネットワークの状態には、データ セットの学習済みバッチ正規化統計量が含まれます。各ワーカーが処理するのはミニバッチの一部のみなので、すべてのデータの統計を表すように、ネットワークの状態を集計します。ミニバッチごとに、統合平均が、各反復のワーカー全体の平均に対する加重平均として計算されます。統合分散は、次の式に従って計算されます。

sc2=1Mj=1Nmj[sj2+(xj-xc)2]

ここで、N はワーカーの合計数、M はミニバッチの観測値の合計数、mjj 番目のワーカーで処理された観測値の数、xjsj2 はそのワーカーで計算された平均と分散の統計、xc はすべてのワーカー全体の統合平均です。

function state = aggregateState(state,factor,...
    isBatchNormalizationStateMean,isBatchNormalizationStateVariance)

stateMeans = state.Value(isBatchNormalizationStateMean);
stateVariances = state.Value(isBatchNormalizationStateVariance);

for j = 1:numel(stateMeans)
    meanVal = stateMeans{j};
    varVal = stateVariances{j};

    % Calculate combined mean
    combinedMean = spmdPlus(factor*meanVal);

    % Calculate combined variance terms to sum
    varTerm = factor.*(varVal + (meanVal - combinedMean).^2);

    % Update state
    stateMeans{j} = combinedMean;
    stateVariances{j} = spmdPlus(varTerm);
end

state.Value(isBatchNormalizationStateMean) = stateMeans;
state.Value(isBatchNormalizationStateVariance) = stateVariances;

end

参考

| | | | | | | | |

関連するトピック