カスタム学習ループを使用したネットワークの並列学習
この例では、ネットワークに並列学習させるためのカスタム学習ループを設定する方法を説明します。
深層ニューラル ネットワークの学習には計算コストがかかり、計算に何時間もかかる可能性があります。特に複数の GPU がある場合、学習を高速化するために、ネットワークの並列学習を行うことができます。
この例では、並列ワーカーによりミニバッチの一部で学習が実行されます。GPU がある場合、GPU 上で学習が行われます。学習中、DataQueue オブジェクトによって、学習の進行状況の情報が MATLAB クライアントに送り返されます。
データ セットの読み込み
Flowers データ セット [1] をダウンロードし、解凍します。Flowers データ セットには、5 つのクラス ("デイジー"、"タンポポ"、"バラ"、"ヒマワリ"、および "チューリップ") に属する 3670 個の花のイメージが格納されています。
url = "http://download.tensorflow.org/example_images/flower_photos.tgz"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"flower_dataset.tgz"); dataFolder = fullfile(downloadFolder,"flower_photos"); if ~exist(dataFolder,"dir") fprintf("Downloading Flowers data set (218 MB)... ") websave(filename,url); untar(filename,downloadFolder) fprintf("Done.\n") end
Downloading Flowers data set (218 MB)...
データ セットのイメージ データストアを作成します。データストアを学習データストアとテスト データストアにランダムに分割します。
imds = imageDatastore(dataFolder, ... IncludeSubfolders=true, ... LabelSource="foldernames"); [imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");
データ拡張は、ネットワークで過適合が発生したり、学習イメージの正確な詳細が記憶されたりすることを防止するのに役立ちます。imageDataAugmenter オブジェクトを使用して、学習用イメージのサイズ変更と拡張を行います。
縦軸を軸としてイメージをランダムに反転させます。
垂直方向および水平方向にイメージを最大 10 ピクセル、ランダムに平行移動します。
時計回りおよび反時計回りにイメージを最大 45 度、ランダムに回転させます。
垂直方向および水平方向にイメージを最大 10%、ランダムに拡大します。
pixelRange = [-10 10]; scaleRange = [0.9 1.1]; imageAugmenter = imageDataAugmenter( ... RandXReflection=true, ... RandXTranslation=pixelRange, ... RandYTranslation=pixelRange, ... RandRotation=[-45 45], ... RandXScale=scaleRange, ... RandYScale=scaleRange); inputSize = [100 100 3]; augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,DataAugmentation=imageAugmenter);
関数 shuffle を使用してデータをシャッフルします。
augimdsTrain = shuffle(augimdsTrain);
学習データ セットに含まれる異なるクラスを判別します。
classes = categories(imdsTrain.Labels)
classes = 5×1 cell
{'daisy' }
{'dandelion' }
{'roses' }
{'sunflowers'}
{'tulips' }
numClasses = numel(classes)
numClasses = 5
ネットワークの定義
2 次元残差ネットワークを作成します。このネットワーク アーキテクチャには、バッチ正規化層が含まれています。この層は、データ セットの統計量である平均と分散を追跡します。並列学習の場合、各反復ステップの最後にすべてのワーカーからの統計量を結合して、ネットワークの状態が必ずミニバッチ全体を反映するようにします。そうでない場合、ネットワークの状態がワーカー間で異なる可能性があります。
net = resnetNetwork(inputSize,numClasses)
net =
dlnetwork with properties:
Layers: [176×1 nnet.cnn.layer.Layer]
Connections: [191×2 table]
Learnables: [214×3 table]
State: [106×3 table]
InputNames: {'input'}
OutputNames: {'softmax'}
Initialized: 1
View summary with summary.
代わりに再帰型ニューラル ネットワーク (RNN)、たとえば LSTM 層や GRU 層を含むネットワークを使用する場合、再帰層は学習中に変化する状態プロパティをもつことになります。したがって、カスタム学習ループを使用してネットワークに並列学習させる場合は、それらの状態を管理するように注意しなければなりません。ステートレスに (つまり、学習反復間で再帰層の状態を維持することなく) RNN に学習させるには、各学習反復の最後で resetState を呼び出します。ワーカー全体でバッチ正規化の統計値を集計するためにコードを編集する必要はありません。
並列環境の設定
関数 canUseGPU を使用し、MATLAB で GPU が使用可能かどうかを判定します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスの詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
使用できる GPU がある場合、GPU 上で学習を実行。GPU と同じ数のワーカーを使用して並列プールを作成。
使用できる GPU がない場合、CPU 上で学習を実行。既定の数のワーカーを使用して並列プールを作成。
if canUseGPU executionEnvironment = "gpu"; numberOfGPUs = gpuDeviceCount("available"); gpuDeviceTable pool = parpool(numberOfGPUs); else executionEnvironment = "cpu"; pool = parpool; end
ans=4×5 table
Index Name ComputeCapability DeviceAvailable DeviceSelected
_____ __________________ _________________ _______________ ______________
1 "NVIDIA RTX A5000" "8.6" true true
2 "NVIDIA RTX A5000" "8.6" true false
3 "NVIDIA RTX A5000" "8.6" true false
4 "NVIDIA RTX A5000" "8.6" true false
Starting parallel pool (parpool) using the 'Processes' profile ... 22-Oct-2024 05:36:23: Job Queued. Waiting for parallel pool job with ID 1 to start ... 22-Oct-2024 05:37:24: Job Queued. Waiting for parallel pool job with ID 1 to start ... Connected to parallel pool with 4 workers.
並列プール内のワーカー数を取得します。
numWorkers = pool.NumWorkers;
モデルの学習
学習オプションを指定します。
numEpochs = 100; miniBatchSize = 128; velocity = [];
GPU を使用した学習では、GPU の数でミニバッチ サイズを線形にスケールアップし、各 GPU における作業負荷を一定に保つことを推奨します。関連するアドバイスの詳細については、MATLAB による複数の GPU での深層学習を参照してください。
if executionEnvironment == "gpu" miniBatchSize = miniBatchSize .* numWorkers end
miniBatchSize = 512
ミニバッチ全体のサイズをワーカー間で均等に配分し、各ワーカーのミニバッチ サイズを計算します。余りは最初のワーカー間で分配します。
workerMiniBatchSize = floor(miniBatchSize ./ repmat(numWorkers,1,numWorkers)); remainder = miniBatchSize - sum(workerMiniBatchSize); workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,numWorkers-remainder)]
workerMiniBatchSize = 1×4
128 128 128 128
このネットワークには、ネットワークに学習させているデータの平均と分散を追跡するバッチ正規化層が含まれています。各ワーカーは各反復中に各ミニバッチの一部を処理するため、平均と分散はすべてのワーカーにわたって集計しなければなりません。ネットワーク内のすべてのバッチ正規化層の名前を検索します。
layers = net.Layers; batchNormLayersNames = string.empty; for idx = 1:numel(layers) currentLayer = layers(idx); if isa(currentLayer,"nnet.cnn.layer.BatchNormalizationLayer") batchNormLayersNames(end+1) = currentLayer.Name; end end
ネットワークの状態プロパティで、バッチ正規化層の平均と分散の状態パラメーターについて、インデックスを検索します。
state = net.State; isBatchNormalizationStateMean = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedMean"; isBatchNormalizationStateVariance = ismember(state.Layer,batchNormLayersNames) & state.Parameter == "TrainedVariance";
TrainingProgressMonitor オブジェクトを初期化します。監視オブジェクトを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。
monitor = trainingProgressMonitor( ... Metrics=["TrainingLoss" "TrainingAccuracy"], ... Info=["Epoch" "Workers"], ... XLabel="Iteration");

