Main Content

このページは前リリースの情報です。該当の英語のページはこのリリースで削除されています。

カスタム トレーニング ループを使用した並列でのネットワークの学習

この例では、カスタム トレーニング ループを設定して並列でネットワークに学習させる方法を説明します。この例では、並列ワーカーはミニバッチ全体の各部分で学習させます。GPU がある場合は、学習は GPU で実行されます。学習中、DataQueue オブジェクトが学習の進行状況の情報を MATLAB クライアントに送り返します。

データセットの読み込み

数字データセットを読み込み、そのデータセットのイメージ データ ストアを作成します。このデータ ストアを学習データ ストアとテスト データ ストアにランダムに分割します。

digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, ...
    'LabelSource','foldernames');

[imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

学習セットの各種のクラスを決定します。

classes = categories(imdsTrain.Labels);
numClasses = numel(classes);

ネットワークの定義

ネットワーク アーキテクチャを定義し、関数 layerGraph を使用してこれを層グラフに変換します。このネットワーク アーキテクチャには、データセットの平均と分散の統計値を追跡するバッチ正規化層が含まれます。並列で学習させる場合は、各反復ステップの最後にすべてのワーカーの統計値を組み合わせて、ネットワークの状態がミニバッチ全体を確実に反映するようにします。そうしない場合、ネットワークの状態がワーカー間で分かれることがあります。たとえば、小さいシーケンスに分割したシーケンス データを使用して LSTM 層または GRU 層を含むネットワークに学習させる場合など、ステートフル再帰型ニューラル ネットワーク (RNN) に学習させる場合、ワーカー間の状態も管理しなければなりません。

layers = [
    imageInputLayer([28 28 1],'Name','input','Normalization','none')
    convolution2dLayer(5,20,'Name','conv1')
    batchNormalizationLayer('Name','bn1')
    reluLayer('Name','relu1')
    convolution2dLayer(3,20,'Padding',1,'Name','conv2')
    batchNormalizationLayer('Name','bn2')
    reluLayer('Name','relu2')
    convolution2dLayer(3,20,'Padding',1,'Name','conv3')
    batchNormalizationLayer('Name','bn3')
    reluLayer('Name','relu3')
    fullyConnectedLayer(numClasses,'Name','fc')];

lgraph = layerGraph(layers);

層グラフから dlnetwork オブジェクトを作成します。dlnetwork オブジェクトを使用すると、カスタム ループによる学習が可能になります。

dlnet = dlnetwork(lgraph)
dlnet = 
  dlnetwork with properties:

         Layers: [11×1 nnet.cnn.layer.Layer]
    Connections: [10×2 table]
     Learnables: [14×3 table]
          State: [6×3 table]
     InputNames: {'input'}
    OutputNames: {'fc'}
    Initialized: 1

並列環境の設定

MATLAB で GPU が使用可能であるかどうかを、関数 canUseGPU で判別します。

  • 使用可能な GPU がある場合は、その GPU で学習します。GPU と同じ数のワーカーをもつ並列プールを作成します。

  • 使用可能な GPU がない場合は、CPU で学習します。既定の数のワーカーをもつ並列プールを作成します。

if canUseGPU
    executionEnvironment = "gpu";
    numberOfGPUs = gpuDeviceCount("available");
    pool = parpool(numberOfGPUs);
else
    executionEnvironment = "cpu";
    pool = parpool;
end

並列プール内のワーカー数を取得します。この例の後の方で、この数に従って作業負荷を分割します。

N = pool.NumWorkers;

モデルを学習させる

学習オプションを指定します。

numEpochs = 20;
miniBatchSize = 128;
velocity = [];

GPU の学習の場合に推奨される方法は、ミニバッチのサイズを GPU 数に比例してスケール アップし、各 GPU の作業負荷を一定に保つことです。関連する情報については、複数の GPU による学習 (Deep Learning Toolbox)を参照してください。

if executionEnvironment == "gpu"
    miniBatchSize = miniBatchSize .* N
end
miniBatchSize = 512

ミニバッチ全体のサイズをワーカー間で均等に分割して、ワーカーごとのミニバッチのサイズを計算します。剰余は最初のワーカー間で分割します。

workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N));
remainder = miniBatchSize - sum(workerMiniBatchSize);
workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)]
workerMiniBatchSize = 1×4

   128   128   128   128

学習の進行状況プロットを初期化します。

% Set up the training plot
figure
lineLossTrain = animatedline('Color',[0.85 0.325 0.098]);
ylim([0 inf])
xlabel("Iteration")
ylabel("Loss")
grid on

学習中にデータをワーカーから送り返すには、DataQueue オブジェクトを作成します。afterEach を使用して、ワーカーがデータを送信するたびに関数 displayTrainingProgress を呼び出すよう設定します。displayTrainingProgress は補助関数で、この例の最後に定義します。この関数はワーカーから送信された学習の進行状況の情報を表示します。

Q = parallel.pool.DataQueue;
displayFcn = @(x) displayTrainingProgress(x,lineLossTrain);
afterEach(Q,displayFcn);

