How to save pretrained DQN agent and extract the weights inside the network?

3 ビュー (過去 30 日間)
Kuan Yi Li
Kuan Yi Li 2024 年 8 月 28 日
編集済み: praguna manvi 2024 年 8 月 29 日
The following is part of the program. I want to know how to extract the weight values from the trained DQN network.
DQNnet = [
imageInputLayer([1 520 1],"Name","ImageFeatureInput","Normalization","none")
fullyConnectedLayer(1024,"Name","fc1")
reluLayer("Name","relu1")
% fullyConnectedLayer(512,"Name","fc2")
% reluLayer("Name","relu2")
fullyConnectedLayer(14,"Name","fc3")
softmaxLayer("Name","softmax")
classificationLayer("Name","ActionOutput")];
ObsInfo = getObservationInfo(env);
ActInfo = getActionInfo(env);
DQNOpts = rlRepresentationOptions('LearnRate',0.0001,'GradientThreshold',1,'UseDevice','gpu');
DQNagent = rlQValueRepresentation(DQNnet,ObsInfo,ActInfo,'Observation',{'ImageFeatureInput'},'ActionInputNames',{'BoundingBox Actions'},DQNOpts);
agentOpts = rlDQNAgentOptions(...
'UseDoubleDQN',true ...
,'MiniBatchSize',256);
agentOpts.EpsilonGreedyExploration.Epsilon = 1;
agent = rlDQNAgent(DQNagent,agentOpts);
%% Agent Training
% Training options
trainOpts = rlTrainingOptions(...
'MaxEpisodes', 100, ...
'MaxStepsPerEpisode', 100, ...
'Verbose', true, ...
'Plots','training-progress',...
'ScoreAveragingWindowLength',400,...
'StopTrainingCriteria','AverageSteps',...
'StopTrainingValue',1000000000,...
'SaveAgentDirectory', pwd + "\agents\");
% Agent training
trainingStats = train(agent,env,trainOpts);

採用された回答

praguna manvi
praguna manvi 2024 年 8 月 28 日
編集済み: praguna manvi 2024 年 8 月 29 日
For saving and loading pretrained “DQN” agent, you could use “load” and “save” functions as:
doTraining = false;
if doTraining
% Train the agent.
trainingStats = train(agent,env,trainOpts);
save myagent.mat agent
else
% Load the pretrained agent for the example.
load("MATLABCartpoleDQNMulti.mat","agent")
end
Refer to the example illustrated here:
And to extract weights from the saved agent you can use “getLearnableParameters“ function refer:

その他の回答 (0 件)

カテゴリ

Help Center および File ExchangeReinforcement Learning についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by