Main Content

振子の振り上げと平衡化のための、イメージ観測を使用した DDPG エージェントの学習

この例では、MATLAB® でモデル化されたイメージ観測を使用して、振子の振り上げと平衡化を行うように深層決定論的方策勾配 (DDPG) エージェントに学習させる方法を説明します。

DDPG エージェントの詳細については、Deep Deterministic Policy Gradient (DDPG) Agents (Reinforcement Learning Toolbox)を参照してください。

イメージ MATLAB 環境を使用した単純な振子

この例では、初期状態で下向きにぶら下がっている摩擦がない単純な振子を、強化学習の環境として使用します。学習の目標は、最小限の制御操作を使用して、振子が倒れず直立した状態を維持することです。

この環境では、次のようにします。

  • 振子が倒立平衡状態となっている位置を 0 ラジアンとし、鉛直下向きとなっている位置を pi ラジアンとする。

  • エージェントから環境へのトルク アクション信号は、–2 ~ 2 N m とする。

  • 環境からの観測値は、振子の質量の位置および振子の角速度を表すイメージとする。

  • 各タイム ステップで与えられる報酬 rt は次のとおりとする。

rt=-(θt2+0.1θt˙2+0.001ut-12)

ここで、以下となります。

  • θt は直立位置からの変位角。

  • θt˙ は変位角の微分。

  • ut-1 は前のタイム ステップからの制御量。

このモデルの詳細については、Load Predefined Control System Environments (Reinforcement Learning Toolbox)を参照してください。

環境インターフェイスの作成

振子用の事前定義された環境インターフェイスを作成します。

env = rlPredefinedEnv("SimplePendulumWithImage-Continuous")
env = 
  SimplePendlumWithImageContinuousAction with properties:

             Mass: 1
        RodLength: 1
       RodInertia: 0
          Gravity: 9.8100
     DampingRatio: 0
    MaximumTorque: 2
               Ts: 0.0500
            State: [2×1 double]
                Q: [2×2 double]
                R: 1.0000e-03

このインターフェイスは、エージェントが –2 ~ 2 N m のトルクを適用できる、連続したアクション空間をもちます。

環境インターフェイスから観測値とアクションの仕様を取得します。

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

再現性をもたせるために、乱数発生器のシードを固定します。

rng(0)

DDPG エージェントの作成

DDPG エージェントは、パラメーター化された Q 値関数近似器を使用して方策の価値を推定します。Q 値関数クリティックは、現在の観測値とアクションを入力として取り、単一のスカラーを出力として返します (状態からのアクションの受け取りに関する割引累積長期報酬の推定値は、現在の観測値に対応し、その後の方策に従います)。

パラメーター化された Q 値関数をクリティック内でモデル化するには、3 つの入力層 (obsInfo で指定された観測チャネル用に 1 つずつ使用し、もう 1 つは actInfo で指定されたアクション チャネル用に使用) と 1 つの出力層 (これはスカラー値を返します) をもつ畳み込みニューラル ネットワーク (CNN) を使用します。

各ネットワーク パスを layer オブジェクトの配列として定義します。また、各パスの入力層と出力層、および追加層と連結層に名前を割り当てます。これらの名前を使用すると、パスを接続してから、ネットワークの入力層と出力層に適切な環境チャネルを明示的に関連付けることができます。表現の作成の詳細については、Create Policies and Value Functions (Reinforcement Learning Toolbox)を参照してください。

hiddenLayerSize1 = 256;
hiddenLayerSize2 = 256;

% Image input path
imgPath = [
    imageInputLayer(obsInfo(1).Dimension, ...
        Name="imgInLyr")
    convolution2dLayer(5,8,Stride=3,Padding=0)
    reluLayer
    convolution2dLayer(5,8,Stride=3,Padding=0)
    reluLayer    
    fullyConnectedLayer(32)
    concatenationLayer(1,2,Name="cat1")
    fullyConnectedLayer(hiddenLayerSize1)
    reluLayer
    fullyConnectedLayer(hiddenLayerSize2)
    additionLayer(2,Name="add")
    reluLayer
    fullyConnectedLayer(1,Name="fc4")
    ];

% d(theta)/dt input path
dthPath = [
    featureInputLayer(prod(obsInfo(2).Dimension), ...
        Name="dthInLyr")
    fullyConnectedLayer(1,Name="fc5", ...
        BiasLearnRateFactor=0, ...
        Bias=0)
    ];

% Action path
actPath =[
    featureInputLayer(prod(obsInfo(2).Dimension), ...
        Name="actInLyr")
    fullyConnectedLayer(hiddenLayerSize2, ...
        Name="fc6", ...
        BiasLearnRateFactor=0, ...
        Bias=zeros(hiddenLayerSize2,1))
    ];

