Main Content

このページの翻訳は最新ではありません。ここをクリックして、英語の最新版を参照してください。

転移学習用の事前学習済みのネットワークを複数試用

この例では、転移学習のためにさまざまな事前学習済みのネットワークの層を置き換える実験を構成する方法を説明します。転移学習は、深層学習アプリケーションでよく使用されています。事前学習済みのネットワークを取得して、新しいタスクの学習の開始点として使用できます。通常は、転移学習によってネットワークを微調整する方が、ランダムに初期化された重みでゼロからネットワークに学習させるよりもはるかに簡単で時間がかかりません。少ない数の学習イメージを使用して、新しいタスクに学習済みの特徴を高速に転移できます。

Deep Learning Toolbox™ には、利用できる事前学習済みのネットワークが多数あります。これらの事前学習済みのネットワークには、問題に適用するネットワークを選択する際に重要になるさまざまな特性があります。最も重要な特性は、ネットワークの精度、速度、およびサイズです。ネットワークの選択には、通常、これらの特性の間のトレードオフが生じます。目的のタスクに向けてさまざまな事前学習済みのネットワークの性能を比較するには、この実験を編集し、どの事前学習済みのネットワークを使用するかを指定します。

この実験には、Deep Learning Toolbox Model for GoogLeNet Network サポート パッケージと、Deep Learning Toolbox Model for ResNet-18 Network サポート パッケージが必要です。実験を実行する前に、関数googlenetresnet18を呼び出し、ダウンロード リンクをクリックして、これらのサポート パッケージをインストールしてください。アドオン エクスプローラーからダウンロードできるその他の事前学習済みネットワークの詳細については、事前学習済みの深層ニューラル ネットワークを参照してください。

実験を開く

まず、例を開きます。実験マネージャーによって、検証と実行が可能な事前構成済みの実験を含むプロジェクトが読み込まれます。実験を開くには、[実験ブラウザー] ペインで、実験の名前 (TransferLearningExperiment) をダブルクリックします。

組み込みの学習実験は、説明、ハイパーパラメーターのテーブル、セットアップ関数、および実験の結果を評価するためのメトリクス関数の集合で構成されます。詳細については、組み込み学習実験の構成を参照してください。

[説明] フィールドには、実験を説明するテキストが表示されます。この例の説明は次のようになります。

Perform transfer learning by replacing layers in a pretrained network.

[ハイパーパラメーター] セクションでは、実験で使用する手法 (Exhaustive Sweep) とハイパーパラメーター値を指定します。実験を実行すると、実験マネージャーは、ハイパーパラメーター テーブルで指定されたハイパーパラメーター値のすべての組み合わせを使用してネットワークに学習させます。この例では、ハイパーパラメーター NetworkName によって、学習を行うネットワークと学習オプション miniBatchSize の値が指定されます。

[セットアップ関数] は、実験用の学習データ、ネットワーク アーキテクチャ、および学習オプションを構成します。セットアップ関数への入力は、ハイパーパラメーター テーブルのフィールドをもつ構造体です。このセットアップ関数は、イメージ分類問題用のネットワークに学習させるために使用する 3 つの出力を返します。この例のセットアップ関数では以下を行います。

  • ハイパーパラメーター NetworkName に対応する事前学習済みのネットワークの読み込み。

networkName = params.NetworkName;
switch networkName
    case "squeezenet"
        net = squeezenet;
        miniBatchSize = 128;
    case "googlenet"
        net = googlenet;
        miniBatchSize = 128;
    case "resnet18"
        net = resnet18;
        miniBatchSize = 128;
    case "mobilenetv2"
        net = mobilenetv2;
        miniBatchSize = 128;
    case "resnet50"
        net = resnet50;
        miniBatchSize = 128;
    case "resnet101"
        net = resnet101;
        miniBatchSize = 64;
    case "inceptionv3"
        net = inceptionv3;
        miniBatchSize = 64;
    case "inceptionresnetv2"
        net = inceptionresnetv2;
        miniBatchSize = 64;
    otherwise
        error("Undefined network selection.");
end
  • 花のデータセット (約 218 MB) のダウンロードと抽出。このデータセットの詳細については、イメージ データセットを参照してください。

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");
imageFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(imageFolder,"dir")
    disp("Downloading Flower Dataset (218 MB)...")
    websave(filename,url);
    untar(filename,downloadFolder)
end
imds = imageDatastore(imageFolder, ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);
inputSize = net.Layers(1).InputSize;
augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain);
augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);
  • 事前学習済みネットワークの学習可能な層を置き換えて、転移学習を実行します。補助関数 findLayersToReplace (この例の最後の付録 2 に掲載) は、転移学習用に置き換えるネットワーク アーキテクチャの層を判定します。使用可能な事前学習済みのネットワークの詳細については、事前学習済みの深層ニューラル ネットワークを参照してください。

