Main Content

カスタム学習ループを使用したネットワークの並列学習

この例では、ネットワークに並列学習させるためのカスタム学習ループを設定する方法を説明します。

深層ニューラル ネットワークの学習には計算コストがかかり、計算に何時間もかかる可能性があります。特に複数の GPU がある場合、学習を高速化するために、ネットワークの並列学習を行うことができます。

この例では、並列ワーカーによりミニバッチの一部で学習が実行されます。GPU がある場合、GPU 上で学習が行われます。学習中、DataQueue オブジェクトによって、学習の進行状況の情報が MATLAB クライアントに送り返されます。

データ セットの読み込み

Flowers データ セット [1] をダウンロードし、解凍します。Flowers データ セットには、5 つのクラス ("デイジー""タンポポ""バラ""ヒマワリ"、および "チューリップ") に属する 3670 個の花のイメージが格納されています。

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

dataFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(dataFolder,"dir")
    fprintf("Downloading Flowers data set (218 MB)... ")
    websave(filename,url);
    untar(filename,downloadFolder)
    fprintf("Done.\n")
end
Downloading Flowers data set (218 MB)... 
Done.

データ セットのイメージ データストアを作成します。データストアを学習データストアとテスト データストアにランダムに分割します。

imds = imageDatastore(dataFolder, ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

[imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

学習イメージのサイズを変更するには、augmentedImageDatastore を作成します。関数 shuffle を使用してデータをシャッフルします。

inputSize = [100 100 3];
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
augimdsTrain = shuffle(augimdsTrain);

学習データ セットに含まれる異なるクラスを判別します。

classes = categories(imdsTrain.Labels)
classes = 5×1 cell
    {'daisy'     }
    {'dandelion' }
    {'roses'     }
    {'sunflowers'}
    {'tulips'    }

numClasses = numel(classes)
numClasses = 5

ネットワークの定義

2 次元残差ネットワークを作成します。このネットワーク アーキテクチャには、バッチ正規化層が含まれています。この層は、データ セットの統計量である平均と分散を追跡します。並列学習の場合、各反復ステップの最後にすべてのワーカーからの統計量を結合して、ネットワークの状態が必ずミニバッチ全体を反映するようにします。そうでない場合、ネットワークの状態がワーカー間で異なる可能性があります。

net = resnetNetwork(inputSize,numClasses)
net = 
  dlnetwork with properties:

         Layers: [176×1 nnet.cnn.layer.Layer]
    Connections: [191×2 table]
     Learnables: [214×3 table]
          State: [106×3 table]
     InputNames: {'input'}
    OutputNames: {'softmax'}
    Initialized: 1

  View summary with summary.

代わりに再帰型ニューラル ネットワーク (RNN)、たとえば LSTM 層や GRU 層を含むネットワークを使用する場合、再帰層は学習中に変化する状態プロパティをもつことになります。したがって、カスタム学習ループを使用してネットワークに並列学習させる場合は、それらの状態を管理するように注意しなければなりません。ステートレスに (つまり、学習反復間で再帰層の状態を維持することなく) RNN に学習させるには、各学習反復の最後で resetState を呼び出します。ワーカー全体でバッチ正規化の統計値を集計するためにコードを編集する必要はありません。

並列環境の設定

関数 canUseGPU を使用し、MATLAB で GPU が使用可能かどうかを判定します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

  • 使用できる GPU がある場合、GPU 上で学習を実行。GPU と同じ数のワーカーを使用して並列プールを作成。

  • 使用できる GPU がない場合、CPU 上で学習を実行。既定の数のワーカーを使用して並列プールを作成。

if canUseGPU
    executionEnvironment = "gpu";
    numberOfGPUs = gpuDeviceCount("available");
    gpuDeviceTable
    pool = parpool(numberOfGPUs);
else
    executionEnvironment = "cpu";
    pool = parpool;
end
ans=4×5 table
    Index          Name           ComputeCapability    DeviceAvailable    DeviceSelected
    _____    _________________    _________________    _______________    ______________

      1      "NVIDIA TITAN Xp"          "6.1"               true              true      
      2      "NVIDIA TITAN Xp"          "6.1"               true              false     
      3      "NVIDIA TITAN Xp"          "6.1"               true              false     
      4      "NVIDIA TITAN Xp"          "6.1"               true              false     

Starting parallel pool (parpool) using the 'Processes' profile ...

並列プール内のワーカー数を取得します。

numWorkers = pool.NumWorkers;

モデルの学習

学習オプションを指定します。

numEpochs = 100;
miniBatchSize = 128;
velocity = [];

GPU を使用した学習では、GPU の数でミニバッチ サイズを線形にスケールアップし、各 GPU における作業負荷を一定に保つことを推奨します。関連するアドバイスの詳細については、MATLAB による複数の GPU での深層学習を参照してください。

if executionEnvironment == "gpu"
     miniBatchSize = miniBatchSize .* numWorkers
end

ミニバッチ全体のサイズをワーカー間で均等に配分し、各ワーカーのミニバッチ サイズを計算します。余りは最初のワーカー間で分配します。

workerMiniBatchSize = floor(miniBatchSize ./ repmat(numWorkers,1,numWorkers));
remainder = miniBatchSize - sum(workerMiniBatchSize);
workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,numWorkers-remainder)]
workerMiniBatchSize = 128