dlnetwork オブジェクトを組み立てます。

criticNetwork = dlnetwork();
criticNetwork = addLayers(criticNetwork,imgPath);
criticNetwork = addLayers(criticNetwork,dthPath);
criticNetwork = addLayers(criticNetwork,actPath);
criticNetwork = connectLayers(criticNetwork,"fc5","cat1/in2");
criticNetwork = connectLayers(criticNetwork,"fc6","add/in2");

クリティック ネットワークの構成を表示し、パラメーターの数を表示します。

plot(criticNetwork)

Figure contains an axes object. The axes object contains an object of type graphplot.

ネットワークを初期化し、パラメーターの数を表示します。

criticNetwork = initialize(criticNetwork);
summary(criticNetwork)
   Initialized: true

   Number of learnables: 81.2k

   Inputs:
      1   'imgInLyr'   50×50×1 images
      2   'dthInLyr'   1 features
      3   'actInLyr'   1 features

指定したニューラル ネットワーク、および環境のアクション仕様と観測仕様を使用して、クリティックを作成します。追加の引数として、観測チャネルとアクション チャネルに接続するネットワーク層の名前も渡します。詳細については、rlQValueFunction (Reinforcement Learning Toolbox)を参照してください。

critic = rlQValueFunction(criticNetwork, ...
    obsInfo,actInfo,...
    ObservationInputNames=["imgInLyr","dthInLyr"], ...
    ActionInputNames="actInLyr");

DDPG エージェントは、連続行動空間において、パラメーター化された決定論的方策を使用します。この方策は、連続決定論的アクターによって実装されます。連続決定論的アクターは、連続行動空間に関するパラメーター化された決定論的方策を実装します。このアクターは、現在の観測値を入力として取り、観測値の決定論的関数であるアクションを出力として返します。

パラメーター化された方策をアクター内でモデル化するには、2 つの入力層 (これは、obsInfo で指定された、2 つの環境観測チャネルのコンテンツを受け取ります) と 1 つの出力層 (これは、actInfo で指定された、環境アクション チャネルへのアクションを返します) をもつニューラル ネットワークを使用します。

ネットワークを layer オブジェクトの配列として定義します。

% Image input path
imgPath = [
    imageInputLayer(obsInfo(1).Dimension, ...        
        Name="imgInLyr")
    convolution2dLayer(5,8,Stride=3,Padding=0)
    reluLayer
    convolution2dLayer(5,8,Stride=3,Padding=0)
    reluLayer    
    fullyConnectedLayer(32)
    concatenationLayer(1,2,Name="cat1")
    fullyConnectedLayer(hiddenLayerSize1)
    reluLayer
    fullyConnectedLayer(hiddenLayerSize2)
    reluLayer
    fullyConnectedLayer(1)
    tanhLayer
    scalingLayer(Name="scale1", ...
        Scale=max(actInfo.UpperLimit))
    ];

% d(theta)/dt input layer
dthPath = [
    featureInputLayer(prod(obsInfo(2).Dimension), ...
        Name="dthInLyr")
    fullyConnectedLayer(1, ...
        Name="fc5", ...
        BiasLearnRateFactor=0, ...
        Bias=0) 
    ];

dlnetwork オブジェクトを組み立てます。

actorNetwork = dlnetwork();
actorNetwork = addLayers(actorNetwork,imgPath);
actorNetwork = addLayers(actorNetwork,dthPath);
actorNetwork = connectLayers(actorNetwork,"fc5","cat1/in2");

アクター ネットワークの構成を表示し、重みの数を表示します。

figure
plot(actorNetwork)

Figure contains an axes object. The axes object contains an object of type graphplot.

ネットワークを初期化し、パラメーターの数を表示します。

actorNetwork = initialize(actorNetwork);
summary(actorNetwork)
   Initialized: true

   Number of learnables: 80.6k

   Inputs:
      1   'imgInLyr'   50×50×1 images
      2   'dthInLyr'   1 features

指定したニューラル ネットワークを使用して、アクターを作成します。詳細については、rlContinuousDeterministicActor (Reinforcement Learning Toolbox)を参照してください。

actor = rlContinuousDeterministicActor(actorNetwork, ...
    obsInfo,actInfo, ...
    ObservationInputNames=["imgInLyr","dthInLyr"]);

rlOptimizerOptions (Reinforcement Learning Toolbox)を使用して、アクターとクリティックのオプションを指定します。

criticOptions = rlOptimizerOptions( ...
    LearnRate=1e-03, ...
    GradientThreshold=1);
actorOptions = rlOptimizerOptions( ...
    LearnRate=1e-04, ...
    GradientThreshold=1);

