Main Content

イメージ分類用のビジョン トランスフォーマー ネットワークの学習

この例では、事前学習済みのビジョン トランスフォーマー (ViT) ニューラル ネットワークを微調整して、新しいイメージ コレクションの分類を実行する方法を示します。JAB

ViT [1] は、トランスフォーマー アーキテクチャを使用してイメージ入力を特徴ベクトルに符号化するニューラル ネットワーク モデルです。ネットワークは、バックボーンおよびヘッドから成る 2 つの主要コンポーネントで構成されます。バックボーンは、ネットワークの符号化ステップを担当します。バックボーンは入力イメージを受け取り、特徴のベクトルを出力します。ヘッドは予測を担当します。ヘッドは符号化された特徴ベクトルを予測スコアにマッピングします。

この例の事前学習済み ViT ネットワークは、イメージの強い特徴表現を学習しています。転移学習を使用して、特定のタスクに合わせてこのモデルを微調整できます。この特徴表現を転移させ、新しいデータ セットに合わせて微調整するには、ネットワークのヘッドをタスク データ分類用の新しいヘッドに置き換えた後、新しいデータ セットでネットワークを微調整します。

この図は、K 個のクラスの予測を行う ViT ネットワークのアーキテクチャ、およびこのネットワークを編集して K* 個のクラスを含む新しいデータ セットの転移学習を有効にする方法の概要を示したものです。

この例では、パッチ サイズが 16 である基本サイズの ViT モデル (8,680 万パラメーター) を微調整します。これは、解像度 384×384 で ImageNet 2012 データ セットを使用して微調整されます。

事前学習済みの ViT ネットワークの読み込み

関数visionTransformerを使用して、事前学習済みの ViT ネットワークを読み込みます。この関数には、Deep Learning Toolbox™ ライセンスと Computer Vision Toolbox™ Model for Vision Transformer Network サポート パッケージが必要です。このサポート パッケージは、アドオン エクスプローラーからダウンロードできます。サポート パッケージがインストールされていない場合は、関数によってダウンロード リンクが表示されます。

net = visionTransformer
net = 
  dlnetwork with properties:

         Layers: [143×1 nnet.cnn.layer.Layer]
    Connections: [167×2 table]
     Learnables: [200×3 table]
          State: [0×3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 1

  View summary with summary.

ネットワークの入力サイズを表示します。

inputSize = net.Layers(1).InputSize
inputSize = 1×3

   384   384     3

ViT ネットワークを微調整するには、通常、注意層のみを微調整し、他の学習可能なパラメーターは凍結します [2]。この例にサポート ファイルとして添付されている関数 freezeNetwork を使用して、ネットワークの重みを凍結します。この関数にアクセスするには、例をライブ スクリプトとして開きます。

net = freezeNetwork(net,LayersToIgnore="SelfAttentionLayer");

学習データの読み込み

Flowers データ セット [3] をダウンロードし、解凍します。データ セットのサイズは約 218 MB で、5 つのクラス ("デイジー""タンポポ""バラ""ヒマワリ"、および "チューリップ") に属する 3670 個の花のイメージが格納されています。

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

imageFolder = fullfile(downloadFolder,"flower_photos");
if ~datasetExists(imageFolder)
    disp("Downloading Flowers data set (218 MB)...")
    websave(filename,url);
    untar(filename,downloadFolder)
end

イメージを格納するイメージ データストアを作成します。

imds = imageDatastore(imageFolder,IncludeSubfolders=true,LabelSource="foldernames");

クラス数を表示します。

classNames = categories(imds.Labels);
numClasses = numel(categories(imds.Labels))
numClasses = 5

関数 splitEachLabel を使用して、データストアを学習、検証、テストの各パーティションに分割します。イメージの 80% を学習用に使用し、10% を検証用、10% をテスト用に確保します。

[imdsTrain,imdsValidation,imdsTest] = splitEachLabel(imds,0.8,0.1);

学習を改善するには、ランダムな回転、スケーリング、水平方向の反転を含むように学習データを拡張します。ネットワークの入力サイズと一致するようにイメージのサイズを変更します。

augmenter = imageDataAugmenter( ...
    RandXReflection=true, ...
    RandRotation=[-90 90], ...
    RandScale=[1 2]);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain,DataAugmentation=augmenter);

ネットワーク入力サイズと一致するように検証イメージとテスト イメージをサイズ変更する拡張イメージ データストアを作成します。検証データおよびテスト データにその他の拡張は適用しません。

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
augimdsTest = augmentedImageDatastore(inputSize(1:2),imdsTest);

ネットワーク分類ヘッドの置き換え

ViT ネットワークには 2 つの主要なコンポーネントがあります。ネットワーク本体は入力イメージから特徴を抽出します。分類ヘッドは、抽出された特徴を各クラスの予測スコアを表す確率ベクトルにマッピングします。ニューラル ネットワークに学習させて新しいクラスのセットでイメージを分類するには、分類ヘッドを、抽出された特徴を新しいクラスのセットの予測スコアにマッピングする新しい分類ヘッドに置き換えます。