このネットワークには、ネットワークに学習させているデータの平均と分散を追跡するバッチ正規化層が含まれています。各ワーカーは各反復中に各ミニバッチの一部を処理するため、平均と分散はすべてのワーカーにわたって集計しなければなりません。ネットワーク内のすべてのバッチ正規化層の名前を検索します。

layers = net.Layers;
batchNormLayersNames = string.empty;

for idx = 1:numel(layers)
    currentLayer = layers(idx);
    if isa(currentLayer,"nnet.cnn.layer.BatchNormalizationLayer")
        batchNormLayersNames(end+1) = currentLayer.Name;
    end
end

ネットワークの状態プロパティで、バッチ正規化層の平均と分散の状態パラメーターについて、インデックスを検索します。

state = net.State;
isBatchNormalizationStateMean = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedMean";
isBatchNormalizationStateVariance = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedVariance";

TrainingProgressMonitor オブジェクトを初期化します。monitor オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。

monitor = trainingProgressMonitor( ...
    Metrics=["TrainingLoss" "TrainingAccuracy"], ...
    Info=["Epoch" "Workers"], ...
    XLabel="Iteration");

[停止] ボタンがクリックされたときに学習を停止するフラグを送信するための Dataqueue オブジェクトをワーカーに作成します。

spmd
    stopTrainingEventQueue = parallel.pool.DataQueue;
end
stopTrainingQueue = stopTrainingEventQueue{1};

学習中にワーカーからデータを返すため、DataQueue オブジェクトを作成します。afterEach を使用して関数 displayTrainingProgress を設定し、ワーカーがデータを送信するたびに呼び出されるようにします。displayTrainingProgress は、TrainingProgressMonitor オブジェクトを更新して表示し、ワーカーから送信される学習の進行状況の情報を表示し、[停止] ボタンがクリックされた場合にワーカーにフラグを送信するサポート関数 (この例の最後で定義) です。

dataQueue = parallel.pool.DataQueue;
displayFcn = @(x) displayTrainingProgress(x,numEpochs,numWorkers,monitor,stopTrainingQueue);
afterEach(dataQueue,displayFcn)

次の手順で詳しく説明するように、カスタム並列学習ループを使用してモデルに学習させます。すべてのワーカーで同時にコードを実行するには、spmd ブロックを使用します。spmd ブロック内の spmdIndex により、現在コードを実行しているワーカーのインデックスが与えられます。

学習前に、関数 partition を使用して各ワーカーのデータストアを分割します。分割されたデータストアを使用して、各ワーカーに minibatchqueue を作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例の最後に定義) を使用して、データを正規化し、ターゲット クラスを one-hot 符号化変数に変換し、ミニバッチ内の観測値の数を判定します。

  • イメージ データを次元ラベル "SSCB" (spatial、spatial、channel、batch) で形式を整えます。既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。ターゲット クラスや観測数には形式を追加しないでください。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray に変換します。

  • 関数 dlaccelerate を使用して、モデル損失関数を高速化します。深層学習関数の高速化の詳細については、Deep Learning Function Acceleration for Custom Training Loopsを参照してください。