lgraph = layerGraph(net);
[learnableLayer,classLayer] = findLayersToReplace(lgraph);
numClasses = numel(categories(imdsTrain.Labels));
if isa(learnableLayer,"nnet.cnn.layer.FullyConnectedLayer")
    newLearnableLayer = fullyConnectedLayer(numClasses, ...
        Name="new_fc", ...
        WeightLearnRateFactor=10, ...
        BiasLearnRateFactor=10);
elseif isa(learnableLayer,"nnet.cnn.layer.Convolution2DLayer")
    newLearnableLayer = convolution2dLayer(1,numClasses, ...
        Name="new_conv", ...
        WeightLearnRateFactor=10, ...
        BiasLearnRateFactor=10);
end
lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);
newClassLayer = classificationLayer(Name="new_classoutput");
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
  • 実験用にtrainingOptionsオブジェクトを定義します。この例では、ネットワークの学習を 10 エポック行います。初期学習率は 0.0003 とし、5 エポックごとにネットワークを検証します。

validationFrequencyEpochs = 5;
numObservations = augimdsTrain.NumObservations;
numIterationsPerEpoch = floor(numObservations/miniBatchSize);
validationFrequency = validationFrequencyEpochs * numIterationsPerEpoch;
options = trainingOptions("sgdm", ...
    MaxEpochs=10, ...
    MiniBatchSize=miniBatchSize, ...
    InitialLearnRate=3e-4, ...
    Shuffle="every-epoch", ...
    ValidationData=augimdsValidation, ...
    ValidationFrequency=validationFrequency, ...
    Verbose=false);

セットアップ関数を検査するには、[セットアップ関数][編集] をクリックします。MATLAB® エディターでセットアップ関数が開きます。また、セットアップ関数のコードは、この例の最後の付録 1 にあります。

[メトリクス] セクションは、実験結果を評価するオプションの関数を指定します。この例では、カスタムのメトリクス関数は含まれていません。

実験の実行

実験を実行すると、実験マネージャーはセットアップ関数で定義されたネットワークに 6 回学習させます。試行ごとに、ハイパーパラメーター値の異なる組み合わせが使用されます。既定では、実験マネージャーは一度に 1 つの試行を実行します。Parallel Computing Toolbox™ を使用している場合は、複数の試行を同時に実行したり、クラスター内のバッチ ジョブとして実験をオフロードしたりできます。

  • 一度に 1 つの実験を実行するには、実験マネージャー ツールストリップの [モード][Sequential] を選択し、[実行] をクリックします。

  • 複数の試行を同時に実行するには、[モード][Simultaneous] を選択し、[実行] をクリックします。現在の並列プールがない場合、実験マネージャーは既定のクラスター プロファイルを使用して並列プールを起動します。次に、実験マネージャーは、使用可能な並列ワーカーの数に応じて、複数の同時試行を実行します。最良の結果を得るには、実験を実行する前に、GPU と同じ数のワーカーで並列プールを起動します。詳細については、実験マネージャーを使用したネットワークの並列学習リリース別の GPU サポート (Parallel Computing Toolbox)を参照してください。

  • 実験をバッチ ジョブとしてオフロードするには、[モード][Batch Sequential] または [Batch Simultaneous] を選択し、[クラスター][Pool Size] を指定して、[実行] をクリックします。詳細については、Offload Experiments as Batch Jobs to Clusterを参照してください。

結果テーブルに、各試行の精度と損失が表示されます。実験の実行中に [学習プロット] をクリックすると、学習プロットが表示され、各試行の進行状況を追跡できます。[混同行列] をクリックし、完了した各試行の検証データの混同行列を表示します。

実験の完了後、列ごとに結果テーブルを並べ替えること、[フィルター] ペインを使用して試行をフィルター処理すること、注釈を追加して観測値を記録することができます。詳細については、実験結果の並べ替え、フィルター処理、および注釈追加を参照してください。

各試行の性能をテストするには、試行で使用した学習済みネットワークまたは試行の学習情報をエクスポートします。[実験マネージャー] ツールストリップで、[エクスポート][学習済みネットワーク] を選択するか、[エクスポート][学習情報] を選択します。詳細は、netinfoを参照してください。結果テーブルの内容を MATLAB ワークスペースにtable配列として保存するには、[エクスポート][結果テーブル] を選択します。

実験を閉じる

[実験ブラウザー] ペインでプロジェクトの名前を右クリックし、[プロジェクトを閉じる] を選択します。実験マネージャーによって、プロジェクトに含まれるすべての実験と結果が閉じられます。

付録 1: セットアップ関数

この関数は、実験用の学習データ、ネットワーク アーキテクチャ、および学習オプションを構成します。

