rlVectorQValueFunction with a custom Network for a DPQ agent
2 ビュー (過去 30 日間)
古いコメントを表示
Delprat Sebastien
2023 年 8 月 9 日
コメント済み: Delprat Sebastien
2023 年 8 月 25 日
Question in short: how to design a DQN agent that uses a rlVectorQValueFunction critic based on a custom dlnetwork for an environement with observations that consists in an image and some numeric features.
The documentation of the RL toolbox needs a lot of addition to cover examples that are an epsilon more complex than the vanila ones.
The following code triggers the error. Below, my analysis so far.
Thanks for your help.
NB: I'm using 2022b, in case that matters
% My problem is for a custom environement, but here is a 1 line code to get
% a working environement
env = rlPredefinedEnv('SimplePendulumWithImage-Discrete');
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);
% For brievety, get a working network (instead of my own code, keep in mind that I do not want
% to use the network created automatically, but the network I deseigned myself)
% So here is a way to get a working network in 3 lines of code
initOpts = rlAgentInitializationOptions(NumHiddenUnit=64);
agent = rlDQNAgent(obsInfo,actInfo,initOpts);
criticNet = getModel(getCritic(agent)); % Here I get a network example
analyzeNetwork(criticNet); % Network is fine and consistent with env
Now the fun part begin
% This does not work
critic = rlVectorQValueFunction(criticNet,obsInfo,actInfo);
Error using rl.internal.validate.mapNetworkInput
Model input sizes must match the dimensions specified in the corresponding observation and action info specifications.
Error in rl.internal.validate.mapFunctionObservationInput (line 24)
inputOrder = rl.internal.validate.mapNetworkInput(specDim,networkInputSize);
Error in rlVectorQValueFunction (line 101)
modelInputMap = rl.internal.validate.mapFunctionObservationInput(model,observationInfo,nameValueArgs.ObservationInputNames);
Error in exe2 (line 19)
critic = rlVectorQValueFunction(criticNet,obsInfo,actInfo);
We can see the network diagram: Input is 50x50x1x1, which is correct. Note that the image input layer add 1 dimension for chanel...
Let me emphasize that the doc of rlVectorQValueFunction only provides an example with a network with only feature input and it does not provide any information about the data dimension, and specifically how to manage the chanel dimension in data image:
"The network must have only the observation channels as inputs and a single output layer having as many elements as the number of possible discrete actions. Each element of the output vector approximates the value of executing the corresponding action starting from the currently observed state."
We can check that the network has the correct size:
>> obsInfo(1)
ans =
rlNumericSpec with properties:
LowerLimit: 0
UpperLimit: 1
Name: "pendImage"
Description: [0×0 string]
Dimension: [50 50]
DataType: "double"
>> obsInfo(2)
ans =
rlNumericSpec with properties:
LowerLimit: -Inf
UpperLimit: Inf
Name: "angularRate"
Description: [0×0 string]
Dimension: [1 1]
DataType: "double"
Where is the problem ?
I debug the rlVectorQValueFunction up to the point that it check the size of the network input data (50x50x1) versus the observation size (50x50). As the input image layer always add 1 dimension for chanel, the number of dimensions does not match.
The problems is in mapFunctionObservationInput (line 24),
inputOrder = rl.internal.validate.mapNetworkInput(specDim,networkInputSize);
specDim =
1×2 cell array
{[50 50]} {[1 1]}
K>> networkInputSize
networkInputSize =
1×2 cell array
{[50 50 1]} {[1]}
So it is clear that the network input size has a chanel dimension that is not expected....
Where does the documentation provide any explanation about that ???? How to fix this issue?
0 件のコメント
採用された回答
Ayush Aniket
2023 年 8 月 25 日
編集済み: Ayush Aniket
2023 年 8 月 25 日
The error as shown by you arises due to the fact that the environment has more than one observations. A workaround for this is to use the second syntax for creating the critic in the following documentation page:
The code is as follows:
% This works
criticNet.summary %used for getting names of the input layers
critic = rlVectorQValueFunction(criticNet,obsInfo,actInfo,"ObservationInputNames",["input_1","input_2"]);
You can find the following documentation page useful:
It shows how to construct dlnetwork for image inputs. The example uses rlQValueFunction rather than rlVectorQValueFunction but the difference is only in the additional dlnetwork for action inputs, the rest remains the same. Look at the generated code by the deep network designer.
Hope it helps.
その他の回答 (0 件)
参考
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!