関数 shuffle を使用し、エポックごとにデータストアをシャッフルします。エポック内のそれぞれの反復で次を行います。

  • 並列処理を開始する前に、すべてのワーカーに利用可能なデータがあることを確認します。関数 continueEpoch は、この例の最後で定義されているサポート関数であり、関数 spmdCat を使用してワーカー間で関数 hasdata の結果を連結することによって、すべてのワーカーが利用可能なデータを確実にもつようにします。

  • [停止] ボタンがクリックされたかどうかをチェックします。関数 continueEpoch は、stopTrainingEveneQueue の長さを監視することでこれをチェックします。

  • 関数 next を使用して、minibatchqueue からミニバッチを読み取ります。

  • 関数 modelLossdlfeval を呼び出すことによって、各ワーカーのネットワークの損失と勾配を計算します。関数 dlfeval は、自動微分を有効にして補助関数 modelLoss を評価し、modelLoss が損失の勾配を自動的に計算できるようにします。modelLoss (この例の最後で定義) は、ネットワーク、データのミニバッチ、およびターゲットを受け取り、損失と勾配を返します。

  • 全体的な損失を取得するために、すべてのワーカーの損失を集計します。この例では、損失関数のクロスエントロピーを使用します。集計された損失はすべての損失の合計です。集計する前に、ミニバッチ全体のうちワーカーが処理している割合で乗算し、各損失を正規化します。spmdPlus を使用してすべての損失を加算し、ワーカー全体にその結果を複製します。

  • すべてのワーカーの勾配を集計および更新するために、関数 aggregateAllGradients を使用します。aggregateAllGradients はこの例の最後で定義されているサポート関数です。この関数は、ミニバッチ全体のうち各ワーカーが処理している割合に基づいて正規化した後、spmdPlus を使用し、勾配を加算してワーカー全体に複製します。

  • 関数 aggregateState を使用して、すべてのワーカーのネットワークの状態を集約します。aggregateState は、この例の最後で定義されているサポート関数です。ネットワークのバッチ正規化層がデータの平均と分散を追跡します。ミニバッチ全体が複数のワーカーに分散されているため、各反復の後にネットワークの状態を集計し、ミニバッチ全体の平均と分散を計算します。

  • 最終勾配を計算した後、関数 sgdmupdate を使用し、ネットワークの学習可能なパラメーターを更新します。

  • Dataqueue オブジェクトと関数 send を使用して学習の進行状況の情報をクライアントに送り返します。すべてのワーカーが同じ損失情報をもっているため、必要なのは 1 つのワーカーを使用してデータを送り返すことだけです。データが確実に CPU 上にあり、GPU を搭載していないクライアント マシンがデータにアクセスできるようにするには、データをクライアントに送信する前に、dlarray に対して gather を使用します。

spmd
    % Partition the datastore.
    workerImds = partition(augimdsTrain,numWorkers,spmdIndex);

    % Create minibatchqueue using partitioned datastore on each worker.
    workerMbq = minibatchqueue(workerImds,3,...
        MiniBatchSize=workerMiniBatchSize(spmdIndex),...
        MiniBatchFcn=@preprocessMiniBatch,...
        MiniBatchFormat=["SSCB" "" ""]);

    % Use dlaccelerate on the modelLoss
    accModelLoss = dlaccelerate(@modelLoss);

    workerVelocity = velocity;
    epoch = 0;
    iteration = 0;

    while epoch < numEpochs
        epoch = epoch + 1;
        shuffle(workerMbq);

        % Loop over mini-batches.
        while continueEpoch(workerMbq,stopTrainingEventQueue)
            iteration = iteration + 1;

            % Read a mini-batch of data.
            [workerX,workerT,workerNumObservations] = next(workerMbq);

            % Evaluate the model loss and gradients on the worker.
            [workerLoss,workerGradients,workerState] = dlfeval(accModelLoss,net,workerX,workerT);

            % Aggregate the losses on all workers.
            workerNormalizationFactor = workerMiniBatchSize(spmdIndex)./miniBatchSize;
            loss = spmdPlus(workerNormalizationFactor*extractdata(workerLoss));

            % Aggregate the network state on all workers.
            net.State = aggregateState(workerState,workerNormalizationFactor,...
                isBatchNormalizationStateMean,isBatchNormalizationStateVariance);

            % Aggregate the gradients on all workers.
            workerGradients.Value = aggregateAllGradients(workerGradients.Value,workerNormalizationFactor);
           
            % Update the network parameters using the SGDM optimizer.
            [net,workerVelocity] = sgdmupdate(net,workerGradients,workerVelocity);

            % Calculate the training accuracy and send training progress information to the client.
            if spmdIndex == 1
                scores = predict(net,workerX);
                labels = scores2label(workerT,classes);
                Y = scores2label(scores,classes);
                accuracy = mean(Y==labels);

                data = [epoch loss accuracy iteration];
                send(dataQueue,gather(data));
            end
        end

    end

