Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

カスタム学習ループを使用したネットワークの並列学習

この例では、ネットワークに並列学習させるためのカスタム学習ループを設定する方法を説明します。この例では、並列ワーカーによりミニバッチの一部で学習が実行されます。GPU がある場合、GPU 上で学習が行われます。学習中、DataQueue オブジェクトによって、学習の進行状況の情報が MATLAB クライアントに送り返されます。

データセットの読み込み

数字のデータセットを読み込み、このデータセットのイメージ データストアを作成します。データストアを学習データストアとテスト データストアにランダムに分割します。

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

[imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

学習セットに含まれる異なるクラスを判別します。

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

ネットワークの定義

ネットワーク アーキテクチャを定義し、関数 layerGraph を使用してこれを層グラフにします。このネットワーク アーキテクチャには、バッチ正規化層が含まれています。この層は、データセットの統計量である平均と分散を追跡します。並列学習の場合、各反復ステップの最後にすべてのワーカーからの統計量を結合して、ネットワークの状態が必ずミニバッチ全体を反映するようにします。そうでない場合、ネットワークの状態がワーカー間で異なる可能性があります。たとえば、ステートフル再帰型ニューラル ネットワーク (RNN) の学習において、小さいシーケンスに分割されたシーケンス データを使用して LSTM 層または GRU 層を含むネットワークに学習させる場合、ワーカー間の状態の管理もしなければなりません。

layers = [
    imageInputLayer([28 28 1],'Name','input','Normalization','none')
    convolution2dLayer(5,20,'Name','conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')];

lgraph = layerGraph(layers);

層グラフから dlnetwork オブジェクトを作成します。dlnetwork オブジェクトにより、カスタム ループを使用した学習が可能になります。

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [11×1 nnet.cnn.layer.Layer]
    Connections: [10×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'fc'}

並列環境の設定

関数 canUseGPU を使用し、MATLAB で GPU が使用可能かどうかを判定します。

  • 使用できる GPU がある場合、GPU 上で学習を実行。GPU と同じ数のワーカーを使用して並列プールを作成。

  • 使用できる GPU がない場合、CPU 上で学習を実行。既定の数のワーカーを使用して並列プールを作成。

if canUseGPU
    executionEnvironment = "gpu";
    numberOfGPUs = gpuDeviceCount;
    pool = parpool(numberOfGPUs);
else
    executionEnvironment = "cpu";
    pool = parpool;
end
Starting parallel pool (parpool) using the 'local' profile ...
Connected to the parallel pool (number of workers: 2).

並列プール内のワーカー数を取得します。この例では後ほど、この数に基づいて作業負荷を分割します。

N = pool.NumWorkers;

学習中にワーカーからデータを返すため、DataQueue オブジェクトを作成します。afterEach を使用して関数 displayTrainingProgress を設定し、ワーカーがデータを送信するたびに呼び出されるようにします。displayTrainingProgress は、ワーカーから送信される学習の進行状況の情報を表示するサポート関数です。この関数の定義は、この例の終わりで行います。

Q = parallel.pool.DataQueue;
afterEach(Q,@displayTrainingProgress);

モデルの学習

学習オプションを指定します。

numEpochs = 20;
miniBatchSize = 128;
velocity = [];

GPU を使用した学習では、GPU の数でミニバッチ サイズを線形にスケールアップし、各 GPU における作業負荷を一定に保つことを推奨します。関連するアドバイスの詳細については、複数の GPU による学習を参照してください。

if executionEnvironment == "gpu"
    miniBatchSize = miniBatchSize .* N
end
miniBatchSize = 256

ミニバッチ全体のサイズをワーカー間で均等に配分し、各ワーカーのミニバッチ サイズを計算します。余りは最初のワーカー間で分配します。

workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N));
remainder = miniBatchSize - sum(workerMiniBatchSize);
workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)]
workerMiniBatchSize = 1×2

   128   128

次の手順で詳しく説明するように、カスタム並列学習ループを使用してモデルに学習させます。すべてのワーカーで同時にコードを実行するには、spmd ブロックを使用します。spmd ブロック内の labindex により、現在コードを実行しているワーカーのインデックスが与えられます。

学習前に、関数 partition を使用して各ワーカーのデータストアを分割し、ワーカーのバッチサイズを ReadSize に設定します。

