カスタム学習ループを使用したネットワークの並列学習
この例では、ネットワークに並列学習させるためのカスタム学習ループを設定する方法を説明します。この例では、並列ワーカーによりミニバッチの一部で学習が実行されます。GPU がある場合、GPU 上で学習が行われます。学習中、DataQueue
オブジェクトによって、学習の進行状況の情報が MATLAB クライアントに送り返されます。
データセットの読み込み
数字のデータセットを読み込み、このデータセットのイメージ データストアを作成します。データストアを学習データストアとテスト データストアにランダムに分割します。学習データを格納する augmentedImageDatastore
を作成します。
digitDatasetPath = fullfile(toolboxdir("nnet"),"nndemos", ... "nndatasets","DigitDataset"); imds = imageDatastore(digitDatasetPath, ... IncludeSubfolders=true, ... LabelSource="foldernames"); [imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized"); inputSize = [28 28 1]; augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
学習セットに含まれる異なるクラスを判別します。
classes = categories(imdsTrain.Labels); numClasses = numel(classes);
ネットワークの定義
ネットワーク アーキテクチャを定義します。このネットワーク アーキテクチャには、バッチ正規化層が含まれています。この層は、データセットの統計量である平均と分散を追跡します。並列学習の場合、各反復ステップの最後にすべてのワーカーからの統計量を結合して、ネットワークの状態が必ずミニバッチ全体を反映するようにします。そうでない場合、ネットワークの状態がワーカー間で異なる可能性があります。たとえば、ステートフル再帰型ニューラル ネットワーク (RNN) の学習において、小さいシーケンスに分割されたシーケンス データを使用して LSTM 層または GRU 層を含むネットワークに学習させる場合、ワーカー間の状態の管理もしなければなりません。
layers = [
imageInputLayer(inputSize,Normalization="none")
convolution2dLayer(5,20)
batchNormalizationLayer
reluLayer
convolution2dLayer(3,20,Padding=1)
batchNormalizationLayer
reluLayer
convolution2dLayer(3,20,Padding=1)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer];
層配列から dlnetwork
オブジェクトを作成します。dlnetwork
オブジェクトにより、カスタム ループを使用した学習が可能になります。
net = dlnetwork(layers)
net = dlnetwork with properties: Layers: [12×1 nnet.cnn.layer.Layer] Connections: [11×2 table] Learnables: [14×3 table] State: [6×3 table] InputNames: {'imageinput'} OutputNames: {'softmax'} Initialized: 1 View summary with summary.
並列環境の設定
関数 canUseGPU
を使用し、MATLAB で GPU が使用可能かどうかを判定します。
使用できる GPU がある場合、GPU 上で学習を実行。GPU と同じ数のワーカーを使用して並列プールを作成。
使用できる GPU がない場合、CPU 上で学習を実行。既定の数のワーカーを使用して並列プールを作成。
if canUseGPU executionEnvironment = "gpu"; numberOfGPUs = gpuDeviceCount("available"); pool = parpool(numberOfGPUs); else executionEnvironment = "cpu"; pool = parpool; end
Starting parallel pool (parpool) using the 'Processes' profile ... Connected to the parallel pool (number of workers: 4).
並列プール内のワーカー数を取得します。この例では後ほど、この数に基づいて作業負荷を分割します。
numWorkers = pool.NumWorkers;
モデルの学習
学習オプションを指定します。
numEpochs = 20; miniBatchSize = 128; velocity = [];
GPU を使用した学習では、GPU の数でミニバッチ サイズを線形にスケールアップし、各 GPU における作業負荷を一定に保つことを推奨します。関連するアドバイスの詳細については、MATLAB による複数の GPU での深層学習を参照してください。
if executionEnvironment == "gpu" miniBatchSize = miniBatchSize .* numWorkers end
miniBatchSize = 512
ミニバッチ全体のサイズをワーカー間で均等に配分し、各ワーカーのミニバッチ サイズを計算します。余りは最初のワーカー間で分配します。
workerMiniBatchSize = floor(miniBatchSize ./ repmat(numWorkers,1,numWorkers)); remainder = miniBatchSize - sum(workerMiniBatchSize); workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,numWorkers-remainder)]
workerMiniBatchSize = 1×4
128 128 128 128
このネットワークには、ネットワークに学習させているデータの平均と分散を追跡するバッチ正規化層が含まれています。各ワーカーは各反復中に各ミニバッチの一部を処理するため、平均と分散はすべてのワーカーにわたって集計しなければなりません。ネットワークの状態プロパティで、バッチ正規化層の平均と分散の状態パラメーターについて、インデックスを検索します。
batchNormLayers = arrayfun(@(l)isa(l,"nnet.cnn.layer.BatchNormalizationLayer"),net.Layers); batchNormLayersNames = string({net.Layers(batchNormLayers).Name}); state = net.State; isBatchNormalizationStateMean = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedMean"; isBatchNormalizationStateVariance = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedVariance";
TrainingProgressMonitor
オブジェクトを初期化します。monitor オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。
monitor = trainingProgressMonitor( ... Metrics="TrainingLoss", ... Info=["Epoch" "Workers"], ... XLabel="Iteration");
ワーカーで Dataqueue
オブジェクトを作成して、[停止] ボタンが押されたときに学習を停止するためのフラグを送信します。
spmd stopTrainingEventQueue = parallel.pool.DataQueue; end stopTrainingQueue = stopTrainingEventQueue{1};
学習中にワーカーからデータを返すため、DataQueue
オブジェクトを作成します。afterEach
を使用して関数 displayTrainingProgress
を設定し、ワーカーがデータを送信するたびに呼び出されるようにします。displayTrainingProgress
は、TrainingProgressMonitor
オブジェクトを更新して表示し、ワーカーから送信される学習の進行状況の情報を表示し、[停止] ボタンが押された場合にワーカーにフラグを送信するサポート関数 (この例の最後で定義) です。
dataQueue = parallel.pool.DataQueue; displayFcn = @(x) displayTrainingProgress(x,numEpochs,numWorkers,monitor,stopTrainingQueue); afterEach(dataQueue,displayFcn)
次の手順で詳しく説明するように、カスタム並列学習ループを使用してモデルに学習させます。すべてのワーカーで同時にコードを実行するには、spmd
ブロックを使用します。spmd
ブロック内の spmdIndex
により、現在コードを実行しているワーカーのインデックスが与えられます。
学習前に、関数 partition
を使用して各ワーカーのデータストアを分割します。分割されたデータストアを使用して、各ワーカーに minibatchqueue
を作成します。各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
(この例の最後に定義) を使用して、データを正規化し、ラベルを one-hot 符号化変数に変換し、ミニバッチ内の観測値の数を判定します。イメージ データを次元ラベル
"SSCB"
(spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue
オブジェクトは、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。クラス ラベルや観測数には書式を追加しないでください。GPU が利用できる場合、GPU で学習を行います。既定では、
minibatchqueue
オブジェクトは、GPU が利用可能な場合、各出力をgpuArray
に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox) (Parallel Computing Toolbox) を参照してください。
関数 reset
と関数 shuffle
を使用し、エポックごとにデータストアをリセットしてシャッフルします。エポック内のそれぞれの反復で次を行います。
データの並列処理を開始する前に、
spmdreduce
を使用してグローバルなand
演算を関数hasdata
の結果に対して実行し、すべてのワーカーに利用可能なデータがあることを確認。関数
next
を使用して、minibatchqueue
からミニバッチを読み取ります。関数
modelLoss
でdlfeval
を呼び出すことによって、各ワーカーのネットワークの損失と勾配を計算。関数dlfeval
は、自動微分を有効にして補助関数modelLoss
を評価し、modelLoss
が損失の勾配を自動的に計算できるようにします。modelLoss
(この例の最後で定義) は、ネットワーク、データのミニバッチ、真のラベルを受け取り、損失と勾配を返します。全体的な損失を取得するために、すべてのワーカーの損失を集計。この例では、損失関数の交差エントロピーを使用します。集計された損失はすべての損失の合計です。集計する前に、ミニバッチ全体のうちワーカーが処理している割合で乗算し、各損失を正規化します。
spmdPlus
を使用してすべての損失を加算し、ワーカー全体にその結果を複製します。すべてのワーカーの勾配を集計および更新するために、関数
aggregateGradients
で関数dlupdate
を使用。aggregateGradients
はこの例の終わりで定義するサポート関数です。この関数は、ミニバッチ全体のうち各ワーカーが処理している割合に基づいて正規化した後、spmdPlus
を使用し、勾配を加算してワーカー全体に複製します。関数
aggregateState
を使用して、すべてのワーカーのネットワークの状態を集約。aggregateState
は、この例の最後で定義されているサポート関数です。ネットワークのバッチ正規化層がデータの平均と分散を追跡します。ミニバッチ全体が複数のワーカーに分散されているため、各反復の後にネットワークの状態を集計し、ミニバッチ全体の平均と分散を計算します。最終勾配を計算した後、関数
sgdmupdate
を使用し、ネットワークの学習可能なパラメーターを更新。
各エポックの後、[停止] ボタンが押されたかどうかを確認し、Dataqueue
オブジェクトと関数 send
を使用して学習の進行状況の情報をクライアントに送り返します。すべてのワーカーが同じ損失情報をもっているため、必要なのは 1 つのワーカーを使用してデータを送り返すことだけです。データが確実に CPU 上にあり、GPU を搭載していないクライアント マシンがデータにアクセスできるようにするには、データをクライアントに送信する前に、dlarray
に対して gather
を使用します。各エポックの後にワーカー間の通信が発生するため、[停止] をクリックして、現在のエポックの最後で学習を停止します。各反復の最後に [停止] ボタンで学習を停止させる場合、[停止] ボタンが押されたかどうかを確認し、反復ごとに学習の進行状況の情報をクライアントに送り返せますが、通信オーバーヘッドが増加します。
spmd % Reset and shuffle the datastore. reset(augimdsTrain); augimdsTrain = shuffle(augimdsTrain); % Partition datastore. workerImds = partition(augimdsTrain,numWorkers,spmdIndex); % Create minibatchqueue using partitioned datastore on each worker workerMbq = minibatchqueue(workerImds,3,... MiniBatchSize=workerMiniBatchSize(spmdIndex),... MiniBatchFcn=@preprocessMiniBatch,... MiniBatchFormat=["SSCB" "" ""]); workerVelocity = velocity; epoch = 0; iteration = 0; stopRequest = false; while epoch < numEpochs && ~stopRequest epoch = epoch + 1; shuffle(workerMbq); % Loop over mini-batches while spmdReduce(@and,hasdata(workerMbq)) && ~stopRequest iteration = iteration + 1; % Read a mini-batch of data [workerX,workerT,workerNumObservations] = next(workerMbq); % Evaluate the model loss and gradients on the worker [workerLoss,workerGradients,workerState] = dlfeval(@modelLoss,net,workerX,workerT); % Aggregate the losses on all workers workerNormalizationFactor = workerMiniBatchSize(spmdIndex)./miniBatchSize; loss = spmdPlus(workerNormalizationFactor*extractdata(workerLoss)); % Aggregate the network state on all workers net.State = aggregateState(workerState,workerNormalizationFactor,... isBatchNormalizationStateMean,isBatchNormalizationStateVariance); % Aggregate the gradients on all workers workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor}); % Update the network parameters using the SGDM optimizer [net,workerVelocity] = sgdmupdate(net,workerGradients,workerVelocity); end % Stop training if the Stop button has been clicked stopRequest = spmdPlus(stopTrainingEventQueue.QueueLength); % Send training progress information to the client if spmdIndex == 1 data = [epoch loss iteration]; send(dataQueue,gather(data)); end end end
モデルのテスト
ネットワークに学習させた後、その精度をテストできます。
readall
を使用してテスト データストアにあるテスト イメージをメモリに読み込み、それらを連結して正規化します。
XTest = readall(imdsTest); XTest = cat(4,XTest{:}); XTest = single(XTest) ./ 255; TTest = imdsTest.Labels;
学習が完了すると、各ワーカーがもつ完全な学習済みネットワークはすべて同じになります。それらのいずれかを取得します。
netFinal = net{1};
dlnetwork
オブジェクトを使用してイメージを分類するには、dlarray
に対して関数 predict
を使用します。
YTest = predict(netFinal,dlarray(XTest,"SSCB"));
関数 max
を使用し、予測スコアからスコアが最も高いクラスを見つけます。これを行う前に、関数 extractdata
を使用して dlarray
からデータを抽出します。
[~,idx] = max(extractdata(YTest),[],1); YTest = classes(idx);
モデルの分類精度を取得するには、テスト セットにおける予測を真のラベルと比較します。
accuracy = mean(YTest==TTest)
accuracy = 0.9440
ミニ バッチ前処理関数
関数 preprocessMiniBatch
は、次の手順を使用して予測子とラベルのミニバッチを前処理します。
ミニバッチ内の観測数を判定します。
関数
preprocessMiniBatchPredictors
を使用してイメージを前処理します。入力 cell 配列からラベル データを抽出し、2 番目の次元に沿って categorical 配列に連結します。
カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。
function [X,Y,numObs] = preprocessMiniBatch(XCell,YCell) numObs = numel(YCell); % Preprocess predictors. X = preprocessMiniBatchPredictors(XCell); % Extract label data from cell and concatenate. Y = cat(2,YCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,1); end
ミニバッチ予測子前処理関数
関数 preprocessMiniBatchPredictors
は、入力 cell 配列からイメージ データを抽出することで予測子のミニバッチを前処理し、数値配列に連結します。グレースケール入力では、4 番目の次元で連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されることになります。その後、データが正規化されます。
function X = preprocessMiniBatchPredictors(XCell) % Concatenate. X = cat(4,XCell{1:end}); % Normalize. X = X ./ 255; end
モデル損失関数
ネットワークの学習可能なパラメーターについて損失の勾配を計算する関数 modelLoss
を定義します。この関数は、forward
を使用してミニバッチ X
に対するネットワークの出力を計算し、クロス エントロピーを使用して、本来の出力が与えられたときの損失を計算します。dlfeval
と共にこの関数を呼び出すと、自動微分が有効になり、dlgradient
は学習可能なパラメーターについての損失の勾配を自動的に計算できます。
function [loss,gradients,state] = modelLoss(net,X,T) [Y,state] = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables); end
学習の進行状況を表示する関数
ワーカーから送信される学習の進行状況の情報を表示し、[停止] ボタンがクリックされたかどうかを確認する関数を定義します。[停止] ボタンがクリックされると、学習を停止する必要があることを示すフラグがワーカーに送信されます。この例では、ワーカーがデータを送信するたびに DataQueue
によってこの関数が呼び出されます。
function displayTrainingProgress(data,numEpochs,numWorkers,monitor,stopTrainingQueue) epoch = data(1); loss = data(2); iteration = data(3); recordMetrics(monitor,iteration,TrainingLoss=loss); updateInfo(monitor,Epoch=epoch + " of " + numEpochs, Workers= numWorkers); monitor.Progress = 100 * epoch/numEpochs; if monitor.Stop send(stopTrainingQueue,true); end end
勾配集計関数
すべてのワーカーの勾配を加算して集計する関数を定義します。spmdPlus
は、ワーカー上ですべての勾配を加算して複製します。加算する前に、ミニバッチ全体のうちワーカーが処理している割合を表す係数を勾配に乗算し、それらを正規化します。dlarray
の内容を取得するには、extractdata
を使用します。
function gradients = aggregateGradients(gradients,factor) gradients = extractdata(gradients); gradients = spmdPlus(factor*gradients); end
状態集計関数
すべてのワーカーでネットワークの状態を集計する関数を定義します。このネットワークの状態には、データ セットの学習済みバッチ正規化統計量が含まれます。各ワーカーが処理するのはミニバッチの一部のみなので、すべてのデータの統計を表すように、ネットワークの状態を集計します。ミニバッチごとに、統合平均が、各反復のワーカー全体の平均に対する加重平均として計算されます。統合分散は、次の式に従って計算されます。
ここで、 はワーカーの合計数、 はミニバッチの観測値の合計数、 は 番目のワーカーで処理された観測値の数、 と はそのワーカーで計算された平均と分散の統計、 はすべてのワーカー全体の統合平均です。
function state = aggregateState(state,factor,... isBatchNormalizationStateMean,isBatchNormalizationStateVariance) stateMeans = state.Value(isBatchNormalizationStateMean); stateVariances = state.Value(isBatchNormalizationStateVariance); for j = 1:numel(stateMeans) meanVal = stateMeans{j}; varVal = stateVariances{j}; % Calculate combined mean combinedMean = spmdPlus(factor*meanVal); % Calculate combined variance terms to sum varTerm = factor.*(varVal + (meanVal - combinedMean).^2); % Update state stateMeans{j} = combinedMean; stateVariances{j} = spmdPlus(varTerm); end state.Value(isBatchNormalizationStateMean) = stateMeans; state.Value(isBatchNormalizationStateVariance) = stateVariances; end
参考
dlarray
| dlnetwork
| sgdmupdate
| dlupdate
| dlfeval
| dlgradient
| crossentropy
| softmax
| forward
| predict