カスタム並列学習ループを使用してモデルを学習させます。詳細はこの後の手順で説明します。すべてのワーカーでコードを同時に実行するには、spmd ブロックを使用します。spmd ブロック内で、labindex は現在コードを実行しているワーカーのインデックスを示します。

学習前に、関数 partition を使用して各ワーカーのデータ ストアを分割し、ReadSize をワーカーのミニバッチのサイズに設定します。

エポックごとに、関数 reset および shuffle を使用して、データ ストアをリセットおよびシャッフルします。エポック内の反復ごとに、次を行います。

  • 関数 hasdata の結果に対してグローバルな and 演算 (gop) を実行し、データの並列処理を開始する前にすべてのワーカーでデータが使用可能であることを確認します。

  • 関数 read を使用してデータ ストアからミニバッチを読み取り、取得したイメージを連結してイメージの 4 次元配列を生成します。イメージを正規化して、ピクセルの値が 01 の間になるようにします。

  • ラベルをダミー変数の行列に変換します。これにより、観測値に対してラベルが設定されます。ダミー変数には、観測値のラベルについては 1、それ以外については 0 が格納されます。

  • データのミニバッチを、基となる型が single の dlarray オブジェクトに変換し、次元ラベル 'SSCB' (空間、空間、チャネル、バッチ) を指定します。GPU 学習の場合、データを gpuArray に変換します。

  • 関数 modelGradientsdlfeval を呼び出し、各ワーカーの勾配とネットワークの損失を計算します。関数 dlfeval は、自動微分を有効にして補助関数 modelGradients を評価します。これにより、modelGradients は損失に対する勾配を自動的に計算します。modelGradients はこの例の最後で定義され、指定されたネットワーク、データのミニバッチ、真のラベルに基づいて損失と勾配を返します。

  • 全体の損失を取得するには、すべてのワーカーの損失を集計します。この例では、損失関数のクロス エントロピーを使用しています。損失の集計値は、すべての損失の合計です。集計の前に、ワーカーが作業しているミニバッチ全体の比率を乗じて、各損失を正規化します。gplus を使用して、すべての損失を加算し、その結果をワーカー間で複製します。

  • すべてのワーカーの勾配を集計して更新するには、関数 aggregateGradients と関数 dlupdate を使用します。aggregateGradients はこの例の最後で定義されるサポート関数です。この関数は、各ワーカーが作業しているミニバッチ全体の比率に従って正規化してから、gplus を使用してワーカー間で勾配を加算および複製します。

  • 関数 aggregateState を使用して、すべてのワーカーについてネットワークの状態を集計します。aggregateState は、この例の最後で定義されているサポート関数です。ネットワークのバッチ正規化層は、データの平均と分散を追跡します。ミニバッチの全体が複数のワーカー間に分散されているため、各反復後にネットワークの状態を集計して、ミニバッチ全体の平均と分散を計算します。

  • 最終的な勾配を計算した後、関数 sgdmupdate によりネットワークの学習可能なパラメーターを更新します。

  • DataQueue を指定して関数 send を使用し、学習の進行状況の情報をクライアントに送り返します。データの送信には 1 つのワーカーのみを使用してください。すべてのワーカーで損失情報は同じです。GPU がないクライアント マシンでもデータにアクセスできるようにするために、データが CPU にあることを確認するには、データを送信する前に dlarraygather を使用します。

start = tic;
spmd
    % Partition datastore.
    workerImds = partition(imdsTrain,N,labindex);
    workerImds.ReadSize = workerMiniBatchSize(labindex);
    
    workerVelocity = velocity;
   
    iteration = 0;
    
    for epoch = 1:numEpochs
        % Reset and shuffle the datastore.
        reset(workerImds);
        workerImds = shuffle(workerImds);
        
        % Loop over mini-batches.
        while gop(@and,hasdata(workerImds))
            iteration = iteration + 1;
            
            % Read a mini-batch of data.
            [workerXBatch,workerTBatch] = read(workerImds);
            workerXBatch = cat(4,workerXBatch{:});
            workerNumObservations = numel(workerTBatch.Label);

            % Normalize the images.
            workerXBatch =  single(workerXBatch) ./ 255;
            
            % Convert the labels to dummy variables.
            workerY = zeros(numClasses,workerNumObservations,'single');
            for c = 1:numClasses
                workerY(c,workerTBatch.Label==classes(c)) = 1;
            end
            
            % Convert the mini-batch of data to dlarray.
            dlworkerX = dlarray(workerXBatch,'SSCB');
            
            % If training on GPU, then convert data to gpuArray.
            if executionEnvironment == "gpu"
                dlworkerX = gpuArray(dlworkerX);
            end
            
            % Evaluate the model gradients and loss on the worker.
            [workerGradients,dlworkerLoss,workerState] = dlfeval(@modelGradients,dlnet,dlworkerX,workerY);
            
            % Aggregate the losses on all workers.
            workerNormalizationFactor = workerMiniBatchSize(labindex)./miniBatchSize;
            loss = gplus(workerNormalizationFactor*extractdata(dlworkerLoss));
            
            % Aggregate the network state on all workers
            dlnet.State = aggregateState(workerState,workerNormalizationFactor);
            
            % Aggregate the gradients on all workers.
            workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});
            
            % Update the network parameters using the SGDM optimizer.
            [dlnet.Learnables,workerVelocity] = sgdmupdate(dlnet.Learnables,workerGradients,workerVelocity);
        end
        
       % Display training progress information.
       if labindex == 1
           data = [epoch loss iteration toc(start)];
           send(Q,gather(data)); 
       end
    end
