Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

連合学習を使用したネットワークの学習

この例では、連合学習を使用してネットワークに学習させる方法を示します。連合学習は、分散型つまり非集中型の方法でネットワークに学習させるようにする手法です [1]。

連合学習を使用すると、個々のデータ ソースとデータ セットの全体的な分布が一致しない場合でも、データを 1 か所に移動させることなく、さまざまなソースからのデータを使用してモデルに学習させることができます。これは、独立同分布でない (non-IID な) データと呼ばれます。連合学習は、学習データが大規模な場合や、学習データの転送にプライバシー上の懸念がある場合に特に役立ちます。

連合学習の手法では、データを分散させるのではなく、各モデルがデータ ソースと同じ場所にある状態で複数のモデルに学習させます。ローカルで学習させたモデルの学習可能パラメーターを周期的に収集して組み合わせることで、すべてのデータ ソースから学習したグローバル モデルを作成できます。このようにして、学習データを 1 か所で処理するのではなく、グローバル モデルに学習させることができます。

この例では、連合学習を使用し、高度に non-IID なデータセットを使用して分類モデルに並列学習させます。モデルは、0 ~ 9 までの数字から成る手書きイメージ 10000 個で構成される数字データ セットを使用して学習します。この例では、10 個のワーカーを使用して並列実行し、それぞれのワーカーが単一の数字のイメージを処理します。学習の各ラウンド後にネットワークの学習可能パラメーターを平均化することで、各ワーカーのモデルは、他のクラスのデータを処理することなく、すべてのクラスでパフォーマンスを向上させます。

データのプライバシーは連合学習の有用性の 1 つですが、この例では、データのプライバシーとセキュリティの維持についての詳細を扱いません。この例では、基本的な連合学習アルゴリズムを示します。

並列環境の設定

データ セット内のクラスと同じ数のワーカーを使用して並列プールを作成します。この例では、10 個のワーカーをもつプロセス ベースのローカル並列プールを使用します。

cluster = parcluster("Processes");
cluster.NumWorkers = 10;
pool = parpool(cluster);
Starting parallel pool (parpool) using the 'Processes' profile ...
Connected to parallel pool with 10 workers.
numWorkers = pool.NumWorkers;

データセットの読み込み

この例で使用されるすべてのデータは、1 か所にまず保存されます。このデータを高度な non-IID にするためには、クラスに従ってワーカー間でデータを分散する必要があります。検証とテストのデータ セットを作成するために、データの一部をワーカーからクライアントに転送します。ワーカー上にある個々のクラスの学習データと、クライアント上にあるすべてのクラスのテストデータと検証データが正しく設定された後は、学習中にさらにデータが転送されることはありません。

イメージ データが格納されているフォルダーを指定します。

digitDatasetPath = fullfile(matlabroot,"toolbox","nnet","nndemos", ...
    "nndatasets","DigitDataset");

ワーカー間でデータを分散します。ワーカー 1 が数字 0 のイメージをすべて受け取り、ワーカー 2 が数字 1 のイメージを受け取るというように、各ワーカーが 1 つの数字のイメージのみを受け取ります。

各数字のイメージは、その数字の名前が付いた別個のフォルダーに保存されます。各ワーカーで、関数 fullfile を使用して特定のクラスのフォルダーに対するパスを指定します。次に、その数字のイメージすべてを格納する imageDatastore を作成します。次に、関数 splitEachLabel を使用して、検証とテストに使用するためにデータの 30% をランダムに分離します。最後に、学習データを格納する augmentedImageDatastore を作成します。

inputSize = [28 28 1];
spmd   
    digitDatasetPath = fullfile(digitDatasetPath,num2str(spmdIndex - 1));
    imds = imageDatastore(digitDatasetPath, ...
        IncludeSubfolders=true, ...
        LabelSource="foldernames");
    [imdsTrain,imdsTestVal] = splitEachLabel(imds,0.7,"randomized");
    
    augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
end

結合されたグローバル モデルのパフォーマンスを学習中および学習後にテストするには、すべてのクラスのイメージを含むテスト データセットと検証データセットを作成します。各ワーカーからのテスト データと検証データを単一のデータストアに結合します。次に、このデータストアを 2 つのデータストアに分割します。それぞれに全体のデータの 15% を含めます。1 つは学習中のネットワークの検証用、もう 1 つは学習後のネットワークのテスト用です。

fileList = [];
labelList = [];

for i = 1:numWorkers
    tmp = imdsTestVal{i};
    
    fileList = cat(1,fileList,tmp.Files);
    labelList = cat(1,labelList,tmp.Labels);    
