Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

深層学習を使用したイメージ カテゴリの分類

この例では、事前学習済みの畳み込みニューラル ネットワーク (CNN) を特徴抽出器として使用して、イメージ カテゴリ分類器に学習させる方法を説明します。

概要

畳み込みニューラル ネットワーク (CNN) は、深層学習の分野の強力な機械学習手法です。CNN はさまざまなイメージの大規模なコレクションを使用して学習します。CNN は、これらの大規模なコレクションから広範囲のイメージに対する豊富な特徴表現を学習します。これらの特徴表現は、多くの場合、HOG、LBP または SURF などの手作業で作成した特徴より性能が優れています。学習に時間や手間をかけずに CNN の能力を活用する簡単な方法は、事前学習済みの CNN を特徴抽出器として使用することです。

この例では、Flowers Dataset[5] からのイメージを、そのイメージから抽出した CNN の特徴量で学習されたマルチクラスの線形 SVM でカテゴリに分類します。このイメージ カテゴリの分類のアプローチは、イメージから特徴抽出した市販の分類器を学習する標準的な手法に従っています。たとえば、bag of features を使用したイメージ カテゴリの分類の例では、マルチクラス SVM に学習させる bag of features のフレームワーク内で SURF 特徴量を使用しています。ここでは HOG や SURF などのイメージ特徴を使用する代わりに、CNN を使って特徴量を抽出する点が異なります。

メモ: この例には、Deep Learning Toolbox™、Statistics and Machine Learning Toolbox™ および Deep Learning Toolbox™ Model for ResNet-50 Network が必要です。

この例を実行するには、CUDA 対応 NVIDIA™ GPU の使用が強く推奨されます。GPU を使用するには Parallel Computing Toolbox™ が必要です。サポートされる Compute Capability の詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

データの読み込み

カテゴリ分類器は Flowers Dataset [5] からのイメージで学習を行います。

% Location of the compressed data set
url = 'http://download.tensorflow.org/example_images/flower_photos.tgz';

% Store the output in a temporary folder
downloadFolder = tempdir;
filename = fullfile(downloadFolder,'flower_dataset.tgz');

メモ: データのダウンロードにかかる時間はインターネット接続の速度によって異なります。次の一連のコマンドは MATLAB を使用してデータをダウンロードし、MATLAB をブロックします。別の方法として、Web ブラウザーを使用して、データセットをローカル ディスクにまずダウンロードしておくことができます。Web からダウンロードしたファイルを使用するには、上記の変数 'outputFolder' の値を、ダウンロードしたファイルの場所に変更します。

% Uncompressed data set
imageFolder = fullfile(downloadFolder,'flower_photos');

if ~exist(imageFolder,'dir') % download only once
    disp('Downloading Flower Dataset (218 MB)...');
    websave(filename,url);
    untar(filename,downloadFolder)
end

データを管理しやすいよう ImageDatastore を使用してデータセットを読み込みます。ImageDatastore はイメージ ファイルの場所で動作するため、イメージを読み取るまでメモリに読み込まれません。したがって、大規模なイメージ コレクションを効率的に使用できます。

imds = imageDatastore(imageFolder, 'LabelSource', 'foldernames', 'IncludeSubfolders',true);

下記では、データセットに含まれる 1 つのカテゴリからのイメージ例を見ることができます。表示されるイメージは、Mario によるものです。

% Find the first instance of an image for each category
daisy = find(imds.Labels == 'daisy', 1);

figure
imshow(readimage(imds,daisy))

ここで、変数 imds には、イメージとそれぞれのイメージに関連付けられたカテゴリ ラベルが含められます。ラベルはイメージ ファイルのフォルダー名から自動的に割り当てられます。countEachLabel を使用して、カテゴリごとのイメージの数を集計します。

tbl = countEachLabel(imds)
tbl=5×2 table
      Label       Count
    __________    _____

    daisy          633 
    dandelion      898 
    roses          641 
    sunflowers     699 
    tulips         799 

上記の imds ではカテゴリごとに含まれるイメージの数が等しくないため、最初に調整することで、学習セット内のイメージ数のバランスを取ります。

% Determine the smallest amount of images in a category
minSetCount = min(tbl{:,2}); 

% Limit the number of images to reduce the time it takes
% run this example.
maxNumImages = 100;
minSetCount = min(maxNumImages,minSetCount);

% Use splitEachLabel method to trim the set.
imds = splitEachLabel(imds, minSetCount, 'randomize');

% Notice that each set now has exactly the same number of images.
countEachLabel(imds)
ans=5×2 table
      Label       Count
    __________    _____

    daisy          100 
    dandelion      100 
    roses          100 
    sunflowers     100 
    tulips         100 

事前学習済みのネットワークの読み込み