end

モデルのテスト

ネットワークを学習させた後、その精度をテストできます。

テスト データ ストアで readall を使用してテスト イメージをメモリに読み込み、連結してから、正規化します。

XTest = readall(imdsTest);
XTest = cat(4,XTest{:});
XTest = single(XTest) ./ 255;
YTest = imdsTest.Labels;

学習が完了すると、すべてのワーカーが同一の完全な学習済みネットワークをもつことになります。そのいずれかを取得します。

dlnetFinal = dlnet{1};

dlnetwork オブジェクトを使用してイメージを分類するには、dlarray で関数 predict を使用します。

dlYPredScores = predict(dlnetFinal,dlarray(XTest,'SSCB'));

関数 max を使用して、予測されたスコアの中から最高のスコアをもつクラスを検索します。その前に、関数 extractdata を使用して、dlarray からデータを抽出します。

[~,idx] = max(extractdata(dlYPredScores),[],1);
YPred = classes(idx);

モデルの分類精度を取得するには、テスト セットでの予測と真のラベルを比較します。

accuracy = mean(YPred==YTest)
accuracy = 0.9910

勾配のモデル化関数

ネットワークの学習可能なパラメーターに対する損失の勾配を計算する関数 modelGradients を定義します。この関数は、forwardsoftmax を使用してミニバッチ X のネットワーク出力を計算し、指定した真の出力に基づいて、クロス エントロピーを使用して損失を計算します。この関数を dlfeval で呼び出すと、自動微分が有効になり、dlgradient は学習可能に対する損失の勾配を自動的に計算できます。

function [dlgradients,dlloss,state] = modelGradients(dlnet,dlX,dlY)
    [dlYPred,state] = forward(dlnet,dlX);
    dlYPred = softmax(dlYPred);
    
    dlloss = crossentropy(dlYPred,dlY);
    dlgradients = dlgradient(dlloss,dlnet.Learnables);
end

学習の進行状況の表示関数

ワーカーから送信された学習の進行状況の情報を表示する関数を定義します。この例の DataQueue は、ワーカーがデータを送信するたびにこの関数を呼び出します。

function displayTrainingProgress (data,line)
     addpoints(line,double(data(3)),double(data(2)))
     D = duration(0,0,data(4),'Format','hh:mm:ss');
     title("Epoch: " + data(1) + ", Elapsed: " + string(D))
     drawnow
end

勾配の集計関数

すべてのワーカーの勾配を加算して集計する関数を定義します。gplus は、各ワーカーのすべての勾配を加算して複製します。加算する前に、ワーカーが作業しているミニバッチ全体の割合を表す係数を乗じて正規化します。dlarray の内容を取得するには、extractdata を使用します。

function gradients = aggregateGradients(dlgradients,factor)
    gradients = extractdata(dlgradients);
    gradients = gplus(factor*gradients);
end

状態の集計関数

すべてのワーカーにおけるネットワークの状態を集計する関数を定義します。ネットワークの状態には、学習済みデータセットのバッチ正規化統計値が含まれています。各ワーカーはミニバッチの一部のみを認識するため、ネットワークの状態を集計することにより、その統計値がデータ全体の統計を表すようにします。ミニバッチごとに、各反復のワーカーをまたがった平均値の荷重平均として、複合平均値が計算されます。複合分散は以下の式に従って計算されます。

sc2=1Mj=1Nmj[sj2+(xj-xc)2]

ここで、N は合計ワーカー数、M はミニバッチ内の合計観測数、mjj 番目のワーカー上で処理された観測の数、xjsj2 はそのワーカー上で計算された平均と分散の統計、xc はすべてのワーカーにまたがる複合平均値です。

function state = aggregateState(state,factor)

    numrows = size(state,1);
    
    for j = 1:numrows
        isBatchNormalizationState = state.Parameter(j) =="TrainedMean"...
            && state.Parameter(j+1) =="TrainedVariance"...
            && state.Layer(j) == state.Layer(j+1);
        
        if isBatchNormalizationState
            meanVal = state.Value{j};
            varVal = state.Value{j+1};
            
            % Calculate combined mean
            combinedMean = gplus(factor*meanVal);
                   
            % Caclulate combined variance terms to sum
            combinedVarTerm = factor.*(varVal + (meanVal - combinedMean).^2);        
            
            % Update state
            state.Value(j) = {combinedMean};
            state.Value(j+1) = {gplus(combinedVarTerm)};
           
        end
    end
end

参考

| | | | | | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

関連する例

詳細