Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

転移学習用の重み初期化子を使用した実験

R2020a 以降

この例では、学習用のさまざまな重み初期化子を使用して畳み込み層と全結合層の重みを初期化する実験を構成する方法を説明します。目的のタスクに向けてさまざまな重み初期化子の性能を比較するには、この例を参考にして実験を作成してください。

深層学習ネットワークに学習させる際、層の重みとバイアスの初期化は、ネットワークの学習成果に大きな影響を与える可能性があります。バッチ正規化層をもたないネットワークの場合、初期化子の選択はさらに大きな影響を与えます。重み初期化子の詳細については、層の重み初期化子の比較を参照してください。

実験を開く

まず、例を開きます。実験マネージャーによって、検証と実行が可能な事前構成済みの実験を含むプロジェクトが読み込まれます。実験を開くには、[実験ブラウザー] ペインで、WeightInitializerExperiment をダブルクリックします。

組み込みの学習実験は、説明、ハイパーパラメーターのテーブル、セットアップ関数、および実験の結果を評価するためのメトリクス関数の集合で構成されます。詳細については、組み込みの学習実験の構成を参照してください。

[説明] フィールドには、実験を説明するテキストが表示されます。この例の説明は次のようになります。

Perform transfer learning by initializing the weights of convolution and
fully connected layers in a pretrained network.

[ハイパーパラメーター] セクションでは、実験で使用する手法とハイパーパラメーター値を指定します。実験を実行すると、実験マネージャーは、ハイパーパラメーター テーブルで指定されたハイパーパラメーター値のすべての組み合わせを使用してネットワークに学習させます。この例では、ハイパーパラメーター WeightsInitializer および BiasInitializer を使用し、事前学習済みのネットワークにおける畳み込み層および全結合層の重み初期化子とバイアス初期化子を指定します。これらの初期化子の詳細については、WeightsInitializerおよびBiasInitializerを参照してください。

[セットアップ関数] セクションでは、実験用の学習データ、ネットワーク アーキテクチャ、および学習オプションを構成する関数を指定します。この関数を MATLAB® エディターで開くには、[編集] をクリックします。この関数のコードは、セットアップ関数にも示されています。セットアップ関数への入力は、ハイパーパラメーター テーブルのフィールドをもつ構造体です。この関数は、イメージ分類問題用のネットワークに学習させるために使用する 3 つの出力を返します。この例のセットアップ関数では以下を行います。

  • 事前学習済みの GoogLeNet ネットワークの読み込み。

lgraph = googlenet(Weights="none");

  • 花のデータセット (約 218 MB) のダウンロードと抽出。このデータセットの詳細については、イメージ データセットを参照してください。

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

imageFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(imageFolder,"dir")
    disp("Downloading Flower Dataset (218 MB)...")
    websave(filename,url);
    untar(filename,downloadFolder)
end

imds = imageDatastore(imageFolder, ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);
inputSize = lgraph.Layers(1).InputSize;
augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain);
augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);

  • 畳み込み層と全結合層における入力の重みの初期化。これには、ハイパーパラメーター テーブルで指定された初期化子が使用されます。補助関数 findLayersToReplace は、ネットワーク アーキテクチャのどの層を転移学習用に変更できるかを判定します。この関数のコードを見るには、置き換える層の検索を参照してください。

numClasses = numel(categories(imdsTrain.Labels));
weightsInitializer = params.WeightsInitializer;
biasInitializer = params.BiasInitializer;

learnableLayer = findLayersToReplace(lgraph);
newLearnableLayer = fullyConnectedLayer(numClasses,Name="new_fc");
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);

for i = 1:numel(lgraph.Layers)
    layer = lgraph.Layers(i);
    
    if class(layer) == "nnet.cnn.layer.Convolution2DLayer" || ...
            class(layer) == "nnet.cnn.layer.FullyConnectedLayer"
        layerName = layer.Name;
        newLayer = layer;
        
        newLayer.WeightsInitializer = weightsInitializer;
        newLayer.BiasInitializer = biasInitializer;
        
        lgraph = replaceLayer(lgraph,layerName,newLayer);
    end
end

  • 実験用にtrainingOptionsオブジェクトを定義します。この例では、ネットワークの学習を 10 エポック行います。ミニバッチ サイズは 128 とし、5 エポックごとにネットワークを検証します。

miniBatchSize = 128;
validationFrequencyEpochs = 5;

numObservations = augimdsTrain.NumObservations;
numIterationsPerEpoch = floor(numObservations/miniBatchSize);
validationFrequency = validationFrequencyEpochs * numIterationsPerEpoch;

options = trainingOptions("sgdm", ...
    MaxEpochs=10, ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    ValidationData=augimdsValidation, ...
    ValidationFrequency=validationFrequency, ...
    Verbose=false);

[メトリクス] セクションは、実験結果を評価するオプションの関数を指定します。この例では、カスタムのメトリクス関数は含まれていません。

実験の実行

実験を実行すると、実験マネージャーはセットアップ関数で定義されたネットワークに複数回学習させます。試行ごとに、ハイパーパラメーター値の異なる組み合わせが使用されます。既定では、実験マネージャーは一度に 1 つの試行を実行します。Parallel Computing Toolbox™ を使用している場合は、複数の試行を同時に実行したり、クラスター内のバッチ ジョブとして実験をオフロードしたりできます。

  • 一度に 1 つの実験を実行するには、"実験マネージャー" ツールストリップの [モード][Sequential] を選択し、[実行] をクリックします。

  • 複数の試行を同時に実行するには、[モード][Simultaneous] を選択し、[実行] をクリックします。現在の並列プールがない場合、実験マネージャーは既定のクラスター プロファイルを使用して並列プールを起動します。その後、実験マネージャーは、並列プールにあるワーカーと同じ数の同時試行を実行します。最良の結果を得るには、実験を実行する前に、GPU と同じ数のワーカーで並列プールを起動します。詳細については、実験マネージャーを使用したネットワークの並列学習GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

  • 実験をバッチ ジョブとしてオフロードするには、[モード][Batch Sequential] または [Batch Simultaneous] を選択し、[クラスター][Pool Size] を指定して、[実行] をクリックします。詳細については、Offload Deep Learning Experiments as Batch Jobs to a Clusterを参照してください。

