このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。
連合学習を使用したネットワークの学習
この例では、連合学習を使用してネットワークに学習させる方法を示します。連合学習は、分散型つまり非集中型の方法でネットワークに学習させるようにする手法です [1]。
連合学習を使用すると、個々のデータ ソースとデータ セットの全体的な分布が一致しない場合でも、データを 1 か所に移動させることなく、さまざまなソースからのデータを使用してモデルに学習させることができます。これは、独立同分布でない (non-IID な) データと呼ばれます。連合学習は、学習データが大規模な場合や、学習データの転送にプライバシー上の懸念がある場合に特に役立ちます。
連合学習の手法では、データを分散させるのではなく、各モデルがデータ ソースと同じ場所にある状態で複数のモデルに学習させます。ローカルで学習させたモデルの学習可能パラメーターを周期的に収集して組み合わせることで、すべてのデータ ソースから学習したグローバル モデルを作成できます。このようにして、学習データを 1 か所で処理するのではなく、グローバル モデルに学習させることができます。
この例では、連合学習を使用し、高度に non-IID なデータセットを使用して分類モデルに並列学習させます。モデルは、0 ~ 9 までの数字から成る手書きイメージ 10000 個で構成される数字データ セットを使用して学習します。この例では、10 個のワーカーを使用して並列実行し、それぞれのワーカーが単一の数字のイメージを処理します。学習の各ラウンド後にネットワークの学習可能パラメーターを平均化することで、各ワーカーのモデルは、他のクラスのデータを処理することなく、すべてのクラスでパフォーマンスを向上させます。
データのプライバシーは連合学習の有用性の 1 つですが、この例では、データのプライバシーとセキュリティの維持についての詳細を扱いません。この例では、基本的な連合学習アルゴリズムを示します。
並列環境の設定
データ セット内のクラスと同じ数のワーカーを使用して並列プールを作成します。この例では、10 個のワーカーをもつプロセス ベースのローカル並列プールを使用します。
cluster = parcluster("Processes");
cluster.NumWorkers = 10;
pool = parpool(cluster);
Starting parallel pool (parpool) using the 'Processes' profile ... Connected to parallel pool with 10 workers.
numWorkers = pool.NumWorkers;
データセットの読み込み
この例で使用されるすべてのデータは、1 か所にまず保存されます。このデータを高度な non-IID にするためには、クラスに従ってワーカー間でデータを分散する必要があります。検証とテストのデータ セットを作成するために、データの一部をワーカーからクライアントに転送します。ワーカー上にある個々のクラスの学習データと、クライアント上にあるすべてのクラスのテストデータと検証データが正しく設定された後は、学習中にさらにデータが転送されることはありません。
イメージ データが格納されているフォルダーを指定します。
digitDatasetPath = fullfile(matlabroot,"toolbox","nnet","nndemos", ... "nndatasets","DigitDataset");
ワーカー間でデータを分散します。ワーカー 1 が数字 0 のイメージをすべて受け取り、ワーカー 2 が数字 1 のイメージを受け取るというように、各ワーカーが 1 つの数字のイメージのみを受け取ります。
各数字のイメージは、その数字の名前が付いた別個のフォルダーに保存されます。各ワーカーで、関数 fullfile
を使用して特定のクラスのフォルダーに対するパスを指定します。次に、その数字のイメージすべてを格納する imageDatastore
を作成します。次に、関数 splitEachLabel
を使用して、検証とテストに使用するためにデータの 30% をランダムに分離します。最後に、学習データを格納する augmentedImageDatastore
を作成します。
inputSize = [28 28 1]; spmd digitDatasetPath = fullfile(digitDatasetPath,num2str(spmdIndex - 1)); imds = imageDatastore(digitDatasetPath, ... IncludeSubfolders=true, ... LabelSource="foldernames"); [imdsTrain,imdsTestVal] = splitEachLabel(imds,0.7,"randomized"); augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain); end
結合されたグローバル モデルのパフォーマンスを学習中および学習後にテストするには、すべてのクラスのイメージを含むテスト データセットと検証データセットを作成します。各ワーカーからのテスト データと検証データを単一のデータストアに結合します。次に、このデータストアを 2 つのデータストアに分割します。それぞれに全体のデータの 15% を含めます。1 つは学習中のネットワークの検証用、もう 1 つは学習後のネットワークのテスト用です。
fileList = []; labelList = []; for i = 1:numWorkers tmp = imdsTestVal{i}; fileList = cat(1,fileList,tmp.Files); labelList = cat(1,labelList,tmp.Labels); end imdsGlobalTestVal = imageDatastore(fileList); imdsGlobalTestVal.Labels = labelList; [imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized"); augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest); augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);
これで、各ワーカーが 1 つのクラスからのデータを学習し、クライアントがすべてのクラスからの検証データおよびテスト データを保持するようにデータが配置されました。
ネットワークの定義
データ セット内のクラスの数を判定します。
classes = categories(imdsGlobalTest.Labels); numClasses = numel(classes);
ネットワーク アーキテクチャを定義します。
layers = [
imageInputLayer(inputSize,Normalization="none")
convolution2dLayer(5,32)
reluLayer
maxPooling2dLayer(2)
convolution2dLayer(5,64)
reluLayer
maxPooling2dLayer(2)
fullyConnectedLayer(numClasses)
softmaxLayer];
層から dlnetwork
オブジェクトを作成します。
net = dlnetwork(layers)
net = dlnetwork with properties: Layers: [9×1 nnet.cnn.layer.Layer] Connections: [8×2 table] Learnables: [6×3 table] State: [0×3 table] InputNames: {'imageinput'} OutputNames: {'softmax'} Initialized: 1 View summary with summary.
モデル損失関数の定義
この例のモデル損失関数セクションにリストされている関数 modelLoss
を作成します。この関数は、dlnetwork
オブジェクト、入力データのミニバッチとそれに対応するラベルを受け取り、ネットワーク内の学習可能なパラメーターについての損失およびその損失の勾配を返します。
連合平均化関数の定義
この例の連合平均化関数セクションにリストされている関数 federatedAveraging
を作成します。この関数は、各ワーカーのネットワークの学習可能なパラメーターと各ワーカーに対する正規化係数を受け取り、すべてのネットワークで平均化された学習可能なパラメーターを返します。学習可能なパラメーターの平均を使用して、グローバル ネットワークと各ワーカーのネットワークを更新します。
精度計算関数の定義
この例の精度計算関数セクションにリストされている関数 computeAccuracy
を作成します。この関数は、dlnetwork
オブジェクト、minibatchqueue
オブジェクト内のデータ セット、およびクラスのリストを受け取り、データ セット内の観測値全体で予測の精度を返します。
学習オプションの指定
学習中、ワーカーはネットワークの学習可能なパラメーターをクライアントに周期的に伝達し、クライアントがグローバル モデルを更新できるようにします。学習はラウンドに分かれています。学習の各ラウンドの最後に、学習可能なパラメーターが平均化され、グローバル モデルが更新されます。次に、ワーカー モデルが新しいグローバル モデルに置き換えられ、ワーカーで学習が続行されます。
1 ラウンドあたり 5 エポックで、学習を 300 ラウンド行います。ラウンドごとに少数のエポックの学習を行うことで、ワーカーのネットワークが平均化される前に発散しすぎることを確実に防ぎます。
numRounds = 300; numEpochsperRound = 5; miniBatchSize = 100;
SGDM 最適化のオプションを指定します。初期学習率 0.001 とモーメンタム 0 を指定します。
learnRate = 0.001; momentum = 0;
モデルの学習
カスタム ミニバッチ前処理関数 preprocessMiniBatch
(この例のミニバッチ前処理関数セクションで定義) への関数ハンドルを作成します。
各ワーカーについて、そのワーカーでローカルに処理された学習観測値の総数を求めます。各通信ラウンド後に学習可能パラメーターの平均を求めるときに、この数値を使用して各ワーカーの学習可能なパラメーターを正規化します。これは、各ワーカーのデータ量に差がある場合に平均のバランスを取るのに役立ちます。
各ワーカーで、学習中にイメージのミニバッチを処理および管理する minibatchqueue
オブジェクトを作成します。各ミニバッチで次を行います。
カスタム ミニバッチ前処理関数
preprocessMiniBatch
を使用してデータを前処理し、ラベルを one-hot 符号化変数に変換します。イメージ データを次元ラベル
'SSCB'
(spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue
オブジェクトは、基となる型がsingle
のdlarray
オブジェクトにデータを変換します。書式をクラス ラベルに追加しないでください。GPU が利用できる場合、GPU で学習を行います。既定では、
minibatchqueue
オブジェクトは、GPU が利用可能な場合、各出力をgpuArray
に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。
preProcess = @(x,y)preprocessMiniBatch(x,y,classes); spmd sizeOfLocalDataset = augimdsTrain.NumObservations; mbq = minibatchqueue(augimdsTrain, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=preProcess, ... MiniBatchFormat=["SSCB",""]); end
学習中に使用する検証データを管理する minibatchqueue
オブジェクトを作成します。各ワーカーで minibatchqueue
と同じ設定を使用します。
mbqGlobalVal = minibatchqueue(augimdsGlobalVal, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=preProcess, ... MiniBatchFormat=["SSCB",""]);
trainingProgressMonitor
オブジェクトを初期化します。モニターを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。
monitor = trainingProgressMonitor( ... Metrics="GlobalAccuracy", ... Info="CommunicationRound", ... XLabel="Communication Round");
SGDM ソルバーの速度パラメーターを初期化します。
velocity = [];
グローバル モデルを初期化します。まず、グローバル モデルには、各ワーカーの未学習のネットワークと同じ初期パラメーターがあります。
globalModel = net;
カスタム学習ループを使用してモデルに学習させます。それぞれの通信ラウンドで次を行います。
ワーカーのネットワークを最新のグローバル ネットワークで更新します。
ワーカーのネットワークに 5 エポック学習させます。
関数
federatedAveraging
を使用して、すべてのネットワークの平均パラメーターを求めます。グローバル ネットワーク パラメーターを平均値に置換します。
検証データを使用して、更新されたグローバル ネットワークの精度を計算します。
学習の進行状況モニターでグローバル精度を更新します。
Stop
プロパティがtrue
の場合は停止します。[停止] ボタンをクリックしたときにTrainingProgressMonitor
オブジェクトのStop
プロパティ値がtrue
に変更します。
各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。各ミニバッチで次を行います。
関数
dlfeval
およびmodelLoss
を使用してモデルの損失と勾配を評価します。関数
sgdmupdate
を使用してローカル ネットワーク パラメーターを更新します。
round = 0; while round < numRounds && ~monitor.Stop round = round + 1; spmd % Send global updated parameters to each worker. net.Learnables.Value = globalModel.Learnables.Value; % Loop over epochs. for epoch = 1:numEpochsperRound % Shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) % Read mini-batch of data. [X,T] = next(mbq); % Evaluate the model loss and gradients using dlfeval and the % modelLoss function. [loss,gradients] = dlfeval(@modelLoss,net,X,T); % Update the network parameters using the SGDM optimizer. [net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum); end end % Collect updated learnable parameters on each worker. workerLearnables = net.Learnables.Value; end % Find normalization factors for each worker based on ratio of data % processed on that worker. sizeOfAllDatasets = sum([sizeOfLocalDataset{:}]); normalizationFactor = [sizeOfLocalDataset{:}]/sizeOfAllDatasets; % Update the global model with new learnable parameters, normalized and % averaged across all workers. globalModel.Learnables.Value = federatedAveraging(workerLearnables,normalizationFactor); % Calculate the accuracy of the global model. accuracy = computeAccuracy(globalModel,mbqGlobalVal,classes); % Update the training progress monitor. recordMetrics(monitor,round,GlobalAccuracy=accuracy); updateInfo(monitor,CommunicationRound=round + " of " + numRounds); monitor.Progress = 100*round/numRounds; end
学習の最終ラウンドの後、各ワーカーのネットワークを学習可能なパラメーターの最終平均で更新します。これは、ワーカーのネットワークを引き続き使用する場合、またはそのネットワークに学習させる場合に重要です。
spmd net.Learnables.Value = globalModel.Learnables.Value; end
モデルのテスト
真のラベルをもつテスト セットで予測を比較し、モデルの分類精度をテストします。
テスト データを管理する minibatchqueue
オブジェクトを作成します。学習中および検証中に使用される minibatchqueue
オブジェクトと同じ設定を使用します。
mbqGlobalTest = minibatchqueue(augimdsGlobalTest, ... MiniBatchSize=miniBatchSize, ... MiniBatchFcn=preProcess, ... MiniBatchFormat="SSCB");
関数 computeAccuracy
を使用して、予測されたクラスを計算し、すべてのテスト データの予測精度を計算します。
accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes)
accuracy = single
0.9827
計算が完了したら、並列プールを削除できます。関数 gcp
が現在の並列プール オブジェクトを返すため、それでプールを削除できます。
delete(gcp("nocreate"));
モデル損失関数
関数 modelLoss
は、dlnetwork
オブジェクト net
、入力データ X
のミニバッチとそれに対応するラベル T
を受け取り、net
内の学習可能なパラメーターについての損失とその損失の勾配を返します。勾配を自動的に計算するには、関数 dlgradient
を使用します。学習中にネットワークの予測を計算するために、関数 forward
を使用します。
function [loss,gradients] = modelLoss(net,X,T) YPred = forward(net,X); loss = crossentropy(YPred,T); gradients = dlgradient(loss,net.Learnables); end
精度計算関数
関数 computeAccuracy
は、dlnetwork
オブジェクト net
、minibatchqueue
オブジェクト mbq
、およびクラスのリストを受け取り、用意されているデータ セットですべての予測の精度を返します。検証中または学習終了後にネットワークの予測を計算するために、関数 predict
を使用します。
function accuracy = computeAccuracy(net,mbq,classes) correctPredictions = []; shuffle(mbq); while hasdata(mbq) [XTest,TTest] = next(mbq); TTest = onehotdecode(TTest,classes,1)'; YPred = predict(net,XTest); YPred = onehotdecode(YPred,classes,1)'; correctPredictions = [correctPredictions; YPred == TTest]; end predSum = sum(correctPredictions); accuracy = single(predSum./size(correctPredictions,1)); end
ミニバッチ前処理関数
関数 preprocessMiniBatch
は、次の手順でデータを前処理します。
入力 cell 配列からイメージ データを抽出して数値配列に連結します。4 番目の次元でイメージ データを連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されます。
入力 cell 配列からラベル データを抽出し、2 番目の次元に沿って categorical 配列に連結させます。
カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。
function [X,Y] = preprocessMiniBatch(XCell,YCell,classes) % Concatenate. X = cat(4,XCell{1:end}); % Extract label data from cell and concatenate. Y = cat(2,YCell{1:end}); % One-hot encode labels. Y = onehotencode(Y,1,ClassNames=classes); end
連合平均化関数
関数 federatedAveraging
は、各ワーカーのネットワークの学習可能なパラメーターと各ワーカーに対する正規化係数を受け取り、すべてのネットワークで平均化された学習可能なパラメーターを返します。学習可能なパラメーターの平均を使用して、グローバル ネットワークと各ワーカーのネットワークを更新します。
function learnables = federatedAveraging(workerLearnables,normalizationFactor) numWorkers = size(normalizationFactor,2); % Initialize container for averaged learnables with same size as existing % learnables. Use learnables of first worker network as an example. exampleLearnables = workerLearnables{1}; learnables = cell(height(exampleLearnables),1); for i = 1:height(learnables) learnables{i} = zeros(size(exampleLearnables{i}),"like",(exampleLearnables{i})); end % Add the normalized learnable parameters of all workers to % calculate average values. for i = 1:numWorkers tmp = workerLearnables{i}; for values = 1:numel(learnables) learnables{values} = learnables{values} + normalizationFactor(i).*tmp{values}; end end end
参考文献
参考
dlarray
| dlnetwork
| sgdmupdate
| dlupdate
| dlfeval
| dlgradient
| minibatchqueue