DQN Agent with 512 discrete actions not learning

10 ビュー (過去 30 日間)
Raja Suryadevara
Raja Suryadevara 2021 年 5 月 3 日
I am using a DQN agent to train my network which takes three continuous observations error, derivative of the error and power output. The actions are activating switches which are 1 for 'on' and 0 for 'off', there are a total of 9 switches which is a total of 512 discrete combinations. I have no errors. My model is in a simulink environment. The episode Q0 values are exponentially high. Please let me know where I might be doing wrong. Below is my full code and attached is the simulink model I am using.
N = 9;
L = 2^N;
T = zeros(L,N);
for i=1:N
temp = [zeros(L/2^i,1); ones(L/2^i,1)];
T(:,i) = repmat(temp,2^(i-1),1);
end
[l, c ] = size (T) ;
b = cell (l,1);
for i =1 : l
b {i,: } = [ T(i,1) T(i,2) T(i,3) T(i,4) T(i,5) T(i,6) T(i,7) T(i,8) T(i,9)]';
end
mdl = 'InitRLModel';
open_system(mdl)
obsInfo = rlNumericSpec([3 1]);
actInfo = rlFiniteSetSpec(b);
env = rlSimulinkEnv('InitRLModel','InitRLModel/RLAgent',obsInfo,actInfo);
env.UseFastRestart = 'off';
Ts = 0.1;
env.ResetFcn = @(in)localResetFcn(in);
rng(0)
dnn = [
featureInputLayer(obsInfo.Dimension(1),'Normalization','none','Name','state')
fullyConnectedLayer(24,'Name','CriticStateFC1')
reluLayer('Name','CriticRelu1')
fullyConnectedLayer(24, 'Name','CriticStateFC2')
reluLayer('Name','CriticCommonRelu')
fullyConnectedLayer(length(actInfo.Elements),'Name','output')];
criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);
critic = rlQValueRepresentation(dnn,obsInfo,actInfo,'Observation',{'state'},criticOpts);
agentOpts = rlDQNAgentOptions(...
'UseDoubleDQN',false, ...
'TargetSmoothFactor',1, ...
'TargetUpdateFrequency',4, ...
'ExperienceBufferLength',100000, ...
'DiscountFactor',0.99, ...
'MiniBatchSize',256);
agent = rlDQNAgent(critic,agentOpts);
trainOpts = rlTrainingOptions(...
'MaxEpisodes',5000, ...
'MaxStepsPerEpisode',512, ...
'Verbose',false, ...
'Plots','training-progress',...
'StopTrainingCriteria','AverageReward',...
'StopTrainingValue',30000);
doTraining = true;
if doTraining
% Train the agent.
trainingStats = train(agent,env,trainOpts);
else
% Load the pretrained agent for the example.
load('agent7.mat','agent')
end
function in = localResetFcn(in)
blk = sprintf('InitRLModel/Microgrid Environment/Step1');
t1 = 50*randn;
while t1 <= 0 || t1 >= 100
t1 = 50*randn;
end
in = setBlockParameter(in,blk,'time',num2str(t1));
blk = sprintf('InitRLModel/Microgrid Environment/Step2');
t2 = 50*randn;
while t2 <= 0 || t2 >= 100
t2 = 50*randn;
end
in = setBlockParameter(in,blk,'time',num2str(t2));
blk = sprintf('InitRLModel/Microgrid Environment/NIM2');
pow = 100*randn + 100;
while pow <= 0 || pow >= 1000
pow = 100*randn + 100*randn;
end
in = setBlockParameter(in,blk,'Activepower',num2str(pow));
end

回答 (1 件)

Emmanouil Tzorakoleftherakis
Emmanouil Tzorakoleftherakis 2021 年 5 月 5 日
I would initially revisit the critic architecture for 2 reasons:
1) Network seems a little simple for a 3->512 mapping
2) This is somewhat confirmed by the abnormal Q0 behavior you are seeing.
Of course there could be many other reasons for not converging:
1) The reward may need tweaking
2) You may need to train for more time
3) You may need to increase exploration (epsilon min and epsilon decay rate specifically for DQN) - I would actually do that either way
4) You may need to change some of the agent's hyperparameters (e.g. mini-batch size)
Hope this helps
  2 件のコメント
Emmanouil Tzorakoleftherakis
Emmanouil Tzorakoleftherakis 2021 年 5 月 6 日
Using a scalingLayer would help on the surface but that won't change the fact that some of the internal weights of the neural net are blowing up.
We don't have any examples in the toolbox for such large action spaces, but I would first start by increasing #of neurons from 24->128 ++ and the other option would be to add another fully connected+relu layer to make the network deeper.

サインインしてコメントする。

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by