end

imdsGlobalTestVal = imageDatastore(fileList);
imdsGlobalTestVal.Labels = labelList;

[imdsGlobalTest,imdsGlobalVal] = splitEachLabel(imdsGlobalTestVal,0.5,"randomized");

augimdsGlobalTest = augmentedImageDatastore(inputSize(1:2),imdsGlobalTest);
augimdsGlobalVal = augmentedImageDatastore(inputSize(1:2),imdsGlobalVal);

これで、各ワーカーが 1 つのクラスからのデータを学習し、クライアントがすべてのクラスからの検証データおよびテスト データを保持するようにデータが配置されました。

ネットワークの定義

データ セット内のクラスの数を判定します。

classes = categories(imdsGlobalTest.Labels);
numClasses = numel(classes);

ネットワーク アーキテクチャを定義します。

layers = [
    imageInputLayer(inputSize,Normalization="none")
    convolution2dLayer(5,32)
    reluLayer
    maxPooling2dLayer(2)
    convolution2dLayer(5,64)
    reluLayer
    maxPooling2dLayer(2)
    fullyConnectedLayer(numClasses)
    softmaxLayer];

層から dlnetwork オブジェクトを作成します。

net = dlnetwork(layers)
net = 
  dlnetwork with properties:

         Layers: [9×1 nnet.cnn.layer.Layer]
    Connections: [8×2 table]
     Learnables: [6×3 table]
          State: [0×3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 1

  View summary with summary.

モデル損失関数の定義

この例のモデル損失関数セクションにリストされている関数 modelLoss を作成します。この関数は、dlnetwork オブジェクト、入力データのミニバッチとそれに対応するラベルを受け取り、ネットワーク内の学習可能なパラメーターについての損失およびその損失の勾配を返します。

連合平均化関数の定義

この例の連合平均化関数セクションにリストされている関数 federatedAveraging を作成します。この関数は、各ワーカーのネットワークの学習可能なパラメーターと各ワーカーに対する正規化係数を受け取り、すべてのネットワークで平均化された学習可能なパラメーターを返します。学習可能なパラメーターの平均を使用して、グローバル ネットワークと各ワーカーのネットワークを更新します。

精度計算関数の定義

この例の精度計算関数セクションにリストされている関数 computeAccuracy を作成します。この関数は、dlnetwork オブジェクト、minibatchqueue オブジェクト内のデータ セット、およびクラスのリストを受け取り、データ セット内の観測値全体で予測の精度を返します。

学習オプションの指定

学習中、ワーカーはネットワークの学習可能なパラメーターをクライアントに周期的に伝達し、クライアントがグローバル モデルを更新できるようにします。学習はラウンドに分かれています。学習の各ラウンドの最後に、学習可能なパラメーターが平均化され、グローバル モデルが更新されます。次に、ワーカー モデルが新しいグローバル モデルに置き換えられ、ワーカーで学習が続行されます。

1 ラウンドあたり 5 エポックで、学習を 300 ラウンド行います。ラウンドごとに少数のエポックの学習を行うことで、ワーカーのネットワークが平均化される前に発散しすぎることを確実に防ぎます。

numRounds = 300;
numEpochsperRound = 5;
miniBatchSize = 100;

SGDM 最適化のオプションを指定します。初期学習率 0.001 とモーメンタム 0 を指定します。

learnRate = 0.001;
momentum = 0;

モデルの学習

カスタム ミニバッチ前処理関数 preprocessMiniBatch (この例のミニバッチ前処理関数セクションで定義) への関数ハンドルを作成します。

各ワーカーについて、そのワーカーでローカルに処理された学習観測値の総数を求めます。各通信ラウンド後に学習可能パラメーターの平均を求めるときに、この数値を使用して各ワーカーの学習可能なパラメーターを正規化します。これは、各ワーカーのデータ量に差がある場合に平均のバランスを取るのに役立ちます。

各ワーカーで、学習中にイメージのミニバッチを処理および管理する minibatchqueue オブジェクトを作成します。各ミニバッチで次を行います。

  • カスタム ミニバッチ前処理関数 preprocessMiniBatch を使用してデータを前処理し、ラベルを one-hot 符号化変数に変換します。

  • イメージ データを次元ラベル 'SSCB' (spatial、spatial、channel、batch) で書式設定します。既定では、minibatchqueue オブジェクトは、基となる型が singledlarray オブジェクトにデータを変換します。書式をクラス ラベルに追加しないでください。

  • GPU が利用できる場合、GPU で学習を行います。既定では、minibatchqueue オブジェクトは、GPU が利用可能な場合、各出力を gpuArray に変換します。GPU を使用するには、Parallel Computing Toolbox™ とサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

preProcess = @(x,y)preprocessMiniBatch(x,y,classes);

spmd
    sizeOfLocalDataset = augimdsTrain.NumObservations;
    
    mbq = minibatchqueue(augimdsTrain, ...
        MiniBatchSize=miniBatchSize, ...
        MiniBatchFcn=preProcess, ...
        MiniBatchFormat=["SSCB",""]);
end

学習中に使用する検証データを管理する minibatchqueue オブジェクトを作成します。各ワーカーで minibatchqueue と同じ設定を使用します。

mbqGlobalVal = minibatchqueue(augimdsGlobalVal, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=preProcess, ...
    MiniBatchFormat=["SSCB",""]);

trainingProgressMonitorオブジェクトを初期化します。モニターを作成するとタイマーが開始されるため、学習ループに近いところでオブジェクトを作成するようにしてください。

monitor = trainingProgressMonitor( ...
    Metrics="GlobalAccuracy", ...
    Info="CommunicationRound", ...
    XLabel="Communication Round");

SGDM ソルバーの速度パラメーターを初期化します。

velocity = [];

グローバル モデルを初期化します。まず、グローバル モデルには、各ワーカーの未学習のネットワークと同じ初期パラメーターがあります。

globalModel = net;

カスタム学習ループを使用してモデルに学習させます。それぞれの通信ラウンドで次を行います。

  • ワーカーのネットワークを最新のグローバル ネットワークで更新します。

  • ワーカーのネットワークに 5 エポック学習させます。

  • 関数 federatedAveraging を使用して、すべてのネットワークの平均パラメーターを求めます。

  • グローバル ネットワーク パラメーターを平均値に置換します。

  • 検証データを使用して、更新されたグローバル ネットワークの精度を計算します。

  • 学習の進行状況モニターでグローバル精度を更新します。

  • Stop プロパティが true の場合は停止します。[停止] ボタンをクリックしたときに TrainingProgressMonitor オブジェクトの Stop プロパティ値が true に変更します。

各エポックについて、データをシャッフルしてデータのミニバッチをループで回します。各ミニバッチで次を行います。

  • 関数 dlfeval および modelLoss を使用してモデルの損失と勾配を評価します。

  • 関数 sgdmupdate を使用してローカル ネットワーク パラメーターを更新します。

round = 0;
while round < numRounds && ~monitor.Stop

    round = round + 1;

    spmd
        % Send global updated parameters to each worker.
        net.Learnables.Value = globalModel.Learnables.Value;

        % Loop over epochs.
        for epoch = 1:numEpochsperRound
            % Shuffle data.
            shuffle(mbq);

            % Loop over mini-batches.
            while hasdata(mbq)

                % Read mini-batch of data.
                [X,T] = next(mbq);

                % Evaluate the model loss and gradients using dlfeval and the
                % modelLoss function.
                [loss,gradients] = dlfeval(@modelLoss,net,X,T);

                % Update the network parameters using the SGDM optimizer.
                [net,velocity] = sgdmupdate(net,gradients,velocity,learnRate,momentum);

            end
        end

        % Collect updated learnable parameters on each worker.
        workerLearnables = net.Learnables.Value;
    end

    % Find normalization factors for each worker based on ratio of data
    % processed on that worker.
    sizeOfAllDatasets = sum([sizeOfLocalDataset{:}]);
    normalizationFactor = [sizeOfLocalDataset{:}]/sizeOfAllDatasets;

    % Update the global model with new learnable parameters, normalized and
    % averaged across all workers.
    globalModel.Learnables.Value = federatedAveraging(workerLearnables,normalizationFactor);

    % Calculate the accuracy of the global model.
    accuracy = computeAccuracy(globalModel,mbqGlobalVal,classes);

    % Update the training progress monitor.
    recordMetrics(monitor,round,GlobalAccuracy=accuracy);
    updateInfo(monitor,CommunicationRound=round + " of " + numRounds);
    monitor.Progress = 100*round/numRounds;

end

学習の最終ラウンドの後、各ワーカーのネットワークを学習可能なパラメーターの最終平均で更新します。これは、ワーカーのネットワークを引き続き使用する場合、またはそのネットワークに学習させる場合に重要です。

spmd
    net.Learnables.Value = globalModel.Learnables.Value;
end

モデルのテスト

真のラベルをもつテスト セットで予測を比較し、モデルの分類精度をテストします。

テスト データを管理する minibatchqueue オブジェクトを作成します。学習中および検証中に使用される minibatchqueue オブジェクトと同じ設定を使用します。

mbqGlobalTest = minibatchqueue(augimdsGlobalTest, ...
    MiniBatchSize=miniBatchSize, ...
    MiniBatchFcn=preProcess, ...
    MiniBatchFormat="SSCB");

関数 computeAccuracy を使用して、予測されたクラスを計算し、すべてのテスト データの予測精度を計算します。

accuracy = computeAccuracy(globalModel,mbqGlobalTest,classes)
accuracy = single
    0.9827

計算が完了したら、並列プールを削除できます。関数 gcp が現在の並列プール オブジェクトを返すため、それでプールを削除できます。

delete(gcp("nocreate"));

モデル損失関数

関数 modelLoss は、dlnetwork オブジェクト net、入力データ X のミニバッチとそれに対応するラベル T を受け取り、net 内の学習可能なパラメーターについての損失とその損失の勾配を返します。勾配を自動的に計算するには、関数 dlgradient を使用します。学習中にネットワークの予測を計算するために、関数 forward を使用します。

function [loss,gradients] = modelLoss(net,X,T)

    YPred = forward(net,X);
    
    loss = crossentropy(YPred,T);
    gradients = dlgradient(loss,net.Learnables);

end

精度計算関数

関数 computeAccuracy は、dlnetwork オブジェクト netminibatchqueue オブジェクト mbq、およびクラスのリストを受け取り、用意されているデータ セットですべての予測の精度を返します。検証中または学習終了後にネットワークの予測を計算するために、関数 predict を使用します。

function accuracy = computeAccuracy(net,mbq,classes)

    correctPredictions = [];
    
    shuffle(mbq);
    while hasdata(mbq)
        
        [XTest,TTest] = next(mbq);
        
        TTest = onehotdecode(TTest,classes,1)';
        
        YPred = predict(net,XTest);
        YPred = onehotdecode(YPred,classes,1)';
        
        correctPredictions = [correctPredictions; YPred == TTest];
    end
    
    predSum = sum(correctPredictions);
    accuracy = single(predSum./size(correctPredictions,1));

end

ミニバッチ前処理関数

関数 preprocessMiniBatch は、次の手順でデータを前処理します。

  1. 入力 cell 配列からイメージ データを抽出して数値配列に連結します。4 番目の次元でイメージ データを連結することにより、3 番目の次元が各イメージに追加されます。この次元は、シングルトン チャネル次元として使用されます。

  2. 入力 cell 配列からラベル データを抽出し、2 番目の次元に沿って categorical 配列に連結させます。

  3. カテゴリカル ラベルを数値配列に one-hot 符号化します。最初の次元への符号化は、ネットワーク出力の形状と一致する符号化された配列を生成します。

function [X,Y] = preprocessMiniBatch(XCell,YCell,classes)

    % Concatenate.
    X = cat(4,XCell{1:end});
    
    % Extract label data from cell and concatenate.
    Y = cat(2,YCell{1:end});
    
    % One-hot encode labels.
    Y = onehotencode(Y,1,ClassNames=classes);

end

連合平均化関数

関数 federatedAveraging は、各ワーカーのネットワークの学習可能なパラメーターと各ワーカーに対する正規化係数を受け取り、すべてのネットワークで平均化された学習可能なパラメーターを返します。学習可能なパラメーターの平均を使用して、グローバル ネットワークと各ワーカーのネットワークを更新します。

function learnables = federatedAveraging(workerLearnables,normalizationFactor)

    numWorkers = size(normalizationFactor,2);
    
    % Initialize container for averaged learnables with same size as existing
    % learnables. Use learnables of first worker network as an example.
    exampleLearnables = workerLearnables{1};
    learnables = cell(height(exampleLearnables),1);
    
    for i = 1:height(learnables)   
        learnables{i} = zeros(size(exampleLearnables{i}),"like",(exampleLearnables{i}));
    end
    
    % Add the normalized learnable parameters of all workers to
    % calculate average values.
    for i = 1:numWorkers
        tmp = workerLearnables{i};
        for values = 1:numel(learnables)
            learnables{values} = learnables{values} + normalizationFactor(i).*tmp{values};
        end
    end
    
end

参考文献

[1] McMahan, H. Brendan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Agüera y Arcas. "Communication-Efficient Learning of Deep Networks from Decentralized Data." Preprint, submitted. February, 2017. https://arxiv.org/abs/1602.05629.

参考

| | | | | |

関連するトピック