結果テーブルに、各試行の精度と損失が表示されます。

He 重み初期化子を使用する試行では、学習と検証の損失が数回の反復後に未定義になるため、実験マネージャーが学習を中断することに注意してください。これらの試行について学習を継続しても、有益な結果は得られません。結果テーブルの [ステータス] 列に、これらの試行が停止した理由 (Training loss is NaN) が示されます。

実験の完了後、列ごとに結果テーブルを並べ替えること、[フィルター] ペインを使用して試行をフィルター処理すること、注釈を追加して観測値を記録することができます。詳細については、実験結果の並べ替え、フィルター処理、および注釈追加を参照してください。

各試行の性能をテストするには、試行で使用した学習済みネットワークまたは試行の学習情報をエクスポートします。[実験マネージャー] ツールストリップで、[エクスポート][学習済みネットワーク] を選択するか、[エクスポート][学習情報] を選択します。詳細は、netinfoを参照してください。結果テーブルの内容を MATLAB ワークスペースにtable配列として保存するには、[エクスポート][結果テーブル] を選択します。

実験を閉じる

[実験ブラウザー] ペインでプロジェクトの名前を右クリックし、[プロジェクトを閉じる] を選択します。実験マネージャーによって、プロジェクトに含まれるすべての実験と結果が閉じられます。

セットアップ関数

この関数は、実験用の学習データ、ネットワーク アーキテクチャ、および学習オプションを構成します。この関数への入力は、ハイパーパラメーター テーブルのフィールドをもつ構造体です。この関数は、イメージ分類問題用のネットワークに学習させるために使用する 3 つの出力を返します。

function [augimdsTrain,lgraph,options] = WeightInitializerExperiment_setup(params)

事前学習済みのネットワークの読み込み

lgraph = googlenet(Weights="none");

学習データの読み込み

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

imageFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(imageFolder,"dir")
    disp("Downloading Flower Dataset (218 MB)...")
    websave(filename,url);
    untar(filename,downloadFolder)
end

imds = imageDatastore(imageFolder, ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);
inputSize = lgraph.Layers(1).InputSize;
augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain);
augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);

ネットワーク アーキテクチャの定義

numClasses = numel(categories(imdsTrain.Labels));
weightsInitializer = params.WeightsInitializer;
biasInitializer = params.BiasInitializer;

learnableLayer = findLayersToReplace(lgraph);
newLearnableLayer = fullyConnectedLayer(numClasses,Name="new_fc");
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);

for i = 1:numel(lgraph.Layers)
    layer = lgraph.Layers(i);
    
    if class(layer) == "nnet.cnn.layer.Convolution2DLayer" || ...
            class(layer) == "nnet.cnn.layer.FullyConnectedLayer"
        layerName = layer.Name;
        newLayer = layer;
        
        newLayer.WeightsInitializer = weightsInitializer;
        newLayer.BiasInitializer = biasInitializer;
        
        lgraph = replaceLayer(lgraph,layerName,newLayer);
    end
end

学習オプションの指定

miniBatchSize = 128;
validationFrequencyEpochs = 5;

numObservations = augimdsTrain.NumObservations;
numIterationsPerEpoch = floor(numObservations/miniBatchSize);
validationFrequency = validationFrequencyEpochs * numIterationsPerEpoch;

options = trainingOptions("sgdm", ...
    MaxEpochs=10, ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    ValidationData=augimdsValidation, ...
    ValidationFrequency=validationFrequency, ...
    Verbose=false);

end

置き換える層の検索

この関数は、層グラフ lgraph の単一の分類層とその前の学習可能な (全結合または畳み込みの) 層を検索します。

function [learnableLayer,classLayer] = findLayersToReplace(lgraph)

if ~isa(lgraph,"nnet.cnn.LayerGraph")
    error("Argument must be a LayerGraph object.")
end

src = string(lgraph.Connections.Source);
dst = string(lgraph.Connections.Destination);
layerNames = string({lgraph.Layers.Name}');

isClassificationLayer = arrayfun(@(l) ...
    (isa(l,"nnet.cnn.layer.ClassificationOutputLayer")|isa(l,"nnet.layer.ClassificationLayer")), ...
    lgraph.Layers);

if sum(isClassificationLayer) ~= 1
    error("Layer graph must have a single classification layer.")
end
classLayer = lgraph.Layers(isClassificationLayer);

currentLayerIdx = find(isClassificationLayer);
while true
    
    if numel(currentLayerIdx) ~= 1
        error("Layer graph must have a single learnable layer preceding the classification layer.")
    end
    
    currentLayerType = class(lgraph.Layers(currentLayerIdx));
    isLearnableLayer = ismember(currentLayerType, ...
        ["nnet.cnn.layer.FullyConnectedLayer","nnet.cnn.layer.Convolution2DLayer"]);
    
    if isLearnableLayer
        learnableLayer =  lgraph.Layers(currentLayerIdx);
        return
    end
    
    currentDstIdx = find(layerNames(currentLayerIdx) == dst);
    currentLayerIdx = find(src(currentDstIdx) == layerNames);
end
end


参考

アプリ

関数

関連するトピック