Main Content

このページの内容は最新ではありません。最新版の英語を参照するには、ここをクリックします。

ディープ ネットワーク デザイナーを使用した DQN エージェントの作成およびイメージ観測値を使用した学習

この例では、MATLAB® でモデル化された振子の振り上げと平衡化を行うことができる深層 Q 学習ネットワーク (DQN) エージェントを作成する方法を説明します。この例では、ディープ ネットワーク デザイナーを使用して DQN エージェントを作成します。DQN エージェントの詳細については、Deep Q-Network (DQN) Agents (Reinforcement Learning Toolbox)を参照してください。

イメージ MATLAB 環境を使用した振子の振り上げ

この例では、初期状態で下向きにぶら下がっている摩擦がない単純な振子を、強化学習の環境として使用します。学習の目標は、最小限の制御操作を使用して、振子が倒れず直立した状態を維持することです。

この環境では、次のようにします。

  • 振子が倒立平衡状態となっている位置を 0 ラジアンとし、鉛直下向きとなっている位置を pi ラジアンとする。

  • エージェントから環境へのトルク アクション信号は、–2 ~ 2 N m の範囲にある 5 つの整数のいずれかの値を取れるものとする。

  • 環境からの観測値は、簡略化された振子のグレースケール イメージおよび振子の角度の微分とする。

  • 各タイム ステップで与えられる報酬 rt は次のとおりとする。

rt=-(θt2+0.1θt˙2+0.001ut-12)

ここで、以下となります。

  • θt は直立位置からの変位角。

  • θt˙ は変位角の微分。

  • ut-1 は前のタイム ステップからの制御量。

このモデルの連続行動空間版の詳細については、Train DDPG Agent to Swing Up and Balance Pendulum with Image Observation (Reinforcement Learning Toolbox)を参照してください。

環境インターフェイスの作成

振子用の事前定義された環境インターフェイスを作成します。

env = rlPredefinedEnv("SimplePendulumWithImage-Discrete");

このインターフェイスは 2 つの観測値をもちます。1 つ目の観測値は、"pendImage" という名前で、50 x 50 のグレースケール イメージです。

obsInfo = getObservationInfo(env);
obsInfo(1)
ans = 
  rlNumericSpec with properties:

     LowerLimit: 0
     UpperLimit: 1
           Name: "pendImage"
    Description: [0x0 string]
      Dimension: [50 50]
       DataType: "double"

2 つ目の観測値は、"angularRate" という名前で、振子の角速度です。

obsInfo(2)
ans = 
  rlNumericSpec with properties:

     LowerLimit: -Inf
     UpperLimit: Inf
           Name: "angularRate"
    Description: [0x0 string]
      Dimension: [1 1]
       DataType: "double"

このインターフェイスは、エージェントが 5 つの可能なトルク値 (-2、-1、0、1、または 2 N·m) のいずれかを振子に適用できる、離散行動空間をもちます。

actInfo = getActionInfo(env)
actInfo = 
  rlFiniteSetSpec with properties:

       Elements: [-2 -1 0 1 2]
           Name: "torque"
    Description: [0x0 string]
      Dimension: [1 1]
       DataType: "double"

再現性をもたせるために、乱数発生器のシードを固定します。

rng(0)

ディープ ネットワーク デザイナーを使用したクリティック ネットワークの構築

DQN エージェントは、パラメーター化された Q 値関数近似器を使用して方策の価値を推定します。DQN エージェントは離散行動空間をもちますが、(多出力の) ベクトル Q 値関数クリティックを使用できます。これは通常、単出力クリティックよりも効率的です。ただし、この例では、標準の単出力 Q 値関数クリティックを使用します。

パラメーター化された Q 値関数をクリティック内でモデル化するには、3 つの入力層 (そのうち 2 つは obsInfo で指定された観測チャネル用で、もう 1 つは actInfo で指定されたアクション チャネル用) と 1 つの出力層 (これはスカラー値を返します) をもつニューラル ネットワークを使用します。深層ニューラル ネットワークに基づく Q 値関数表現の作成の詳細については、Create Policies and Value Functions (Reinforcement Learning Toolbox)を参照してください。

