How to add new classes to a neural network?
3 ビュー (過去 30 日間)
古いコメントを表示
I made myself a network for flowers recognition. It's pretty much a copy of Alex net, but with some layers deleted. I trained it with 5 classes, but now i want to add more. How can i do that without retrain it from 0?
allImages = imageDatastore('D:\stuff machine learning\flowers', 'IncludeSubfolders', true,... 'LabelSource', 'foldernames');
[trainingImages, testImages] = splitEachLabel(allImages, 0.8, 'randomize');
conv1 = convolution2dLayer(11,96,'Stride',4,'Padding',0); %290.5k neuroni conv2 = convolution2dLayer(5,256,'Stride',1,'Padding',2); %7milioane neuroni conv3 = convolution2dLayer(3,384,'Stride',1,'Padding',1); conv4 = convolution2dLayer(3,384,'Stride',1,'Padding',1); conv5 = convolution2dLayer(3,256,'Stride',1,'Padding',1); layers = [... imageInputLayer([227 227 3]); conv1; reluLayer('Name','relu1'); maxPooling2dLayer(3,'Name','pool1','Stride',2); conv2; reluLayer('Name','relu2'); maxPooling2dLayer(3,'Name','pool2','Stride',2); conv3; reluLayer('Name','relu3'); conv4; reluLayer('Name','relu4'); conv5; reluLayer('Name','relu5'); maxPooling2dLayer(3,'Name','pool5','Stride',2); fullyConnectedLayer(4096,'Name','fc6'); reluLayer('Name','relu6'); dropoutLayer('Name','drop6'); fullyConnectedLayer(4096,'Name','fc7'); reluLayer('Name','relu7'); dropoutLayer('Name','drop7'); fullyConnectedLayer(5,'Name','fc8'); softmaxLayer('Name','prob'); classificationLayer('Name','output');]
opts = trainingOptions('sgdm', ... 'InitialLearnRate', 0.001, ... 'LearnRateSchedule', 'piecewise', ... 'LearnRateDropFactor', 0.1, ... 'LearnRateDropPeriod', 10, ... 'L2Regularization', 0.008, ... 'MaxEpochs', 30, ... 'MiniBatchSize', 40, ... 'ValidationData',testImages, ... 'Verbose', true,... 'Plot','training-progress');
testImages.ReadFcn = @readFunctionTrain1; trainingImages.ReadFcn = @readFunctionTrain1; %antrenarea retelei myNet = trainNetwork(trainingImages, layers, opts);
[YPred,probs] = classify(myNet,testImages); accuracy = mean(YPred == testImages.Labels)
idx = randperm(numel(testImages.Files),4); figure for i = 1:4 subplot(2,2,i) I = readimage(testImages,idx(i)); imshow(I) label = YPred(idx(i)); title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%"); end
This is the network
1 件のコメント
Balakrishnan Rajan
2018 年 10 月 16 日
I am trying to do the same thing. Theoretically this should be done by changing the dimension of the Weights matrix, Bias vector and the OutputSize of the fully connected layer and the OutputSize of the classoutput layer and add the new category label to the Classes object. However, these properties are set to read-only.
Peter Gadfort provided a solution in this thread. However, I cant change the OutputSize as this is still a read-only property. If you do find a solution, please post it.
The code I am trying is this:
% Adding new classes to a trained net
%%Create an editable net object
load('BestNet.mat')
TempNet = net.saveobj;
%%Edit the properties of the fully connected layer
FCLayer = TempNet.Layers(142,1);
FCOutputSize = FCLayer.OutputSize;
FCLayer.OutputSize = FCOutputSize+1;
FCWeights = FCLayer.Weights;
FCWsize = size(FCWeights)
FCLayer.Weights = rand(FCWsize(1)+1, FCWsize(2));
FCLayer.Weights(1:FCWsize(1),:) = FCWeights;
FCBias = FCLayer.Bias;
FCLayer.Bias = rand(size(FCBias)+1);
FCLayer.Bias(1:size(FCBias)) = FCBias;
%%Edit the properties of the output layer
OutputLayer = TempNet.Layers(144,1);
OLOutputSize = OutputLayer.OutputSize;
OutputLayer.OutputSize = OLOutputSize + 1;
OLClasses = OutputLayer.Classes;
OLClasses(size(OLClasses)+1) = 'Obstructed';
%%Make this the net
net = load.obj(TempNet);
The pretrained net that I am using is the GoogLeNet derivative with the last three layers changed to a fully connected layer, a softmax layer followed by a crossentropy loss. I am adding a new class called "obstructed". Alphabetically sorted, this is the last class which is why I add the new elements to the end of the older elements.
回答 (0 件)
参考
カテゴリ
Help Center および File Exchange で Image Data Workflows についてさらに検索
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!