関数 reset と関数 shuffle を使用し、エポックごとにデータストアをリセットしてシャッフルします。

エポック内のそれぞれの反復で次を行います。

  • データの並列処理を開始する前に、グローバルな and 演算 (gop) を関数 hasdata の結果に対して実行し、すべてのワーカーに利用可能なデータがあることを確認。

  • 関数 read を使用してデータストアからミニバッチを読み取り、取得したイメージをイメージの 4 次元配列に連結。ピクセルの値が 0 から 1 の間となるようにイメージを正規化。

  • ラベルを変換し、観測値にラベルを付けるダミー変数の行列にする。観測値のラベルに対してはダミー変数に 1、それ以外の場合は 0 を格納します。

  • データのミニバッチを、基となる型が single の dlarray オブジェクトに変換し、次元ラベル 'SSCB' (spatial、spatial、channel、batch) を指定。GPU で学習する場合、データを gpuArray に変換。

  • 関数 modelGradientsdlfeval を呼び出すことによって、各ワーカーのネットワークの勾配と損失を計算。関数 dlfeval は、自動微分を有効にして補助関数 modelGradients を評価し、modelGradients が損失の勾配を自動的に計算できるようにします。modelGradients (この例の最後に定義) は、ネットワーク、データのミニバッチ、真のラベルを受け取り、損失と勾配を返します。

  • 全体的な損失を取得するために、すべてのワーカーの損失を集計。この例では、損失関数の交差エントロピーを使用します。集計された損失はすべての損失の合計です。集計する前に、ミニバッチ全体のうちワーカーが処理している割合で乗算し、各損失を正規化します。gplus を使用してすべての損失を加算し、ワーカー全体にその結果を複製します。

  • すべてのワーカーの勾配を集計および更新するために、関数 aggregateGradients で関数 dlupdate を使用。aggregateGradients はこの例の終わりで定義するサポート関数です。この関数は、ミニバッチ全体のうち各ワーカーが処理している割合に基づいて正規化した後、gplus を使用し、勾配を加算してワーカー全体に複製します。

  • ネットワークの状態を集計および更新するために、関数 aggregateState で関数 dlupdate を使用。aggregateState はこの例の終わりで定義するサポート関数です。この関数は、ミニバッチ全体のうち各ワーカーが処理している割合に基づいて正規化した後、gplus を使用し、状態を加算してワーカー全体に複製します。

  • すべてのワーカーでネットワークの状態を集計。ネットワークのバッチ正規化層がデータの平均と分散を追跡します。ミニバッチ全体が複数のワーカーに分散されているため、各反復の後にネットワークの状態を集計し、ミニバッチ全体の平均と分散を計算します。

  • 最終勾配を計算した後、関数 sgdmupdate を使用し、ネットワークの学習可能なパラメーターを更新。

  • DataQueue と関数 send を使用し、学習の進行状況の情報をクライアントに返送。すべてのワーカーが同じ損失情報をもっているため、ワーカーを 1 つだけ使用してデータを送信します。GPU を搭載していないクライアント マシンがデータにアクセスできるよう、データを送信する前に dlarray に対して gather を使用し、データが確実に CPU 上にあるようにします。