関数 analyzeNetwork を使用してネットワーク アーキテクチャを表示します。抽出された特徴を予測スコアのベクトルにマッピングする層をネットワークの最後に配置します。この場合、"head" という名前の全結合層は、抽出された特徴を長さ 1000 (予測のためにネットワークに学習させるクラス数) のベクトルにマッピングします。"softmax" という名前のソフトマックス層は、それらのベクトルを確率ベクトルにマッピングします。

analyzeNetwork(net)

学習データ内のクラス数と一致する出力サイズをもつ新しい全結合層を、次のようにして作成します。

  • 出力サイズを学習データのクラス数に設定する。

  • 層名を "head" に設定する。

layer = fullyConnectedLayer(numClasses,Name="head");

関数replaceLayer (Deep Learning Toolbox)を使用して、全結合層を新しい層に置き換えます。ソフトマックス層には学習可能なパラメーターがないため、置き換える必要はありません。

net = replaceLayer(net,"head",layer);

学習オプションの指定

学習オプションを指定します。オプションの中から選択するには、経験的解析が必要です。実験を実行してさまざまな学習オプションの構成を調べるには、Experiment Managerアプリを使用できます。

  • Adam オプティマイザーを使用して学習させます。

  • 微調整するには、学習率を 0.0001 に下げます。

  • 学習を 4 エポック行います。

  • ミニバッチ サイズとして 12 を使用します。ViT ネットワークの学習には通常、大量のメモリが必要です。メモリが不足する場合は、より小さいミニバッチ サイズの使用を試します。または、関数 visionTransformer でモデル名として "tiny-16-imagenet-384" を指定して、極小サイズの ViT モデル (570 万パラメーター) などのより小さなモデルの仕様を試します。

  • エポックごとに 1 回、検証データを使用してネットワークを検証します。

  • 検証損失が最も低くなるネットワークを出力します。

  • 学習の進行状況をプロットで監視し、精度メトリクスを監視します。

  • 詳細出力を無効にします。

miniBatchSize = 12;

numObservationsTrain = numel(augimdsTrain.Files);
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);

options = trainingOptions("adam", ...
    MaxEpochs=4, ...
    InitialLearnRate=0.0001, ...
    MiniBatchSize=miniBatchSize, ...
    ValidationData=augimdsValidation, ...
    ValidationFrequency=numIterationsPerEpoch, ...
    OutputNetwork="best-validation", ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

ニューラル ネットワークの学習

関数trainnet (Deep Learning Toolbox)を使用してニューラル ネットワークに学習させます。分類には、クロスエントロピー損失を使用します。既定では、関数 trainnet は利用可能な GPU がある場合にそれを使用します。GPU での学習には、Parallel Computing Toolbox™ ライセンスとサポートされている GPU デバイスが必要です。サポートされているデバイスについては、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。そうでない場合、関数 trainnet は CPU を使用します。実行環境を指定するには、ExecutionEnvironment 学習オプションを使用します。

この例では、24 GB の RAM を搭載した NVIDIA Titan RTX GPU を使用してネットワークに学習させます。学習の実行には約 37 分かかります。

net = trainnet(augimdsTrain,net,"crossentropy",options);

ニューラル ネットワークのテスト

テスト データを使用してネットワークの精度を評価します。

テスト データを使用して、予測を実行します。予測スコアをクラス ラベルに変換するには、関数onehotdecode (Deep Learning Toolbox)を使用します。

YTest = minibatchpredict(net,augimdsTest);
YTest = onehotdecode(YTest,classNames,2);

テスト分類結果を混同行列に表示します。

figure
TTest = imdsTest.Labels;
confusionchart(TTest,YTest)

テストの精度を評価します。

accuracy = mean(YTest == TTest)
accuracy = 0.9564

新しいデータを使用した予測

学習済みのニューラル ネットワークを使用し、テスト データの最初のイメージを使用して予測を行います。

テスト データの最初のファイルからイメージを読み取ります。

idx = 1;
testData = readByIndex(augimdsTest,idx);
I = testData.input{1};

イメージを使用して予測します。

Y = minibatchpredict(net,single(I));

関数 onehotdecode を使用して、確率が最も高いラベルを取得します。

label = onehotdecode(Y,classNames,2);

イメージと予測ラベルを表示します。

imshow(I)
title(label)

fprintf("Image Credit: %s\n",flowerCredit(augimdsValidation.Files(idx)))
Image Credit: CC-BY by mikeyskatie - https://www.flickr.com/photos/mikeyskatie/5948835387/

参考文献

  1. Dosovitskiy, Alexey, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani et al. "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale." Preprint, submitted June 3, 2021. https://doi.org/10.48550/arXiv.2010.11929

  2. Touvron, Hugo, Matthieu Cord, Alaaeldin El-Nouby, Jakob Verbeek, and Hervé Jégou. "Three things everyone should know about vision transformers." In Computer Vision–ECCV 2022, edited by Shai Avidan, Gabriel Brostow, Moustapha Cissé, Giovanni Maria Farinella, and Tal Hassner, 13684: 497-515. Cham: Springer Nature Switzerland, 2022. https://doi.org/10.1007/978-3-031-20053-3_29.

  3. TensorFlow. “Tf_flowers | TensorFlow Datasets.” Accessed June 16, 2023. https://www.tensorflow.org/datasets/catalog/tf_flowers.

参考

| | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox)

関連するトピック