end

モデルのテスト

学習が完了すると、各ワーカーがもつ完全な学習済みネットワークはすべて同じになります。それらのいずれかを取得します。

netFinal = net{1};

ネットワークに学習させた後、その精度をテストします。

テスト イメージを分類します。複数の観測値を使用して予測を行うには、関数 minibatchpredict を使用します。予測スコアをラベルに変換するには、関数 scores2label を使用します。関数 minibatchpredict は利用可能な GPU がある場合に自動的にそれを使用します。そうでない場合、関数は CPU を使用します。

labels = imdsTest.Labels;
imdsTestResized = transform(imdsTest,@(X) {imresize(X,inputSize(1:2))});
X = readall(imdsTestResized);
X = cat(4,X{:});
X = single(X) ./ 255;

scores = minibatchpredict(netFinal,X);
Y = scores2label(scores,classes);

ネットワークの精度を計算します。

accuracy = mean(Y==labels)
accuracy = 0.6649

ミニ バッチ前処理関数

関数 preprocessMiniBatch は、次の手順を使用して予測子とターゲット クラスのミニバッチを前処理します。

  1. ミニバッチ内の観測数を判定します。

  2. 関数 preprocessMiniBatchPredictors を使用してイメージを前処理します。

  3. 入力 cell 配列からターゲット クラスのデータを抽出し、2 番目の次元に沿って categorical 配列に連結します。

  4. カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。

function [X,Y,numObs] = preprocessMiniBatch(XCell,YCell)

numObs = numel(YCell);

% Preprocess predictors.
X = preprocessMiniBatchPredictors(XCell);

% Extract class data from cell and concatenate.
Y = cat(2,YCell{1:end});

% One-hot encode classes.
Y = onehotencode(Y,1);

end

ミニバッチ予測子前処理関数

関数 preprocessMiniBatchPredictors は、入力 cell 配列からイメージ データを抽出することで予測子のミニバッチを前処理し、数値配列に連結します。その後、イメージが正規化されます。

function X = preprocessMiniBatchPredictors(XCell)

X = cat(4,XCell{:});

X = single(X) ./ 255;

end

モデル損失関数

関数 modelLoss は、ネットワークの学習可能なパラメーターに関する損失の勾配を計算します。この関数は、forward を使用してミニバッチ X に対するネットワークの出力を計算し、クロス エントロピーを使用して、ターゲット T が与えられたときの損失を計算します。dlfeval と共にこの関数を呼び出すと、自動微分が有効になり、dlgradient は学習可能なパラメーターについての損失の勾配を自動的に計算できます。

function [loss,gradients,state] = modelLoss(net,X,T)

[Y,state] = forward(net,X);

loss = crossentropy(Y,T);
gradients = dlgradient(loss,net.Learnables);

end

学習の進行状況を表示する関数

関数 displayTrainingProgress は、ワーカーから送信される学習の進行状況の情報を表示し、[停止] ボタンがクリックされたかどうかをチェックします。[停止] ボタンがクリックされると、学習を停止する必要があることを示すフラグがワーカーに送信されます。この例では、ワーカーがデータを送信するたびに DataQueue によってこの関数が呼び出されます。

function displayTrainingProgress(data,numEpochs,numWorkers,monitor,stopTrainingQueue)

% Extract training information from array.
epoch = data(1);
loss = data(2);
accuracy = data(3);
iteration = data(4);

% Update training progress monitor.
recordMetrics(monitor,iteration,TrainingLoss=loss,TrainingAccuracy=accuracy);
updateInfo(monitor,Epoch=epoch + " of " + numEpochs, Workers= numWorkers);
monitor.Progress = 100 * epoch/numEpochs;