ディープ ネットワーク デザイナーアプリを使用して、クリティック ネットワークを対話的に構築します。これを行うには、まず、観測値とアクションごとに個別の入力パスを作成します。これらのパスは、各入力から下位レベルの特徴を学習します。その後、入力パスからの出力を結合する共通の出力パスを作成します。

イメージ観測パスの作成

イメージ観測パスを作成するには、まず、imageInputLayer[層のライブラリ] ペインからキャンバスにドラッグします。層の [InputSize] をイメージ観測用に 50,50,1 と設定し、[Normalization]none に設定します。

次に、convolution2DLayer をキャンバスにドラッグし、この層の入力を imageInputLayer の出力に接続します。2 個のフィルターをもち (NumFilters プロパティ)、高さと幅が 10 である (FilterSize プロパティ) 畳み込み層を作成します。また、水平方向と垂直方向のストライドを 5 とします (Stride プロパティ)。

最後に、reLULayer 層と fullyConnectedLayer 層の組を 2 つ使用して、イメージ パスのネットワークを完成させます。1 番目と 2 番目の fullyConnectedLayer 層の出力サイズは、それぞれ 400 と 300 です。

すべての入力パスと出力パスの作成

同様の方法で、その他の入力パスと出力パスを作成します。この例では、次のオプションを使用します。

角速度パス (スカラー入力):

  • imageInputLayerInputSize1,1 に設定し、Normalizationnone に設定。

  • fullyConnectedLayerOutputSize400 に設定。

  • reLULayer

  • fullyConnectedLayerOutputSize300 に設定。

アクション パス (スカラー入力):

  • imageInputLayerInputSize1,1 に設定し、Normalizationnone に設定。

  • fullyConnectedLayerOutputSize300 に設定。

出力パス:

  • additionLayer — すべての入力パスの出力をこの層の入力に接続。

  • reLULayer

  • fullyConnectedLayerOutputSize1 に設定 (スカラーの価値関数用)。

ディープ ネットワーク デザイナーからのネットワークのエクスポート

ネットワークを MATLAB ワークスペースにエクスポートするには、ディープ ネットワーク デザイナー[エクスポート] をクリックします。ディープ ネットワーク デザイナーは、ネットワーク層を格納する新しい変数としてネットワークをエクスポートします。このネットワーク層変数を使用して、クリティック表現を作成できます。

または、このネットワークと同等の MATLAB コードを生成するために、[エクスポート]、[コード生成] をクリックします。

生成されるコードは次のとおりです。

lgraph = layerGraph();

tempLayers = [
    imageInputLayer([1 1 1],"Name","angularRate","Normalization","none")
    fullyConnectedLayer(400,"Name","dtheta_fc1")
    reluLayer("Name","dtheta_relu1")
    fullyConnectedLayer(300,"Name","dtheta_fc2")];
lgraph = addLayers(lgraph,tempLayers);

tempLayers = [
    imageInputLayer([1 1 1],"Name","torque","Normalization","none")
    fullyConnectedLayer(300,"Name","torque_fc1")];
lgraph = addLayers(lgraph,tempLayers);

tempLayers = [
    imageInputLayer([50 50 1],"Name","pendImage","Normalization","none")
    convolution2dLayer([10 10],2,"Name","img_conv1","Padding","same","Stride",[5 5])
    reluLayer("Name","relu_1")
    fullyConnectedLayer(400,"Name","critic_theta_fc1")
    reluLayer("Name","theta_relu1")
    fullyConnectedLayer(300,"Name","critic_theta_fc2")];
lgraph = addLayers(lgraph,tempLayers);

tempLayers = [
    additionLayer(3,"Name","addition")
    reluLayer("Name","relu_2")
    fullyConnectedLayer(1,"Name","stateValue")];
lgraph = addLayers(lgraph,tempLayers);

lgraph = connectLayers(lgraph,"torque_fc1","addition/in3");
lgraph = connectLayers(lgraph,"critic_theta_fc2","addition/in1");
lgraph = connectLayers(lgraph,"dtheta_fc2","addition/in2");

クリティック ネットワークの構成を表示します。

figure
plot(lgraph)

Figure contains an axes object. The axes object contains an object of type graphplot.

dlnetwork オブジェクトに変換し、パラメーターの数を表示します。

