メインコンテンツ

getCritic

強化学習エージェントからのクリティックの抽出

説明

critic = getCritic(agent) は、指定された強化学習エージェントから取得したクリティック オブジェクトを返します。

すべて折りたたむ

既存の学習済み強化学習エージェントがあると仮定します。この例では、Compare DDPG Agent to LQR Controllerの学習済みエージェントを読み込みます。

load("DoubleIntegDDPG.mat","agent") 

エージェントからクリティックを取得します。

critic = getCritic(agent);

近似器オブジェクトでは、ドット表記を使用して Learnables プロパティにアクセスできます。

まず、パラメーターを表示します。

critic.Learnables{1}
ans = 
  1×6 single dlarray

   -5.0017   -1.5513   -0.3424   -0.1116   -0.0506   -0.0047

パラメーター値を変更します。この例では、すべてのパラメーターを単純に 2 倍します。

critic.Learnables{1} = critic.Learnables{1}*2;

新しいパラメーターを表示します。

critic.Learnables{1}
ans = 
  1×6 single dlarray

  -10.0034   -3.1026   -0.6848   -0.2232   -0.1011   -0.0094

あるいは、getLearnableParameterssetLearnableParameters を使用することもできます。

まず、学習可能なパラメーターをクリティックから取得します。

params = getLearnableParameters(critic)
params=2×1 cell array
    {[-10.0034 -3.1026 -0.6848 -0.2232 -0.1011 -0.0094]}
    {[                                               0]}

パラメーター値を変更します。この例では、すべてのパラメーターを単純に 2 で割ります。

modifiedParams = cellfun(@(x) x/2,params,"UniformOutput",false);

クリティックのパラメーター値を新しく変更した値に設定します。

critic = setLearnableParameters(critic,modifiedParams);

エージェント内のクリティックを新しく変更したクリティックに設定します。

setCritic(agent,critic);

新しいパラメーター値を表示します。

getLearnableParameters(getCritic(agent))
ans=2×1 cell array
    {[-5.0017 -1.5513 -0.3424 -0.1116 -0.0506 -0.0047]}
    {[                                              0]}

連続行動空間をもつ環境を作成し、その観測仕様とアクション仕様を取得します。この例では、Compare DDPG Agent to LQR Controllerの例で使用されている環境を読み込みます。

事前定義された環境を読み込みます。

env = rlPredefinedEnv("DoubleIntegrator-Continuous");

観測仕様とアクション仕様を取得します。

obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

環境の観測仕様とアクション仕様から PPO エージェントを作成します。このエージェントは、アクターとクリティックに既定の深層ニューラル ネットワークを使用します。

agent = rlPPOAgent(obsInfo,actInfo);

強化学習エージェント内の深層ニューラル ネットワークを変更するには、まずアクター関数近似器とクリティック関数近似器を抽出しなければなりません。

actor = getActor(agent);
critic = getCritic(agent);

アクター関数近似器とクリティック関数近似器の両方から深層ニューラル ネットワークを抽出します。

actorNet = getModel(actor);
criticNet = getModel(critic);

アクター ネットワークをプロットします。

plot(actorNet)

Figure contains an axes object. The axes object contains an object of type graphplot.

ネットワークを検証するには、analyzeNetwork を使用します。たとえば、クリティック ネットワークを検証します。

analyzeNetwork(criticNet)

アクターおよびクリティックのネットワークを変更し、エージェントに保存し直すことができます。ネットワークを変更するには、ディープ ネットワーク デザイナーアプリを使用できます。各ネットワーク用にアプリを開くには、次のコマンドを使用します。

deepNetworkDesigner(criticNet)
deepNetworkDesigner(actorNet)

"ディープ ネットワーク デザイナー" でネットワークを変更します。たとえば、ネットワークに他の層を追加できます。ネットワークを変更するときに、getModel によって返されるネットワークの入力層と出力層を変更しないでください。ネットワークの構築の詳細については、ディープ ネットワーク デザイナーを使用したネットワークの構築を参照してください。

変更したネットワークを "ディープ ネットワーク デザイナー" で検証するには、[解析] セクションの [解析] をクリックしなければなりません。変更したネットワーク構造を MATLAB® ワークスペースにエクスポートするには、新しいネットワークを作成するコードを生成し、コマンド ラインからそのコードを実行します。"ディープ ネットワーク デザイナー" のエクスポート オプションは使用しないでください。コードを生成して実行する方法を示す例については、ディープ ネットワーク デザイナーを使用した DQN エージェントの作成およびイメージ観測値を使用した学習を参照してください。

この例では、変更したアクター ネットワークとクリティック ネットワークを作成するコードが createModifiedNetworks 補助スクリプトに含まれています。

createModifiedNetworks

変更した各ネットワークには、メインの共通パスに追加の fullyConnectedLayerreluLayer が含まれます。変更したアクター ネットワークをプロットします。

plot(modifiedActorNet)

Figure contains an axes object. The axes object contains an object of type graphplot.

ネットワークをエクスポートした後、ネットワークをアクター関数近似器とクリティック関数近似器に挿入します。

actor = setModel(actor,modifiedActorNet);
critic = setModel(critic,modifiedCriticNet);

最後に、変更したアクター関数近似器とクリティック関数近似器をアクター オブジェクトとクリティック オブジェクトに挿入します。

agent = setActor(agent,actor);
agent = setCritic(agent,critic);

入力引数

すべて折りたたむ

クリティックを含む強化学習エージェント。次のオブジェクトのいずれかとして指定します。

出力引数

すべて折りたたむ

クリティック オブジェクト。次のいずれかとして返されます。

  • rlValueFunction オブジェクト — agentrlACAgentrlPGAgent、または rlPPOAgent オブジェクトの場合に返されます。

  • rlQValueFunction オブジェクト — agent が、単一のクリティックをもつ rlQAgentrlSARSAAgentrlDQNAgentrlDDPGAgent、または rlTD3Agent オブジェクトの場合に返されます。

  • rlVectorQValueFunction オブジェクト — agent が、ベクトル Q 値関数クリティックをもつ rlQAgentrlSARSAAgentrlDQNAgent、または非連続行動空間をもつ rlSACAgent オブジェクトの場合に返されます。

  • rlQValueFunction オブジェクトの 2 要素の行ベクトル — agent が連続行動空間と 2 つのクリティックをもつ rlTD3Agent オブジェクトまたは rlSACAgent オブジェクトの場合に返されます。

  • rlVectorQValueFunction オブジェクトの 2 要素の行ベクトル — agent が非連続行動空間と 2 つのクリティックをもつ rlSACAgent オブジェクトの場合に返されます。

バージョン履歴

R2019a で導入