よく使われる事前学習済みネットワークはいくつかあります。これらの大半は ImageNet データセットで学習されています。このデータセットには 1000 個のオブジェクトのカテゴリと 120 万個の学習用イメージが含まれています [1]。"ResNet-50" はそうしたモデルの 1 つであり、Neural Network Toolbox™ の関数 resnet50 を使用して読み込むことができます。resnet50 を使用するには、まず resnet50 (Deep Learning Toolbox) をインストールする必要があります。

% Load pretrained network
net = resnet50();

ImageNet で学習されたその他のよく使用されるネットワークには AlexNet、GoogLeNet、VGG-16 および VGG-19 [3] があり、Deep Learning Toolbox™ の alexnetgooglenetvgg16vgg19 を使用して読み込むことができます。

ネットワークの可視化には、plot を使用します。これは非常に大規模なネットワークであるため、最初のセクションだけが表示されるように表示ウィンドウを調整します。

% Visualize the first section of the network. 
figure
plot(net)
title('First section of ResNet-50')
set(gca,'YLim',[150 170]);

最初の層は入力の次元を定義します。それぞれの CNN は入力サイズの要件が異なります。この例で使用される CNN には 224 x 224 x 3 のイメージ入力が必要です。

% Inspect the first layer
net.Layers(1)
ans = 
  ImageInputLayer with properties:

                      Name: 'input_1'
                 InputSize: [224 224 3]

   Hyperparameters
          DataAugmentation: 'none'
             Normalization: 'zerocenter'
    NormalizationDimension: 'auto'
                      Mean: [224×224×3 single]

中間層は CNN の大半を占めています。ここには、一連の畳み込み層とその間に正規化線形ユニット (ReLU) と最大プーリング層が不規則に配置されています [2]。これらの層に続いて 3 つの全結合層があります。

最後の層は分類層で、その特性は分類タスクに依存します。この例では、読み込まれた CNN モデルは 1000 とおりの分類問題を解決するよう学習されています。したがって、分類層には ImageNet データセットからの 1000 個のクラスがあります。

% Inspect the last layer
net.Layers(end)
ans = 
  ClassificationOutputLayer with properties:

            Name: 'ClassificationLayer_fc1000'
         Classes: [1000×1 categorical]
      OutputSize: 1000

   Hyperparameters
    LossFunction: 'crossentropyex'

% Number of class names for ImageNet classification task
numel(net.Layers(end).ClassNames)
ans = 1000

学習用データの準備

この CNN モデルは、元の分類タスクでは使用できないことに注意してください。これは Flowers Dataset 上の別の分類タスクを解決することを目的としているためです。

セットを学習データと検証データに分割します。各セットからイメージの 30% を学習データに選択し、残る 70% を検証データとします。結果が偏らないようにランダムな方法で分割します。学習セットとテスト セットは CNN モデルによって処理されます。

[trainingSet, testSet] = splitEachLabel(imds, 0.3, 'randomize');

前述のとおり、net は 224 行 224 列の RGB イメージのみ処理できます。すべてのイメージをこの形式で保存し直すのを避けるために、augmentedImageDatastore を使用してグレースケール イメージのサイズを変更して RGB に随時変換します。augmentedImageDatastore は、ネットワークの学習に使用する場合は、追加のデータ拡張にも使用できます。

% Create augmentedImageDatastore from training and test sets to resize
% images in imds to the size required by the network.
imageSize = net.Layers(1).InputSize;
augmentedTrainingSet = augmentedImageDatastore(imageSize, trainingSet, 'ColorPreprocessing', 'gray2rgb');
augmentedTestSet = augmentedImageDatastore(imageSize, testSet, 'ColorPreprocessing', 'gray2rgb');

CNN を使用した学習用特徴の抽出

CNN の各層は入力イメージに対する応答またはアクティベーションを生成します。ただし、CNN 内でイメージの特徴抽出に適している層は数層しかありません。ネットワークの始まりにある層が、エッジやブロブのようなイメージの基本的特徴を捉えます。これを確認するには、最初の畳み込み層からネットワーク フィルターの重みを可視化します。これにより、CNN から抽出された特徴がイメージの認識タスクでよく機能することが直感的に捉えられるようになります。深層の重みの特徴を可視化するには、Deep Learning Toolbox™ の deepDreamImage を使用します。

% Get the network weights for the second convolutional layer
w1 = net.Layers(2).Weights;

% Scale and resize the weights for visualization
w1 = mat2gray(w1);
w1 = imresize(w1,5); 

% Display a montage of network weights. There are 96 individual sets of
% weights in the first layer.
figure
montage(w1)
title('First convolutional layer weights')

ネットワークの最初の層が、ブロブとエッジの特徴を捉えるためにどのようにフィルターを学習するのかに注意してください。これらの「未熟な」特徴はネットワークのより深い層で処理され、初期の特徴と組み合わせてより高度なイメージ特徴を形成します。これらの高度な特徴は、すべての未熟な特徴をより豊富な 1 つのイメージ表現に組み合わせたものであるため、認識タスクにより適しています [4]。