net = dlnetwork(lgraph);
summary(net)
   Initialized: true

   Number of learnables: 322.9k

   Inputs:
      1   'angularRate'   1x1x1 images
      2   'torque'        1x1x1 images
      3   'pendImage'     50x50x1 images

ニューラル ネットワーク、アクション、および観測仕様を使用してクリティックを作成し、観測値とアクションのチャネルに接続する入力層に名前を付けます。詳細については、rlQValueFunction (Reinforcement Learning Toolbox)を参照してください。

critic = rlQValueFunction(net,obsInfo,actInfo,...
    "ObservationInputNames",["pendImage","angularRate"], ...
    "ActionInputNames","torque");

rlOptimizerOptions を使用して、クリティックのオプションを指定します。

criticOpts = rlOptimizerOptions(LearnRate=1e-03,GradientThreshold=1);

rlDQNAgentOptions (Reinforcement Learning Toolbox)を使用して DQN エージェントのオプションを指定します。アクターとクリティックの学習オプションを含めます。

agentOpts = rlDQNAgentOptions(...
    UseDoubleDQN=false,...    
    CriticOptimizerOptions=criticOpts,...
    ExperienceBufferLength=1e6,... 
    SampleTime=env.Ts);

ドット表記を使用してエージェントのオプションを設定または変更することもできます。

agentOpts.EpsilonGreedyExploration.EpsilonDecay = 1e-5;

または、最初にエージェントを作成してから、ドット表記を使用してそのオプションを変更することもできます。

クリティックとエージェントのオプションのオブジェクトを使用して、DQN エージェントを作成します。詳細については、rlDQNAgent (Reinforcement Learning Toolbox)を参照してください。

agent = rlDQNAgent(critic,agentOpts);

エージェントの学習

エージェントに学習させるには、まず、学習オプションを指定します。この例では、次のオプションを使用します。

  • 最大 5000 個のエピソードについて、それぞれ学習を実行 (各エピソードは最大 500 タイム ステップ持続)。

  • Episode Manager のダイアログ ボックスに学習の進行状況を表示 (Plots オプションを設定) し、コマンド ラインの表示を無効化 (Verbose オプションを false に設定)。

  • 連続する 5 個のエピソードの既定のウィンドウ長について、-1000 を超える平均累積報酬をエージェントが受け取ったときに学習を停止。この時点で、エージェントは最小限の制御操作を使用して、振子を直立位置で素早く平衡化できるようになります。

詳細については、rlTrainingOptions (Reinforcement Learning Toolbox)を参照してください。

trainOpts = rlTrainingOptions(...
    MaxEpisodes=5000,...
    MaxStepsPerEpisode=500,...
    Verbose=false,...
    Plots="training-progress",...
    StopTrainingCriteria="AverageReward",...
    StopTrainingValue=-1000);

学習中またはシミュレーション中に、関数 plot を使用して振子系を可視化します。

plot(env)

Figure Simple Pendulum Visualizer contains 2 axes objects. Axes object 1 contains 2 objects of type line, rectangle. Axes object 2 contains an object of type image.

関数 train (Reinforcement Learning Toolbox) を使用して、エージェントに学習させます。これは計算量が多いプロセスのため、完了するのに数時間かかります。この例の実行時間を節約するために、doTrainingfalse に設定して事前学習済みのエージェントを読み込みます。エージェントに学習させるには、doTrainingtrue に設定します。

doTraining = false;

if doTraining
    % Train the agent.
    trainingStats = train(agent,env,trainOpts);
else
    % Load pretrained agent for the example.
    load("MATLABPendImageDQN.mat","agent");
end

DQN エージェントのシミュレーション

学習済みエージェントの性能を検証するには、振子環境内でこれをシミュレートします。エージェントのシミュレーションの詳細については、rlSimulationOptions (Reinforcement Learning Toolbox) および sim (Reinforcement Learning Toolbox) を参照してください。

simOptions = rlSimulationOptions(MaxSteps=500);
experience = sim(env,agent,simOptions);

Figure Simple Pendulum Visualizer contains 2 axes objects. Axes object 1 contains 2 objects of type line, rectangle. Axes object 2 contains an object of type image.

totalReward = sum(experience.Reward)
totalReward = -713.0336

参考

| (Reinforcement Learning Toolbox)

関連するトピック