MATLAB Answers

How can I transfer the model parameters of a well-trained NN to another one?

4 ビュー (過去 30 日間)
Wanli Wen
Wanli Wen 2019 年 11 月 24 日
回答済み: Divya Gaddipati 2019 年 12 月 5 日
I have two NNs, i.e., net_1 and net_2, where net_1 is not trained and net_2 has been well trained. Now I want to transfer the knowledge of net_2 to net_1, such that net_1 can be used well as net_2. So I have got the following code. However, after setting the weights and bias of net_1 to those of net_2, I find that the net_1 behaves very very bad, e.g., net_2(-2) = 3.999, net_1(-2)=32.249. Here, net_1 is expected to output a value that is very similar with net_2. May anone please tell me that is there anything wrong with my code? Thanks.
(Please note that I do not want to use the operation net_1 = net_2 to achieve this purpose.)
clear all
% Task: To fit a non-linear function f(x) = x.^2
D=1e4; % no. of training sample
%% Net 1: no training network
net_1 = feedforwardnet(layers_neurons);
[data1,target2] = gen_data_sample(10);
net_1 = configure(net_1, data1, target2);
%% Net 2: well training network
[data2,target2] = gen_data_sample(D);
net_2 = feedforwardnet(layers_neurons); % doc feedforwardnet for more details
net_2 = configure(net_2, data2, target2);
net_2 = train(net_2,data2, target2); % , 'useGPU', 'yes', 'useparallel', 'yes'
%% Transfer the knowledge of Net 2 to Net 1
net_1.IW = net_2.IW;
net_1.LW = net_2.LW;
net_1.b = net_2.b;
%% Test and Compare Net 1 and Net 2
function [input,output] = gen_data_sample(D)
input = -20+(20-(-20))*rand(1, D);
output = input.^2;

  0 件のコメント

サインイン to comment.


Divya Gaddipati
Divya Gaddipati 2019 年 12 月 5 日
Before you assign weights of “net_2” to “net_1”, initialize net_1 to net_2 using the init function
net_1 = init(net_2);
This would resolve your issue.
Additionally, you can also remove the configuring part of net_1 (i.e., line 10 in your code), which might not be required if you are using init.
Hope this helps!

  0 件のコメント

サインイン to comment.

その他の回答 (0 件)

サインイン してこの質問に回答します。




Translated by