activations メソッドを使用して、深層の 1 つから特徴を簡単に抽出できます。深層のうちどれを選択するかは設計上の選択ですが、通常は分類層の直前の層が適切な開始点となります。net ではこの層に 'fc1000' という名前が付けられています。この層を使用して学習用特徴を抽出します。

featureLayer = 'fc1000';
trainingFeatures = activations(net, augmentedTrainingSet, featureLayer, ...
    'MiniBatchSize', 32, 'OutputAs', 'columns');

アクティベーション関数では、GPU が利用可能な場合には自動的に GPU を使用して処理が行われ、GPU が利用できない場合には CPU が使用されます。

上記のコードでは、CNN およびイメージ データが必ず GPU メモリに収まるよう 'MiniBatchSize' は 32 に設定されます。GPU がメモリ不足となる場合は 'MiniBatchSize' の値を小さくする必要があります。また、アクティベーションの出力は列として並んでいます。これにより、その後のマルチクラス線形 SVM の学習が高速化されます。

CNN 特徴量を使用したマルチクラス SVM 分類器の学習

次に、CNN のイメージ特徴を使用してマルチクラス SVM 分類器に学習させます。関数 fitcecoc の 'Learners' パラメーターを 'Linear' に設定することで、高速の確率的勾配降下法ソルバーを学習に使用します。これにより、高次の CNN 特徴量のベクトルで作業する際に、学習を高速化できます。

% Get training labels from the trainingSet
trainingLabels = trainingSet.Labels;

% Train multiclass SVM classifier using a fast linear solver, and set
% 'ObservationsIn' to 'columns' to match the arrangement used for training
% features.
classifier = fitcecoc(trainingFeatures, trainingLabels, ...
    'Learners', 'Linear', 'Coding', 'onevsall', 'ObservationsIn', 'columns');

分類器の評価

ここまでに使用した手順を繰り返して、testSet からイメージの特徴を抽出します。その後、テスト用の特徴を分類器に渡し、学習済み分類器の精度を測定します。

% Extract test features using the CNN
testFeatures = activations(net, augmentedTestSet, featureLayer, ...
    'MiniBatchSize', 32, 'OutputAs', 'columns');

% Pass CNN image features to trained classifier
predictedLabels = predict(classifier, testFeatures, 'ObservationsIn', 'columns');

% Get the known labels
testLabels = testSet.Labels;

% Tabulate the results using a confusion matrix.
confMat = confusionmat(testLabels, predictedLabels);

% Convert confusion matrix into percentage form
confMat = bsxfun(@rdivide,confMat,sum(confMat,2))
confMat = 5×5

    0.8571    0.0286    0.0286    0.0714    0.0143
    0.0571    0.8286         0    0.0571    0.0571
    0.0143         0    0.7714    0.0714    0.1429
    0.0286    0.0571    0.0571    0.8000    0.0571
         0         0    0.2000    0.0286    0.7714

% Display the mean accuracy
mean(diag(confMat))
ans = 0.8057

イメージ内の花の分類

学習を行った分類器を適用して新しいイメージを分類します。「デイジー」テスト イメージの 1 つを読み込みます。

testImage = readimage(testSet,1);
testLabel = testSet.Labels(1)
testLabel = categorical
     daisy 

CNN を使用してイメージの特徴を抽出します。

% Create augmentedImageDatastore to automatically resize the image when
% image features are extracted using activations.
ds = augmentedImageDatastore(imageSize, testImage, 'ColorPreprocessing', 'gray2rgb');

% Extract image features using the CNN
imageFeatures = activations(net, ds, featureLayer, 'OutputAs', 'columns');

分類器を使用して予測を行います。

% Make a prediction using the classifier
predictedLabel = predict(classifier, imageFeatures, 'ObservationsIn', 'columns')
predictedLabel = categorical
     daisy 

参考文献

[1] Deng, Jia, et al. "Imagenet: A large-scale hierarchical image database." Computer Vision and Pattern Recognition, 2009. CVPR 2009. IEEE Conference on. IEEE, 2009.

[2] Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. "Imagenet classification with deep convolutional neural networks." Advances in neural information processing systems. 2012.

[3] Simonyan, Karen, and Andrew Zisserman."Very deep convolutional networks for large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).

[4] Donahue, Jeff, et al. "Decaf: A deep convolutional activation feature for generic visual recognition." arXiv preprint arXiv:1310.1531 (2013).

[5] Tensorflow: How to Retrain an Image Classifier for New Categories.

参考

(Deep Learning Toolbox) | | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Deep Learning Toolbox) | (Statistics and Machine Learning Toolbox) | (Statistics and Machine Learning Toolbox)

関連するトピック