GPU を使用する学習のパフォーマンスは、バッチ サイズ、ネットワーク構造、およびハードウェアそのものによって左右されます。そのため、GPU を使用しても学習パフォーマンスが必ずしも向上するとは限りません。サポートされている GPU の詳細については、GPU 計算の要件 (Parallel Computing Toolbox)を参照してください。

GPU を使用してアクターに学習させるため、UseGPUCritic true に設定します。

UseGPUCritic = false;
if canUseGPU && UseGPUCritic    
    critic.UseDevice = "gpu";
end

GPU を使用してアクターに学習させるため、UseGPUActor true に設定します。

UseGPUActor = false;
if canUseGPU && UseGPUActor    
    actor.UseDevice = "gpu";
end

再現性をもたせるために、GPU で使用される乱数発生器のシードを固定します。

if canUseGPU && (UseGPUCritic || UseGPUActor)
    gpurng(0)
end

rlDDPGAgentOptions (Reinforcement Learning Toolbox)を使用して DDPG エージェントのオプションを指定します。

agentOptions = rlDDPGAgentOptions(...
    SampleTime=env.Ts,...
    TargetSmoothFactor=1e-3,...
    ExperienceBufferLength=1e6,...
    DiscountFactor=0.99,...
    MiniBatchSize=128);

ドット表記を使用してオプションを指定することもできます。

agentOptions.NoiseOptions.StandardDeviation = 0.6;
agentOptions.NoiseOptions.StandardDeviationDecayRate = 1e-6;
agentOptions.NoiseOptions.StandardDeviationMin = 0.1;

関数近似器オブジェクトの学習オプションを指定します。

agentOptions.CriticOptimizerOptions = criticOptions;
agentOptions.ActorOptimizerOptions = actorOptions;

次に、指定したアクター、クリティック、およびエージェント オプションを使用してエージェントを作成します。詳細については、rlDDPGAgent (Reinforcement Learning Toolbox)を参照してください。

agent = rlDDPGAgent(actor,critic,agentOptions);

エージェントの学習

エージェントに学習させるには、まず、学習オプションを指定します。この例では、次のオプションを使用します。

  • 最大 5000 個のエピソードについて、それぞれ学習を実行 (各エピソードは最大 400 タイム ステップ持続)。

  • [強化学習の学習モニター] ダイアログ ボックスに学習の進行状況を表示 (Plots オプションを設定)。

  • エージェントが -740 よりも大きい評価統計を受け取ったときに学習を停止。この時点で、エージェントは最小限の制御操作を使用して、振子を直立位置で素早く平衡化できるようになります。

詳細については、rlTrainingOptions (Reinforcement Learning Toolbox)を参照してください。

maxepisodes = 5000;
maxsteps = 400;
trainingOptions = rlTrainingOptions(...
    MaxEpisodes=maxepisodes,...
    MaxStepsPerEpisode=maxsteps,...
    Plots="training-progress",...
    StopTrainingCriteria="EvaluationStatistic",...
    StopTrainingValue=-740);

エピソードを 50 個学習するたびにエージェントを評価する評価器を作成します。

evaluator = rlEvaluator(EvaluationFrequency=50, NumEpisodes=1);

学習中またはシミュレーション中に、関数 plot を使用して振子を可視化できます。

plot(env)

Figure Simple Pendulum Visualizer contains 2 axes objects. Axes object 1 contains 2 objects of type line, rectangle. Axes object 2 contains an object of type image.

関数 train (Reinforcement Learning Toolbox) を使用して、エージェントに学習させます。このエージェントの学習は計算量が多いプロセスのため、完了するのに数時間かかります。この例の実行時間を節約するために、doTrainingfalse に設定して事前学習済みのエージェントを読み込みます。エージェントに学習させるには、doTrainingtrue に設定します。

doTraining = false;
if doTraining    
    % Train the agent.
    trainingStats = train(agent,env,trainingOptions, Evaluator=evaluator);
else
    % Load pretrained agent for the example.
    load("SimplePendulumWithImageDDPG.mat","agent")       
end

DDPG エージェントのシミュレーション

学習済みエージェントの性能を検証するには、振子環境内でこれをシミュレートします。エージェントのシミュレーションの詳細については、rlSimulationOptions (Reinforcement Learning Toolbox) および sim (Reinforcement Learning Toolbox) を参照してください。

rng(1); % For reproducibility
simOptions = rlSimulationOptions(MaxSteps=500);
experience = sim(env,agent,simOptions);

Figure Simple Pendulum Visualizer contains 2 axes objects. Axes object 1 contains 2 objects of type line, rectangle. Axes object 2 contains an object of type image.

参考

(Reinforcement Learning Toolbox)

関連するトピック