Main Content

R-CNN オブジェクト検出ネットワークの作成

この例では、事前学習済みの ResNet-50 ネットワークを R-CNN オブジェクト検出ネットワークに変更する方法を説明します。この例で作成するネットワークは、trainRCNNObjectDetector を使用して学習させることができます。

% Load pretrained ResNet-50.
net = resnet50();

% Convert network into a layer graph object to manipulate the layers.
lgraph = layerGraph(net);

ネットワークを R-CNN ネットワークに変換する手順は、イメージ分類の転移学習のワークフローと同じです。最後の 3 つの分類層を、検出するオブジェクト クラスの数と背景クラスをサポートできる新しい層で置き換えます。

ResNet-50 では、最後の 3 つの層は、fc1000、fc1000_softmax、および ClassificationLayer_fc1000 という名前です。ネットワークを表示して、変更するネットワークのセクションを拡大します。

figure
plot(lgraph)
ylim([-5 16])

% Remove the last 3 layers. 
layersToRemove = {
    'fc1000'
    'fc1000_softmax'
    'ClassificationLayer_fc1000'
    };

lgraph = removeLayers(lgraph, layersToRemove);

% Display the results after removing the layers.
figure
plot(lgraph)
ylim([-5 16])

ネットワークに新しい分類層を追加します。層は、ネットワークが検出する必要のあるオブジェクトの数を分類し、追加の背景クラスも分類するように設定します。検出中、ネットワークでは、トリミングされたイメージ領域を処理して、イメージ領域がオブジェクト クラスの 1 つまたは背景に属するものとして分類します。

% Specify the number of classes the network should classify.
numClassesPlusBackground = 2 + 1;

% Define new classification layers
newLayers = [
    fullyConnectedLayer(numClassesPlusBackground, 'Name', 'rcnnFC')
    softmaxLayer('Name', 'rcnnSoftmax')
    classificationLayer('Name', 'rcnnClassification')
    ];

% Add new layers
lgraph = addLayers(lgraph, newLayers);

% Connect the new layers to the network. 
lgraph = connectLayers(lgraph, 'avg_pool', 'rcnnFC');

% Display the final R-CNN network. This can be trained using trainRCNNObjectDetector.
figure
plot(lgraph)
ylim([-5 16])