離散cartpole環境が正常に学習しない

33 ビュー (過去 30 日間)
ryuuzi
ryuuzi 2024 年 10 月 25 日 11:39
回答済み: Hiro Yoshino 2024 年 11 月 5 日 7:17
「create custom environment from class template」を参考に離散cartpole環境を作成して、強化学習デザイナーにインポートさせてみました。
しかし、学習が安定に収束してくれませんでした。試行錯誤してみましたが、対処法が思いつきませんでした。
教えてください
classdef matlab < rl.env.MATLABEnvironment
properties
% Acceleration due to gravity in m/s^2
Gravity = 9.8
% Mass of the cart
MassCart = 1.0
% Mass of the pole
MassPole = 0.1
% Half the length of the pole
Length = 0.5
% Max Force the input can appy
MaxForce = 10
% Sample time
Ts = 0.02
% Angle at which to fail the episode
ThetaThresholdRadians = 12 * pi/180
% Distance at which to fail the episode
XThreshold = 2.4
% Reward each time step the cart-pole is balanced
RewardForNotFalling = 1
% Penalty when the cart-pole fails to balance
PenaltyForFalling = -5
end
properties
% system state [x,dx,theta,dtheta]'
State = zeros(4,1)
end
properties(Access = protected)
% Internal flag to store stale env that is finished
IsDone = false
end
properties (Transient,Access = private)
Visualizer = []
end
methods
function this = matlab()%ObservationInfo, ActionInfo
ObservationInfo = rlNumericSpec([4 1]);
ObservationInfo.Name = 'CartPole States';
ObservationInfo.Description = 'x, dx, theta, dtheta';
ActionInfo = rlFiniteSetSpec([-1 1]);
ActionInfo.Name = 'CartPole Action';
this = this@rl.env.MATLABEnvironment(ObservationInfo, ActionInfo);
updateActionInfo(this);
end
function set.State(this,state)
validateattributes(state,{'numeric'},{'finite','real','vector','numel',4},'','State');
this.State = double(state(:));
notifyEnvUpdated(this);
end
function set.Length(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Length');
this.Length = val;
notifyEnvUpdated(this);
end
function set.Gravity(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Gravity');
this.Gravity = val;
end
function set.MassCart(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MassCart');
this.MassCart = val;
end
function set.MassPole(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MassPole');
this.MassPole = val;
end
function set.MaxForce(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MaxForce');
this.MaxForce = val;
updateActionInfo(this);
end
function set.Ts(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Ts');
this.Ts = val;
end
function set.ThetaThresholdRadians(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','ThetaThresholdRadians');
this.ThetaThresholdRadians = val;
notifyEnvUpdated(this);
end
function set.XThreshold(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','XThreshold');
this.XThreshold = val;
notifyEnvUpdated(this);
end
function set.RewardForNotFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','RewardForNotFalling');
this.RewardForNotFalling = val;
end
function set.PenaltyForFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','PenaltyForFalling');
this.PenaltyForFalling = val;
end
function [observation,reward,isdone,loggedSignals] = step(this,action)
loggedSignals = [];
% Get action
force = getForce(this,action);
% Unpack state vector
state = this.State;
%x = state(1);
x_dot = state(2);
theta = state(3);
theta_dot = state(4);
% Apply motion equations
costheta = cos(theta);
sintheta = sin(theta);
totalmass = this.MassCart + this.MassPole;
polemasslength = this.MassPole*this.Length;
temp = (force + polemasslength * theta_dot * theta_dot * sintheta) / totalmass;
thetaacc = (this.Gravity * sintheta - costheta* temp) / (this.Length * (4.0/3.0 - this.MassPole * costheta * costheta / totalmass));
xacc = temp - polemasslength * thetaacc * costheta / totalmass;
% Euler integration
observation = state + this.Ts.*[x_dot;xacc;theta_dot;thetaacc];
this.State = observation;
x = observation(1);
theta = observation(3);
isdone = abs(x) > this.XThreshold || abs(theta) > this.ThetaThresholdRadians;
this.IsDone = isdone;
% Get reward
reward = getReward(this,x,force);
end
function initialState = reset(this)
% Randomize the initial pendulum angle between (+- .05 rad)
% Theta (+- .05 rad)
T0 = 2*0.05*rand - 0.05;
% Thetadot
Td0 = 0;
% X
X0 = 0;
% Xdot
Xd0 = 0;
initialState= [X0;Xd0;T0;Td0];
this.State = initialState;
end
function varargout = plot(this)
% Visualizes the environment
if isempty(this.Visualizer) || ~isvalid(this.Visualizer)
this.Visualizer = rl.env.viz.CartPoleVisualizer(this);
else
bringToFront(this.Visualizer);
end
if nargout
varargout{1} = this.Visualizer;
end
end
end
methods (Access = protected)
function force = getForce(this,action)
if ~ismember(action,this.ActionInfo.Elements)
error(message('rl:env:CartPoleDiscreteInvalidAction',sprintf('%g',-this.MaxForce),sprintf('%g',this.MaxForce)));
end
force = action;
end
% update the action info based on max force
function updateActionInfo(this)
this.ActionInfo.Elements = this.MaxForce*[-1 10];
end
function Reward = getReward(this,~,~)
if ~this.IsDone
Reward = this.RewardForNotFalling;
else
Reward = this.PenaltyForFalling;
end
end
end
end

採用された回答

Hiro Yoshino
Hiro Yoshino 2024 年 11 月 5 日 7:17
に離散 cartpole が有るので、動作するものを開いて中身を調べてみると参考になる (答えが有る) かもしれません

その他の回答 (0 件)

カテゴリ

Help Center および File Exchangeビッグ データの処理 についてさらに検索

タグ

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!