spmd
    % Partition datastore.
    workerImds = partition(imdsTrain,N,labindex);
    workerImds.ReadSize = workerMiniBatchSize(labindex);
    
    workerVelocity = velocity;
   
    iteration = 0;
    
    for epoch = 1:numEpochs
        % Reset and shuffle the datastore.
        reset(workerImds);
        workerImds = shuffle(workerImds);
        
        % Loop over mini-batches.
        while gop(@and,hasdata(workerImds))
            iteration = iteration + 1;
            
            % Read a mini-batch of data.
            [workerXBatch,workerTBatch] = read(workerImds);
            workerXBatch = cat(4,workerXBatch{:});
            workerNumObservations = numel(workerTBatch.Label);

            % Normalize the images.
            workerXBatch =  single(workerXBatch) ./ 255;
            
            % Convert the labels to dummy variables.
            workerY = zeros(numClasses,workerNumObservations,'single');
            for c = 1:numClasses
                workerY(c,workerTBatch.Label==classes(c)) = 1;
            end
            
            % Convert the mini-batch of data to dlarray.
            dlworkerX = dlarray(workerXBatch,'SSCB');
            
            % If training on GPU, then convert data to gpuArray.
            if executionEnvironment == "gpu"
                dlworkerX = gpuArray(dlworkerX);
            end
            
            % Evaluate the model gradients and loss on the worker.
            [workerGradients,dlworkerLoss,workerState] = dlfeval(@modelGradients,dlnet,dlworkerX,workerY);
            
            % Aggregate the losses on all workers.
            workerNormalizationFactor = workerMiniBatchSize(labindex)./miniBatchSize;
            loss = gplus(workerNormalizationFactor*extractdata(dlworkerLoss));
            
            % Aggregate the network state on all workers
            workerState.Value = dlupdate(@aggregateState,workerState.Value,{workerNormalizationFactor});
            dlnet.State = workerState;
            
            % Aggregate the gradients on all workers.
            workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});
            
            % Update the network parameters using the SGDM optimizer.
            [dlnet.Learnables,workerVelocity] = sgdmupdate(dlnet.Learnables,workerGradients,workerVelocity);
        end
        
       % Display training progress information.
       if labindex == 1
           data = [epoch loss];
           send(Q,gather(data)); 
       end
    end
end

モデルのテスト

ネットワークに学習させた後、その精度をテストできます。

readall を使用してテスト データストアにあるテスト イメージをメモリに読み込み、それらを連結して正規化します。

XTest = readall(imdsTest);
XTest = cat(4,XTest{:});
XTest = single(XTest) ./ 255;
YTest = imdsTest.Labels;

学習が完了すると、各ワーカーがもつ完全な学習済みネットワークはすべて同じになります。それらのいずれかを取得します。

dlnetFinal = dlnet{1};

dlnetwork オブジェクトを使用してイメージを分類するには、dlarray に対して関数 predict を使用します。

dlYPredScores = predict(dlnetFinal,dlarray(XTest,'SSCB'));

関数 max を使用し、予測スコアからスコアが最も高いクラスを見つけます。これを行う前に、関数 extractdata を使用して dlarray からデータを抽出します。

[~,idx] = max(extractdata(dlYPredScores),[],1);
YPred = classes(idx);

モデルの分類精度を取得するには、テスト セットにおける予測を真のラベルと比較します。

accuracy = mean(YPred==YTest)
accuracy = 0.9990

補助関数の定義

ネットワークの学習可能なパラメーターについて損失の勾配を計算する関数 modelGradients を定義します。この関数は、forwardsoftmax を使用してミニバッチ X に対するネットワークの出力を計算します。これには真の出力を指定し、交差エントロピーを使用します。dlfeval と共にこの関数を呼び出すと、自動微分が有効になり、dlgradient は学習可能なパラメーターについての損失の勾配を自動的に計算できます。

function [dlgradients,dlloss,state] = modelGradients(dlnet,dlX,dlY)
    [dlYPred,state] = forward(dlnet,dlX);
    dlYPred = softmax(dlYPred);
    
    dlloss = crossentropy(dlYPred,dlY);
    dlgradients = dlgradient(dlloss,dlnet.Learnables);
end

ワーカーから送信される学習の進行状況の情報を表示する関数を定義します。この例では、ワーカーがデータを送信するたびに DataQueue によってこの関数が呼び出されます。

function displayTrainingProgress (data)
    disp("Epoch: " + data(1) + ", Loss: " + data(2));
end

すべてのワーカーの勾配を加算して集計する関数を定義します。gplus は、ワーカー上ですべての勾配を加算して複製します。加算する前に、ミニバッチ全体のうちワーカーが処理している割合を表す係数を勾配に乗算し、それらを正規化します。dlarray の内容を取得するには、extractdata を使用します。

function gradients = aggregateGradients(dlgradients,factor)
    gradients = extractdata(dlgradients);
    gradients = gplus(factor*gradients);
end

すべてのワーカーでネットワークの状態を集計する関数を定義します。このネットワークの状態には、データセットのバッチ正規化統計量が含まれます。これは、学習の反復全体における平均と分散の加重平均として計算されます。各ワーカーが処理するのはミニバッチの一部のみなので、すべてのデータの統計を表すように、ネットワークの状態を集計します。

function state = aggregateState(state, factor)
    state = gplus(factor*state);
end

参考

| | | | | | | | |

関連するトピック