Main Content

カートポール システムの平衡化のための DQN エージェントの学習

この例では、MATLAB® でモデル化されたカートポール システムの平衡化を行うために深層 Q 学習ネットワーク (DQN) エージェントに学習させる方法を示します。

DQN エージェントの詳細については、深層 Q ネットワーク (DQN) エージェントを参照してください。Simulink® で DQN エージェントに学習させる例については、振子の振り上げと平衡化のための DQN エージェントの学習を参照してください。

カートポール MATLAB 環境

この例の強化学習環境は、摩擦のないトラックに沿って移動するカート上の非駆動ジョイントに取り付けられたポールです。学習の目標は、ポールが倒れず直立した状態を維持することです。

この環境では、次のようにします。

  • ポールが倒立平衡状態となっている位置を 0 ラジアンとし、鉛直下向きとなっている位置を pi ラジアンとする。

  • ポールは、-0.05 ~ 0.05 ラジアンの初期角度で直立して開始する。

  • エージェントから環境への力のアクション信号は、-10 または 10 N とする。

  • 環境からの観測値は、カートの位置と速度、ポールの角度、およびポールの角度の微分とする。

  • ポールが垂直から 12 度を超えるか、カートが元の位置から 2.4 m を超えて移動した場合、エピソードの終了とする。

  • ポールが直立を保っているすべてのタイム ステップに対し、+1 の報酬が提供される。ポールが落下すると、-5 のペナルティが適用される。

このモデルの詳細については、Load Predefined Control System Environmentsを参照してください。

環境インターフェイスの作成

システム用の事前定義された環境インターフェイスを作成します。

env = rlPredefinedEnv("CartPole-Discrete")
env = 
  CartPoleDiscreteAction with properties:

                  Gravity: 9.8000
                 MassCart: 1
                 MassPole: 0.1000
                   Length: 0.5000
                 MaxForce: 10
                       Ts: 0.0200
    ThetaThresholdRadians: 0.2094
               XThreshold: 2.4000
      RewardForNotFalling: 1
        PenaltyForFalling: -5
                    State: [4x1 double]

このインターフェイスは、エージェントが 2 つの有効な力の値 (-10 または 10 N) のいずれかをカートに適用できる離散行動空間をもちます。

観測仕様とアクション仕様の情報を取得します。

obsInfo = getObservationInfo(env)
obsInfo = 
  rlNumericSpec with properties:

     LowerLimit: -Inf
     UpperLimit: Inf
           Name: "CartPole States"
    Description: "x, dx, theta, dtheta"
      Dimension: [4 1]
       DataType: "double"

actInfo = getActionInfo(env)
actInfo = 
  rlFiniteSetSpec with properties:

       Elements: [-10 10]
           Name: "CartPole Action"
    Description: [0x0 string]
      Dimension: [1 1]
       DataType: "double"

再現性をもたせるために、乱数発生器のシードを固定します。

rng(0)

DQN エージェントの作成

DQN エージェントは、ベクトル Q 値関数クリティックを使用できます。これは通常、同等の単出力クリティックよりも効率的です。ベクトル Q 値関数クリティックは、入力として観測値をもち、出力として状態アクション値をもちます。各出力要素は、観測入力によって示された状態から対応する離散行動を実行した場合の累積長期報酬の期待値を表します。価値関数の作成の詳細については、Create Policies and Value Functionsを参照してください。

クリティック内で Q 値関数を近似するには、1 つの入力チャネル (観測された 4 次元の状態のベクトル) と 2 つの要素 (1 つは 10 N アクション用、もう 1 つは –10 N アクション用) から成る 1 つの出力チャネルをもつニューラル ネットワークを使用します。ネットワークを layer オブジェクトの配列として定義し、環境仕様オブジェクトから観測空間の次元と可能なアクションの数を取得します。

net = [
    featureInputLayer(obsInfo.Dimension(1))
    fullyConnectedLayer(20)
    reluLayer
    fullyConnectedLayer(length(actInfo.Elements))
    ];

dlnetwork に変換し、重みの数を表示します。

net = dlnetwork(net);
summary(net)
   Initialized: true

   Number of learnables: 142

   Inputs:
      1   'input'   4 features

ネットワーク構成を表示します。

plot(net)

net と環境仕様を使用して、クリティック近似器を作成します。詳細については、rlVectorQValueFunctionを参照してください。

critic = rlVectorQValueFunction(net,obsInfo,actInfo);

ランダムな観測入力を使用して、クリティックをチェックします。

getValue(critic,{rand(obsInfo.Dimension)})
ans = 2x1 single column vector

   -0.2257
    0.4299

critic を使用して、DQN エージェントを作成します。詳細については、rlDQNAgentを参照してください。

agent = rlDQNAgent(critic);

ランダムな観測入力を使用して、エージェントをチェックします。

getAction(agent,{rand(obsInfo.Dimension)})
ans = 1x1 cell array
    {[10]}

クリティックの学習オプションを含む、DQN エージェントのオプションを指定します。あるいは、rlDQNAgentOptionsオブジェクトとrlOptimizerOptionsオブジェクトを使用することもできます。

agent.AgentOptions.UseDoubleDQN = false;
agent.AgentOptions.TargetSmoothFactor = 1;
agent.AgentOptions.TargetUpdateFrequency = 4;
agent.AgentOptions.ExperienceBufferLength = 1e5;
agent.AgentOptions.MiniBatchSize = 256;
agent.AgentOptions.CriticOptimizerOptions.LearnRate = 1e-3;
agent.AgentOptions.CriticOptimizerOptions.GradientThreshold = 1;

エージェントの学習

エージェントに学習させるには、まず、学習オプションを指定します。この例では、次のオプションを使用します。

  • 最大 1000 個のエピソードを含む 1 つの学習セッションを実行 (各エピソードは最大 500 タイム ステップ持続)。

  • [強化学習の学習モニター] ダイアログ ボックスに学習の進行状況を表示し (Plots オプションを設定)、コマンド ラインの表示を無効化 (Verbose オプションを false に設定)。

  • 480 を超える移動平均累積報酬をエージェントが受け取ったときに学習を停止。この時点で、エージェントはカートポール システムを直立位置で平衡化できるようになります。

詳細については、rlTrainingOptionsを参照してください。

trainOpts = rlTrainingOptions(...
    MaxEpisodes=1000, ...
    MaxStepsPerEpisode=500, ...
    Verbose=false, ...
    Plots="training-progress",...
    StopTrainingCriteria="AverageReward",...
    StopTrainingValue=480); 

学習中またはシミュレーション中に、関数 plot を使用してカートポール システムを可視化できます。

plot(env)

関数 train を使用して、エージェントに学習させます。このエージェントの学習は計算量が多いプロセスのため、完了するのに数分かかります。この例の実行時間を節約するために、doTrainingfalse に設定して事前学習済みのエージェントを読み込みます。エージェントに学習させるには、doTrainingtrue に設定します。

doTraining = false;
if doTraining
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
else
    % Load the pretrained agent for the example.
    load("MATLABCartpoleDQNMulti.mat","agent")
end

DQN エージェントのシミュレーション

学習済みエージェントの性能を検証するには、カートポール環境内でエージェントをシミュレーションします。エージェントのシミュレーションの詳細については、rlSimulationOptions および sim を参照してください。シミュレーション時間を 500 ステップに増やしても、エージェントはカートポールを平衡化することができています。

simOptions = rlSimulationOptions(MaxSteps=500);
experience = sim(env,agent,simOptions);

totalReward = sum(experience.Reward)
totalReward = 500

参考

アプリ

関数

オブジェクト

関連する例

詳細