% Send flag if the Stop button is clicked.
if monitor.Stop
    send(stopTrainingQueue,true);
end

end

勾配集計関数

関数 aggregateAllGradients は、すべてのワーカーの勾配を加算して集計します。spmdPlus は、ワーカー上ですべての勾配を加算して複製します。加算する前に、ミニバッチ全体のうちワーカーが処理している割合を表す係数を勾配に乗算し、それらを正規化します。dlarray の内容を取得するには、extractdata を使用します。

function gradients = aggregateAllGradients(gradients,normalizationFactor)

% Inspect array of gradients and create cell array for storing aggregated
% data.
numArrays = numel(gradients);
aggregationData = cell(numArrays);
arraySizes = cell(numArrays,1);
numElements = zeros(numArrays,1);

% Extract the data from all the arrays
for idxArray = 1:numArrays
    data = gradients{idxArray};
    gradients{idxArray} = [];

    % Extract data from dlarray.
    data = extractdata(data);

    % Save the size of the array.
    arraySizes{idxArray} = size(data);
    numElements(idxArray) = numel(data);

    % Flatten the array to prepare for concatenation.
    aggregationData{idxArray} = data(:);
end

% Concatenate all arrays.
aggregationData = cat(1,aggregationData{:});

% Aggregate the data from the workers.
aggregationData = spmdPlus(normalizationFactor.*aggregationData);

% Reconstruct the gradient arrays.
i = 1;
for idxArray = 1:numArrays
    n = numElements(idxArray);
    if n > 0
        % Reshape the flattened data to the original size.
        data = reshape(aggregationData(i:(i+n-1)),arraySizes{idxArray});
        % Reinsert the aggregated data as a dlarray.
        gradients{idxArray} = dlarray(data);
        i = i + n;
    end
end
end

状態集計関数

関数 aggregateState は、すべてのワーカーでネットワークの状態を集計します。このネットワークの状態には、データ セットの学習済みバッチ正規化統計量が含まれます。各ワーカーが処理するのはミニバッチの一部のみなので、すべてのデータの統計を表すように、ネットワークの状態を集計します。ミニバッチごとに、統合平均が、各反復のワーカー全体の平均に対する加重平均として計算されます。統合分散は、次の式に従って計算されます。

sc2=1Mj=1Nmj[sj2+(xj-xc)2]

ここで、N はワーカーの合計数、M はミニバッチの観測値の合計数、mjj 番目のワーカーで処理された観測値の数、xjsj2 はそのワーカーで計算された平均と分散の統計、xc はすべてのワーカー全体の統合平均です。

function state = aggregateState(state,normalizationFactor,...
    isBatchNormalizationStateMean,isBatchNormalizationStateVariance)

stateMeans = state.Value(isBatchNormalizationStateMean);
stateVariances = state.Value(isBatchNormalizationStateVariance);

for j = 1:numel(stateMeans)
    meanVal = stateMeans{j};
    varVal = stateVariances{j};

    % Calculate combined mean.
    combinedMean = spmdPlus(normalizationFactor*meanVal);

    % Calculate combined variance terms to sum.
    varTerm = normalizationFactor.*(varVal + (meanVal - combinedMean).^2);

    % Update state.
    stateMeans{j} = combinedMean;
    stateVariances{j} = spmdPlus(varTerm);
end

state.Value(isBatchNormalizationStateMean) = stateMeans;
state.Value(isBatchNormalizationStateVariance) = stateVariances;

end

エポック継続関数

関数 continueEpoch は、各ワーカーのミニバッチ キューに残っているデータがあるかどうかをチェックし、[停止] ボタンが押されたかどうかをチェックします。

function tf = continueEpoch(workerMbq,stopTrainingEventQueue)

% Create a struct that will be concatenated across the workers.
info.HasData = hasdata(workerMbq);
info.StopRequested = stopTrainingEventQueue.QueueLength > 0;

% Use spmdCat to aggregate the info from all the workers.
info = spmdCat(info);

% Continue training if all the workers have data, and if we were not asked to stop.
stopRequest = any([info.StopRequested]);
tf = ~stopRequest && all([info.HasData]);

end

参考文献

参考

| | | | | | | | |

関連するトピック