フィルターのクリア

How do I get DQN to output the policy I want

15 ビュー (過去 30 日間)
zhou wen
zhou wen 2024 年 5 月 15 日
回答済み: praguna manvi 2024 年 7 月 17 日 8:20
I'm solving a problem with DQN. This environment currently has 10 optional moves, 8 states, and 20 rounds per run. I want to keep my problem variables to a minimum. The optima

回答 (1 件)

praguna manvi
praguna manvi 2024 年 7 月 17 日 8:20
Hi,
Here is a sample code on how you could train a DQN agent with the above input, I am assuming a random “step function” and “reset function” for a simplified example:
% Define your environment
numStates = 8;
numActions = 10;
% Define the observation and action spaces
obsInfo = rlNumericSpec([numStates 1]);
actInfo = rlFiniteSetSpec(1:numActions);
% Create the custom environment
env = rlFunctionEnv(obsInfo, actInfo, @myStepFunction, @myResetFunction);
% Define the DQN agent
statePath = [
featureInputLayer(8, 'Normalization', 'none', 'Name', 'state')
fullyConnectedLayer(24,'Name','fc1')
reluLayer('Name','relu1')
fullyConnectedLayer(24,'Name','fc2')
reluLayer('Name','relu2')
fullyConnectedLayer(numActions,'Name','fc3')];
criticNetwork = dlnetwork(statePath);
criticOpts = rlRepresentationOptions('LearnRate',1e-03,'GradientThreshold',1);
critic = rlQValueRepresentation(criticNetwork,obsInfo,actInfo,...
'Observation',{'state'},criticOpts);
agentOpts = rlDQNAgentOptions(...
'SampleTime',1,...
'DiscountFactor',0.99,...
'ExperienceBufferLength',10000,...
'MiniBatchSize',256);
agent = rlDQNAgent(critic,agentOpts);
% Train the agent
trainOpts = rlTrainingOptions(...
'MaxEpisodes',20,...
'MaxStepsPerEpisode',numStates,...
'Verbose',false,...
'Plots','training-progress');
trainingStats = train(agent,env,trainOpts);
% Define the step function
function [nextObs, reward, isDone, loggedSignals] = myStepFunction(action, loggedSignals)
% step function logic here, calculating the next state
nextObs = randi([1, 8], [8, 1]);
reward = randi([-1, 1]);
isDone = false;
end
% Define the reset function
function [initialObs, loggedSignals] = myResetFunction()
% reset function logic here, I have used a random intial state
initialObs = randi([1, 8], [8, 1]);
loggedSignals = [];
end
For a detailed example please refer to this documentation on training a Custom PG Agent:

カテゴリ

Help Center および File ExchangeSequence and Numeric Feature Data Workflows についてさらに検索

タグ

製品


リリース

R2023a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by