R-CNN オブジェクト検出ネットワークの作成
この例では、事前学習済みの ResNet-50 ネットワークを R-CNN オブジェクト検出ネットワークに変更する方法を説明します。この例で作成するネットワークは、trainRCNNObjectDetector
を使用して学習させることができます。
% Load pretrained ResNet-50. net = resnet50(); % Convert network into a layer graph object to manipulate the layers. lgraph = layerGraph(net);
ネットワークを R-CNN ネットワークに変換する手順は、イメージ分類の転移学習のワークフローと同じです。最後の 3 つの分類層を、検出するオブジェクト クラスの数と背景クラスをサポートできる新しい層で置き換えます。
ResNet-50 では、最後の 3 つの層は、fc1000、fc1000_softmax、および ClassificationLayer_fc1000 という名前です。ネットワークを表示して、変更するネットワークのセクションを拡大します。
figure plot(lgraph) ylim([-5 16])
% Remove the last 3 layers. layersToRemove = { 'fc1000' 'fc1000_softmax' 'ClassificationLayer_fc1000' }; lgraph = removeLayers(lgraph, layersToRemove); % Display the results after removing the layers. figure plot(lgraph) ylim([-5 16])
ネットワークに新しい分類層を追加します。層は、ネットワークが検出する必要のあるオブジェクトの数を分類し、追加の背景クラスも分類するように設定します。検出中、ネットワークでは、トリミングされたイメージ領域を処理して、イメージ領域がオブジェクト クラスの 1 つまたは背景に属するものとして分類します。
% Specify the number of classes the network should classify. numClassesPlusBackground = 2 + 1; % Define new classification layers newLayers = [ fullyConnectedLayer(numClassesPlusBackground, 'Name', 'rcnnFC') softmaxLayer('Name', 'rcnnSoftmax') classificationLayer('Name', 'rcnnClassification') ]; % Add new layers lgraph = addLayers(lgraph, newLayers); % Connect the new layers to the network. lgraph = connectLayers(lgraph, 'avg_pool', 'rcnnFC'); % Display the final R-CNN network. This can be trained using trainRCNNObjectDetector. figure plot(lgraph) ylim([-5 16])