カートポール システムの平衡化のための DQN エージェントの学習
この例では、MATLAB® でモデル化されたカートポール システムの平衡化を行うために深層 Q 学習ネットワーク (DQN) エージェントに学習させる方法を示します。
DQN エージェントの詳細については、深層 Q ネットワーク (DQN) エージェントを参照してください。Simulink® で DQN エージェントに学習させる例については、振子の振り上げと平衡化のための DQN エージェントの学習を参照してください。
カートポール MATLAB 環境
この例の強化学習環境は、摩擦のないトラックに沿って移動するカート上の非駆動ジョイントに取り付けられたポールです。学習の目標は、ポールが倒れず直立した状態を維持することです。
この環境では、次のようにします。
ポールが倒立平衡状態となっている位置を 0 ラジアンとし、鉛直下向きとなっている位置を
pi
ラジアンとする。ポールは、-0.05 ~ 0.05 ラジアンの初期角度で直立して開始する。
エージェントから環境への力のアクション信号は、-10 または 10 N とする。
環境からの観測値は、カートの位置と速度、ポールの角度、およびポールの角度の微分とする。
ポールが垂直から 12 度を超えるか、カートが元の位置から 2.4 m を超えて移動した場合、エピソードの終了とする。
ポールが直立を保っているすべてのタイム ステップに対し、+1 の報酬が提供される。ポールが落下すると、-5 のペナルティが適用される。
このモデルの詳細については、Load Predefined Control System Environmentsを参照してください。
環境インターフェイスの作成
システム用の事前定義された環境インターフェイスを作成します。
env = rlPredefinedEnv("CartPole-Discrete")
env = CartPoleDiscreteAction with properties: Gravity: 9.8000 MassCart: 1 MassPole: 0.1000 Length: 0.5000 MaxForce: 10 Ts: 0.0200 ThetaThresholdRadians: 0.2094 XThreshold: 2.4000 RewardForNotFalling: 1 PenaltyForFalling: -5 State: [4x1 double]
このインターフェイスは、エージェントが 2 つの有効な力の値 (-10 または 10 N) のいずれかをカートに適用できる離散行動空間をもちます。
観測仕様とアクション仕様の情報を取得します。
obsInfo = getObservationInfo(env)
obsInfo = rlNumericSpec with properties: LowerLimit: -Inf UpperLimit: Inf Name: "CartPole States" Description: "x, dx, theta, dtheta" Dimension: [4 1] DataType: "double"
actInfo = getActionInfo(env)
actInfo = rlFiniteSetSpec with properties: Elements: [-10 10] Name: "CartPole Action" Description: [0x0 string] Dimension: [1 1] DataType: "double"
再現性をもたせるために、乱数発生器のシードを固定します。
rng(0)
DQN エージェントの作成
DQN エージェントは、ベクトル Q 値関数クリティックを使用できます。これは通常、同等の単出力クリティックよりも効率的です。ベクトル Q 値関数クリティックは、入力として観測値をもち、出力として状態アクション値をもちます。各出力要素は、観測入力によって示された状態から対応する離散行動を実行した場合の累積長期報酬の期待値を表します。価値関数の作成の詳細については、Create Policies and Value Functionsを参照してください。
クリティック内で Q 値関数を近似するには、1 つの入力チャネル (観測された 4 次元の状態のベクトル) と 2 つの要素 (1 つは 10 N アクション用、もう 1 つは –10 N アクション用) から成る 1 つの出力チャネルをもつニューラル ネットワークを使用します。ネットワークを layer オブジェクトの配列として定義し、環境仕様オブジェクトから観測空間の次元と可能なアクションの数を取得します。
net = [ featureInputLayer(obsInfo.Dimension(1)) fullyConnectedLayer(20) reluLayer fullyConnectedLayer(length(actInfo.Elements)) ];
dlnetwork
に変換し、重みの数を表示します。
net = dlnetwork(net); summary(net)
Initialized: true Number of learnables: 142 Inputs: 1 'input' 4 features
ネットワーク構成を表示します。
plot(net)
net
と環境仕様を使用して、クリティック近似器を作成します。詳細については、rlVectorQValueFunction
を参照してください。
critic = rlVectorQValueFunction(net,obsInfo,actInfo);
ランダムな観測入力を使用して、クリティックをチェックします。
getValue(critic,{rand(obsInfo.Dimension)})
ans = 2x1 single column vector
-0.2257
0.4299
critic
を使用して、DQN エージェントを作成します。詳細については、rlDQNAgent
を参照してください。
agent = rlDQNAgent(critic);
ランダムな観測入力を使用して、エージェントをチェックします。
getAction(agent,{rand(obsInfo.Dimension)})
ans = 1x1 cell array
{[10]}
クリティックの学習オプションを含む、DQN エージェントのオプションを指定します。あるいは、rlDQNAgentOptions
オブジェクトとrlOptimizerOptions
オブジェクトを使用することもできます。
agent.AgentOptions.UseDoubleDQN = false; agent.AgentOptions.TargetSmoothFactor = 1; agent.AgentOptions.TargetUpdateFrequency = 4; agent.AgentOptions.ExperienceBufferLength = 1e5; agent.AgentOptions.MiniBatchSize = 256; agent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3; agent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1;
エージェントの学習
エージェントに学習させるには、まず、学習オプションを指定します。この例では、次のオプションを使用します。
最大 1000 個のエピソードを含む 1 つの学習セッションを実行 (各エピソードは最大 500 タイム ステップ持続)。
[強化学習の学習モニター] ダイアログ ボックスに学習の進行状況を表示し (
Plots
オプションを設定)、コマンド ラインの表示を無効化 (Verbose
オプションをfalse
に設定)。480 を超える移動平均累積報酬をエージェントが受け取ったときに学習を停止。この時点で、エージェントはカートポール システムを直立位置で平衡化できるようになります。
詳細については、rlTrainingOptions
を参照してください。
trainOpts = rlTrainingOptions(... MaxEpisodes=1000, ... MaxStepsPerEpisode=500, ... Verbose=false, ... Plots="training-progress",... StopTrainingCriteria="AverageReward",... StopTrainingValue=480);
学習中またはシミュレーション中に、関数 plot
を使用してカートポール システムを可視化できます。
plot(env)
関数 train
を使用して、エージェントに学習させます。このエージェントの学習は計算量が多いプロセスのため、完了するのに数分かかります。この例の実行時間を節約するために、doTraining
を false
に設定して事前学習済みのエージェントを読み込みます。エージェントに学習させるには、doTraining
を true
に設定します。
doTraining = false; if doTraining % Train the agent. trainingStats = train(agent,env,trainOpts); else % Load the pretrained agent for the example. load("MATLABCartpoleDQNMulti.mat","agent") end
DQN エージェントのシミュレーション
学習済みエージェントの性能を検証するには、カートポール環境内でエージェントをシミュレーションします。エージェントのシミュレーションの詳細については、rlSimulationOptions
および sim
を参照してください。シミュレーション時間を 500 ステップに増やしても、エージェントはカートポールを平衡化することができています。
simOptions = rlSimulationOptions(MaxSteps=500); experience = sim(env,agent,simOptions);
totalReward = sum(experience.Reward)
totalReward = 500
参考
アプリ
関数
オブジェクト
関連する例
- 振子の振り上げと平衡化のための DQN エージェントの学習
- Train PG Agent to Balance Cart-Pole System
- Train Reinforcement Learning Agents