入力

  • params は、実験マネージャーのハイパーパラメーター テーブルのフィールドをもつ構造体です。

出力

  • augimdsTrain は、学習データ用の拡張イメージ データストアです。

  • lgraph は、ニューラル ネットワーク アーキテクチャを定義する層グラフです。

  • optionstrainingOptions オブジェクトです。

function [augimdsTrain,lgraph,options] = TransferLearningExperiment_setup1(params)

networkName = params.NetworkName;

switch networkName
    case "squeezenet"
        net = squeezenet;
        miniBatchSize = 128;
    case "googlenet"
        net = googlenet;
        miniBatchSize = 128;
    case "resnet18"
        net = resnet18;
        miniBatchSize = 128;
    case "mobilenetv2"
        net = mobilenetv2;
        miniBatchSize = 128;
    case "resnet50"
        net = resnet50;
        miniBatchSize = 128;
    case "resnet101"
        net = resnet101;
        miniBatchSize = 64;
    case "inceptionv3"
        net = inceptionv3;
        miniBatchSize = 64;
    case "inceptionresnetv2"
        net = inceptionresnetv2;
        miniBatchSize = 64;
    otherwise
        error("Undefined network selection.");
end

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

imageFolder = fullfile(downloadFolder,"flower_photos");
if ~exist(imageFolder,"dir")
    disp("Downloading Flower Dataset (218 MB)...")
    websave(filename,url);
    untar(filename,downloadFolder)
end

imds = imageDatastore(imageFolder, ...
    IncludeSubfolders=true, ...
    LabelSource="foldernames");

[imdsTrain,imdsValidation] = splitEachLabel(imds,0.9);
inputSize = net.Layers(1).InputSize;
augimdsTrain = augmentedImageDatastore(inputSize,imdsTrain);
augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);

lgraph = layerGraph(net);
[learnableLayer,classLayer] = findLayersToReplace(lgraph);
numClasses = numel(categories(imdsTrain.Labels));

if isa(learnableLayer,"nnet.cnn.layer.FullyConnectedLayer")
    newLearnableLayer = fullyConnectedLayer(numClasses, ...
        Name="new_fc", ...
        WeightLearnRateFactor=10, ...
        BiasLearnRateFactor=10);
elseif isa(learnableLayer,"nnet.cnn.layer.Convolution2DLayer")
    newLearnableLayer = convolution2dLayer(1,numClasses, ...
        Name="new_conv", ...
        WeightLearnRateFactor=10, ...
        BiasLearnRateFactor=10);
end

lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);

newClassLayer = classificationLayer(Name="new_classoutput");
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);

validationFrequencyEpochs = 5;

numObservations = augimdsTrain.NumObservations;
numIterationsPerEpoch = floor(numObservations/miniBatchSize);
validationFrequency = validationFrequencyEpochs * numIterationsPerEpoch;

options = trainingOptions("sgdm", ...
    MaxEpochs=10, ...
    MiniBatchSize=miniBatchSize, ...
    InitialLearnRate=3e-4, ...
    Shuffle="every-epoch", ...
    ValidationData=augimdsValidation, ...
    ValidationFrequency=validationFrequency, ...
    Verbose=false);

end

付録 2: 置き換える層の検索

この関数は、層グラフ lgraph の単一の分類層とその前の学習可能な (全結合または畳み込みの) 層を検索します。

function [learnableLayer,classLayer] = findLayersToReplace(lgraph)

if ~isa(lgraph,"nnet.cnn.LayerGraph")
    error("Argument must be a LayerGraph object.")
end

src = string(lgraph.Connections.Source);
dst = string(lgraph.Connections.Destination);
layerNames = string({lgraph.Layers.Name}');

isClassificationLayer = arrayfun(@(l) ...
    (isa(l,"nnet.cnn.layer.ClassificationOutputLayer")|isa(l,"nnet.layer.ClassificationLayer")), ...
    lgraph.Layers);

if sum(isClassificationLayer) ~= 1
    error("Layer graph must have a single classification layer.")
end
classLayer = lgraph.Layers(isClassificationLayer);

currentLayerIdx = find(isClassificationLayer);
while true
    
    if numel(currentLayerIdx) ~= 1
        error("Layer graph must have a single learnable layer preceding the classification layer.")
    end
    
    currentLayerType = class(lgraph.Layers(currentLayerIdx));
    isLearnableLayer = ismember(currentLayerType, ...
        ["nnet.cnn.layer.FullyConnectedLayer","nnet.cnn.layer.Convolution2DLayer"]);
    
    if isLearnableLayer
        learnableLayer =  lgraph.Layers(currentLayerIdx);
        return
    end
    
    currentDstIdx = find(layerNames(currentLayerIdx) == dst);
    currentLayerIdx = find(src(currentDstIdx) == layerNames);
end
end


参考

アプリ

関数

関連するトピック