[停止] ボタンがクリックされたときに学習を停止するフラグを送信するための PollableDataqueue オブジェクトを作成します。
stopTrainingQueue = parallel.pool.PollableDataQueue(Destination="any");学習中にワーカーからデータを返すため、DataQueue オブジェクトを作成します。afterEach を使用して関数 displayTrainingProgress を設定し、ワーカーがデータを送信するたびに呼び出されるようにします。displayTrainingProgress は、TrainingProgressMonitor オブジェクトを更新して表示し、ワーカーから送信される学習の進行状況の情報を表示し、[停止] ボタンがクリックされた場合にワーカーにフラグを送信するサポート関数 (この例の最後で定義) です。
dataQueue = parallel.pool.DataQueue; displayFcn = @(x) displayTrainingProgress(x,numEpochs,numWorkers,monitor,stopTrainingQueue); afterEach(dataQueue,displayFcn)
次の手順で詳しく説明するように、カスタム並列学習ループを使用してモデルに学習させます。すべてのワーカーで同時にコードを実行するには、spmd ブロックを使用します。spmd ブロック内の spmdIndex により、現在コードを実行しているワーカーのインデックスが与えられます。
学習前に、関数 partition を使用して各ワーカーのデータストアを分割します。分割されたデータストアを使用して、各ワーカーに minibatchqueue を作成します。各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch(この例の最後に定義) を使用して、データを正規化し、ターゲット クラスを one-hot 符号化変数に変換し、ミニバッチ内の観測値の数を判定します。イメージ データの形式を次元ラベル
"SSCB"(空間、空間、チャネル、バッチ) で整え、ターゲット クラスの形式を次元ラベル "CB" (チャネル、バッチ) で整えます。既定では、minibatchqueueオブジェクトは、基となる型がsingleのdlarrayオブジェクトにデータを変換します。観測数に形式を追加しないでください。GPU が利用できる場合、GPU で学習を行います。既定では、
minibatchqueueオブジェクトは、GPU が利用可能な場合、各出力をgpuArrayに変換します。関数
dlaccelerateを使用して、モデル損失関数を高速化します。深層学習関数の高速化の詳細については、Deep Learning Function Acceleration for Custom Training Loopsを参照してください。
関数 shuffle を使用し、エポックごとにデータストアをシャッフルします。エポック内のそれぞれの反復で次を行います。
並列処理を開始する前に、すべてのワーカーに利用可能なデータがあることを確認します。関数
continueEpochは、この例の最後で定義されているサポート関数であり、関数spmdCatを使用してワーカー間で関数hasdataの結果を連結することによって、すべてのワーカーが利用可能なデータを確実にもつようにします。[停止] ボタンがクリックされたかどうかをチェックします。
continueEpoch関数は、stopTrainingQueueの長さを監視することでこれをチェックします。関数
nextを使用して、minibatchqueueからミニバッチを読み取ります。関数
modelLossでdlfevalを呼び出すことによって、各ワーカーのネットワークの損失と勾配を計算します。関数dlfevalは、自動微分を有効にして補助関数modelLossを評価し、modelLossが損失の勾配を自動的に計算できるようにします。modelLoss(この例の最後で定義) は、ネットワーク、データのミニバッチ、およびターゲットを受け取り、損失と勾配を返します。全体的な損失を取得するために、すべてのワーカーの損失を集計します。この例では、損失関数のクロスエントロピーを使用します。集計された損失はすべての損失の合計です。集計する前に、ミニバッチ全体のうちワーカーが処理している割合で乗算し、各損失を正規化します。
spmdPlusを使用してすべての損失を加算し、ワーカー全体にその結果を複製します。すべてのワーカーの勾配を集計および更新するために、関数
aggregateAllGradientsを使用します。aggregateAllGradientsはこの例の最後で定義されているサポート関数です。この関数は、ミニバッチ全体のうち各ワーカーが処理している割合に基づいて正規化した後、spmdPlusを使用し、勾配を加算してワーカー全体に複製します。関数
aggregateStateを使用して、すべてのワーカーのネットワークの状態を集約します。aggregateStateは、この例の最後で定義されているサポート関数です。ネットワークのバッチ正規化層がデータの平均と分散を追跡します。ミニバッチ全体が複数のワーカーに分散されているため、各反復の後にネットワークの状態を集計し、ミニバッチ全体の平均と分散を計算します。最終勾配を計算した後、関数
sgdmupdateを使用し、ネットワークの学習可能なパラメーターを更新します。Dataqueueオブジェクトと関数sendを使用して学習の進行状況の情報をクライアントに送り返します。すべてのワーカーが同じ損失情報をもっているため、必要なのは 1 つのワーカーを使用してデータを送り返すことだけです。データが確実に CPU 上にあり、GPU を搭載していないクライアント マシンがデータにアクセスできるようにするには、データをクライアントに送信する前に、dlarrayに対してgatherを使用します。
spmd % Partition the datastore. workerImds = partition(augimdsTrain,numWorkers,spmdIndex); % Create minibatchqueue using partitioned datastore on each worker. workerMbq = minibatchqueue(workerImds,3,... MiniBatchSize=workerMiniBatchSize(spmdIndex),... MiniBatchFcn=@preprocessMiniBatch,... MiniBatchFormat=["SSCB" "CB" ""]); % Use dlaccelerate on the modelLoss accModelLoss = dlaccelerate(@modelLoss); workerVelocity = velocity; epoch = 0; iteration = 0; while epoch < numEpochs epoch = epoch + 1; shuffle(workerMbq); % Loop over mini-batches. while continueEpoch(workerMbq,stopTrainingQueue) iteration = iteration + 1; % Read a mini-batch of data. [workerX,workerT,workerNumObservations] = next(workerMbq); % Evaluate the model loss and gradients on the worker. [workerLoss,workerGradients,workerState] = dlfeval(accModelLoss,net,workerX,workerT); % Aggregate the losses on all workers. workerNormalizationFactor = workerMiniBatchSize(spmdIndex)./miniBatchSize; loss = spmdPlus(workerNormalizationFactor*extractdata(workerLoss)); % Aggregate the network state on all workers. net.State = aggregateState(workerState,workerNormalizationFactor,... isBatchNormalizationStateMean,isBatchNormalizationStateVariance); % Aggregate the gradients on all workers. workerGradients.Value = aggregateAllGradients(workerGradients.Value,workerNormalizationFactor); % Update the network parameters using the SGDM optimizer. [net,workerVelocity] = sgdmupdate(net,workerGradients,workerVelocity); % Calculate the training accuracy and send training progress information to the client. if spmdIndex == 1 accuracy = testnet(net,workerX,workerT,"accuracy",MiniBatchSize=workerMiniBatchSize(spmdIndex)); data = [epoch loss accuracy iteration]; send(dataQueue,gather(data)); end end end end
モデルのテスト
学習が完了すると、各ワーカーがもつ完全な学習済みネットワークはすべて同じになります。それらのいずれかを取得します。
netFinal = net{1};ネットワークに学習させた後、その精度をテストします。
テスト イメージを分類します。複数の観測値を使用して予測を行うには、関数 minibatchpredict を使用します。予測スコアをラベルに変換するには、関数 scores2label を使用します。関数 minibatchpredict は利用可能な GPU がある場合に自動的にそれを使用します。そうでない場合、関数は CPU を使用します。
labels = imdsTest.Labels;
imdsTestResized = transform(imdsTest,@(X) {imresize(X,inputSize(1:2))});
X = readall(imdsTestResized);
X = cat(4,X{:});
X = single(X) ./ 255;ネットワークの精度を計算します。
accuracy = testnet(netFinal,X,labels,"accuracy")accuracy = 74.3869
ミニ バッチ前処理関数
関数 preprocessMiniBatch は、次の手順を使用して予測子とターゲット クラスのミニバッチを前処理します。
ミニバッチ内の観測数を判定します。
関数
preprocessMiniBatchPredictorsを使用してイメージを前処理します。入力 cell 配列からターゲット クラスのデータを抽出し、2 番目の次元に沿って categorical 配列に連結します。
カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。
function [X,Y,numObs] = preprocessMiniBatch(XCell,YCell) numObs = numel(YCell); % Preprocess predictors. X = preprocessMiniBatchPredictors(XCell); % Extract class data from cell and concatenate. Y = cat(2,YCell{1:end}); % One-hot encode classes. Y = onehotencode(Y,1); end
ミニバッチ予測子前処理関数
関数 preprocessMiniBatchPredictors は、入力 cell 配列からイメージ データを抽出することで予測子のミニバッチを前処理し、数値配列に連結します。その後、イメージが正規化されます。
function X = preprocessMiniBatchPredictors(XCell) X = cat(4,XCell{:}); X = single(X) ./ 255; end
モデル損失関数
関数 modelLoss は、ネットワークの学習可能なパラメーターに関する損失の勾配を計算します。この関数は、forward を使用してミニバッチ X に対するネットワークの出力を計算し、クロス エントロピーを使用して、ターゲット T が与えられたときの損失を計算します。dlfeval と共にこの関数を呼び出すと、自動微分が有効になり、dlgradient は学習可能なパラメーターについての損失の勾配を自動的に計算できます。
function [loss,gradients,state] = modelLoss(net,X,T) [Y,state] = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables); end
学習の進行状況を表示する関数
関数 displayTrainingProgress は、ワーカーから送信される学習の進行状況の情報を表示し、[停止] ボタンがクリックされたかどうかをチェックします。[停止] ボタンがクリックされると、学習を停止する必要があることを示すフラグがワーカーに送信されます。この例では、ワーカーがデータを送信するたびに DataQueue によってこの関数が呼び出されます。
function displayTrainingProgress(data,numEpochs,numWorkers,monitor,stopTrainingQueue) % Extract training information from array. epoch = data(1); loss = data(2); accuracy = data(3); iteration = data(4); % Update training progress monitor. recordMetrics(monitor,iteration,TrainingLoss=loss,TrainingAccuracy=accuracy); updateInfo(monitor,Epoch=epoch + " of " + numEpochs, Workers= numWorkers); monitor.Progress = 100 * epoch/numEpochs; % Send flag if the Stop button is clicked. if monitor.Stop send(stopTrainingQueue,true); end end
勾配集計関数
関数 aggregateAllGradients は、すべてのワーカーの勾配を加算して集計します。spmdPlus は、ワーカー上ですべての勾配を加算して複製します。加算する前に、ミニバッチ全体のうちワーカーが処理している割合を表す係数を勾配に乗算し、それらを正規化します。dlarray の内容を取得するには、extractdata を使用します。
function gradients = aggregateAllGradients(gradients,normalizationFactor) % Inspect array of gradients and create cell array for storing aggregated % data. numArrays = numel(gradients); aggregationData = cell(numArrays); arraySizes = cell(numArrays,1); numElements = zeros(numArrays,1); % Extract the data from all the arrays for idxArray = 1:numArrays data = gradients{idxArray}; gradients{idxArray} = []; % Extract data from dlarray. data = extractdata(data); % Save the size of the array. arraySizes{idxArray} = size(data); numElements(idxArray) = numel(data); % Flatten the array to prepare for concatenation. aggregationData{idxArray} = data(:); end % Concatenate all arrays. aggregationData = cat(1,aggregationData{:}); % Aggregate the data from the workers. aggregationData = spmdPlus(normalizationFactor.*aggregationData); % Reconstruct the gradient arrays. i = 1; for idxArray = 1:numArrays n = numElements(idxArray); if n > 0 % Reshape the flattened data to the original size. data = reshape(aggregationData(i:(i+n-1)),arraySizes{idxArray}); % Reinsert the aggregated data as a dlarray. gradients{idxArray} = dlarray(data); i = i + n; end end end
状態集計関数
関数 aggregateState は、すべてのワーカーでネットワークの状態を集計します。このネットワークの状態には、データ セットの学習済みバッチ正規化統計量が含まれます。各ワーカーが処理するのはミニバッチの一部のみなので、すべてのデータの統計を表すように、ネットワークの状態を集計します。ミニバッチごとに、統合平均が、各反復のワーカー全体の平均に対する加重平均として計算されます。統合分散は、次の式に従って計算されます。
ここで、 はワーカーの合計数、 はミニバッチの観測値の合計数、 は 番目のワーカーで処理された観測値の数、 と はそのワーカーで計算された平均と分散の統計、 はすべてのワーカー全体の統合平均です。
function state = aggregateState(state,normalizationFactor,... isBatchNormalizationStateMean,isBatchNormalizationStateVariance) stateMeans = state.Value(isBatchNormalizationStateMean); stateVariances = state.Value(isBatchNormalizationStateVariance); for j = 1:numel(stateMeans) meanVal = stateMeans{j}; varVal = stateVariances{j}; % Calculate combined mean. combinedMean = spmdPlus(normalizationFactor*meanVal); % Calculate combined variance terms to sum. varTerm = normalizationFactor.*(varVal + (meanVal - combinedMean).^2); % Update state. stateMeans{j} = combinedMean; stateVariances{j} = spmdPlus(varTerm); end state.Value(isBatchNormalizationStateMean) = stateMeans; state.Value(isBatchNormalizationStateVariance) = stateVariances; end
エポック継続関数
関数 continueEpoch は、各ワーカーのミニバッチ キューに残っているデータがあるかどうかをチェックし、[停止] ボタンが押されたかどうかをチェックします。
function tf = continueEpoch(workerMbq,stopTrainingQueue) % Create a struct that will be concatenated across the workers. info.HasData = hasdata(workerMbq); info.StopRequested = stopTrainingQueue.QueueLength > 0; % Use spmdCat to aggregate the info from all the workers. info = spmdCat(info); % Continue training if all the workers have data, and if we were not asked to stop. stopRequest = any([info.StopRequested]); tf = ~stopRequest && all([info.HasData]); end
参考文献
The TensorFlow Team.Flowers http://download.tensorflow.org/example_images/flower_photos.tgz
参考
dlarray | dlnetwork | sgdmupdate | dlupdate | dlfeval | dlgradient | crossentropy | softmax | forward | predict