Main Content

Create and Train Custom LQR Agent

This example shows how to create and train a custom linear quadratic regulation (LQR) agent to control a discrete-time linear system modeled in MATLAB®. For an introduction to custom agents, see Create Custom Reinforcement Learning Agents. For a step by step example on how to create a custom PG agent (using the REINFORCE algorithm) see Create and Train Custom PG Agent. For an example of how a DDPG agent can be used as an optimal controller for a discrete-time system, see Compare DDPG Agent to LQR Controller.

Fix Random Number Stream for Reproducibility

The example code might involve computation of random numbers at various stages. Fixing the random number stream at the beginning of various sections in the example code preserves the random number sequence in the section every time you run it, and increases the likelihood of reproducing the results. For more information, see Results Reproducibility.

Fix the random number stream with seed 0 and random number algorithm Mersenne Twister. For more information on controlling the seed used for random number generation, see rng.

previousRngState = rng(0,"twister");

The output previousRngState is a structure that contains information about the previous state of the stream. You will restore the state at the end of the example.

Create Linear System Environment

The reinforcement learning environment for this example is a discrete-time linear system. The dynamics for the system are given by:

xt+1=Axt+But.

The feedback control law is:

ut=-Kxt.

The control objective is to minimize the quadratic cost: J=t=0(xtQxt+utRut).

In this example, the system matrices are:

A=[1.050.050.050.051.050.0500.051.05]B=[0.100.20.10.50000.5]

A = [1.05,0.05,0.05;0.05,1.05,0.05;0,0.05,1.05];
B = [0.1,0,0.2;0.1,0.5,0;0,0,0.5]; 

The quadratic cost matrices are:

Q=[1031354149]R=[0.50000.50000.5]

Q = [10,3,1;3,5,4;1,4,9]; 
R = 0.5*eye(3);

For this environment, the reward at time t is given by rt=-xt"Qxt-ut"Rut, which is the negative of the quadratic cost. Therefore, maximizing the reward minimizes the cost. The initial conditions are set randomly by the reset function.

The myDiscreteEnv function creates an environment by defining custom step and reset functions. For more information on creating such a custom environment, see Create Custom Environment Using Step and Reset Functions.

type("myDiscreteEnv");
function env = myDiscreteEnv(A,B,Q,R)

% This function creates a discrete-time linear system environment.
%
% (A,B) are the system matrices, where dx = Ax + Bu.
% (Q,R) defines the quadratic cost, where r = x'Qx + u'Ru.

% Copyright 2018-2019 The MathWorks Inc.

% Create observation and action specification objects.
% The observation is the state, x, the action is the control input, u.
obsInfo = rlNumericSpec([size(A,1),1]);
actInfo = rlNumericSpec([size(B,2),1]);

% Create environment. Define anonymous step and reset functions that 
% take the required number or arguments and in turn call functions 
% that use the parameters A,B,Q, and R. 
env = rlFunctionEnv(obsInfo,actInfo,...
    @(action,state) myStepFunction(action,state,A,B,Q,R), ...
    @() myResetFunction(Q));

end


function [Observation,Reward,IsDone,x] = myStepFunction(u,x,A,B,Q,R)
% This is the step function for the environment, which returns the 
% next observation for a given action.

% Advance system according to its dynamics.
xp = A*x + B*u;
Observation = xp;

% Set new state.
x = xp;

% Set IsDone to false.
IsDone = false; 

% Calculate reward.
Reward = -x'*Q*x -u'*R*u;

end


function [InitialObservation, State] = myResetFunction(Q)
% This is the reset function for the environment, which sets 
% random initial conditions for observation and state.

n = size(Q,1);
x0 = rand(n,1);
InitialObservation = x0;
State=InitialObservation;

end

Create the environment object.

env = myDiscreteEnv(A,B,Q,R);

Create Custom LQR Agent

For an LQR problem, the optimal Q-value function can be represented by the quadratic form QV(x,u)=[xTuT]S[xu], where S=[SxxSxuSuxSuu] is a symmetric, positive definite matrix defined in [1]. Taking the partial derivative of QV(x,u) with respect to u, setting that to zero and solving for u (as in [2]) yields u=-Suu-1Suxx, which is the control law that maximizes the value of QV(x,u).

Because it is symmetric, the matrix S contains m=12n(n+1) distinct element values, where n is the sum of the number of states and number of inputs. Define as θ* the vector containing these m elements, in which the off-diagonal elements in S are multiplied by two. You can then express QV(x,u) as the inner product of the vectors θ* and h(x,u), where h(x,u) is a vector of quadratic monomials built from the combination of all the elements in x and u. For an example, see the WQ matrix in Compare DDPG Agent to LQR Controller.

In this example, the custom agent needs to learn the parameter vector θ*, (starting from an initial random vector θ) so that θ*Th(x,u) = QV(x,u).

Fix the random generator seed for reproducibility.

rng(0,"twister")

The LQR agent starts with a stabilizing controller K0. To get an initial stabilizing controller, place the poles of the closed-loop system A-BK0 inside the unit circle.

K0 = place(A,B,[0.4,0.8,0.5]);

To create a custom agent, you must create a subclass of the rl.agent.CustomAgent abstract class. For the custom LQR agent, the defined custom subclass is LQRCustomAgent. For more information, see Create Custom Reinforcement Learning Agents.

type("LQRCustomAgent");
classdef LQRCustomAgent < rl.agent.CustomAgent
    % LQRCustomAgent: Creates an LQR Agent for a linear system.
    %
    %   agent = LQRCustomAgent(Q,R,K0) creates a LQR agent
    %             Q and R are penalty matrices for the state and action,
    %             K0 is an initial stabilizing feedback gain.
    %
    % Copyright 2018-2019 The MathWorks Inc.
    
    %% Public Properties
    properties
        % Q - Weights state deviation from zero
        Q

        % R - Weights control action
        R

        % K - Feedback gain: u = K*x
        K

        % Discount Factor
        Gamma = 0.95

        % Critic Weights
        Theta

        % Buffer for K
        KBuffer  
        
        % Number of updates for K
        KUpdate = 1

        % Number of steps before updating the critic
        EstimateNum = 10
    end
    
    properties (Access = private)
        Counter = 1
        YBuffer
        HBuffer 
    end
    
    
    %% MAIN METHODS
    methods
        % Constructor
        function obj = LQRCustomAgent(Q,R,K0)

            % Check the number of input arguments.
            narginchk(3,3);

            % Call the abstract class constructor.
            obj = obj@rl.agent.CustomAgent();

            % Set the Q and R properties.
            obj.Q = Q;
            obj.R = R;

            % Define the observation and action spaces.
            obj.ObservationInfo_ = rlNumericSpec([size(Q,1),1]);
            obj.ActionInfo_ = rlNumericSpec([size(R,1),1]);

            % Create the critic.
            obj.Theta = createCriticWeights(obj);

            % Initialize the gain matrix.
            obj.K = K0;

            % Initialize the experience buffers.
            obj.YBuffer = zeros(obj.EstimateNum,1);
            num = size(Q,1) + size(R,1);
            obj.HBuffer = zeros(obj.EstimateNum,0.5*num*(num+1));
            obj.KBuffer = cell(1,1000);
            obj.KBuffer{1} = obj.K;
        end
        end
    
    %% Implementation of abstract parent protected methods
    methods (Access = protected)

        % Action methods

        function action = getActionImpl(obj,Observation)
            % Given the current state of the system, return an action.
            action = -obj.K*Observation{:};
        end

        function action = getActionWithExplorationImpl(obj,Observation)

            % Given the current observation, select an action
            action = getActionImpl(obj,Observation);
            
            % Add random noise to action
            num = size(obj.R,1);
            action = action + 0.1*randn(num,1);

        end

        % learn from current experiences, return action with exploration
        % exp = {state,action,reward,nextstate,isdone}
        function action = learnImpl(obj,exp)
            % Parse the experience input
            x = exp{1}{1};
            u = exp{2}{1};
            dx = exp{4}{1};            
            y = (x'*obj.Q*x + u'*obj.R*u);  % exp{3}{1}
            num = size(obj.Q,1) + size(obj.R,1);

            % Wait N steps before updating critic parameters
            N = obj.EstimateNum;

            % Evaluating the critic at the points (x,u) and (dx,-K*dx) 
            % yields Q1 = theta'*h1 and Q2 = theta'*h2, respectively, 
            % where h1 = h(x,u) and h2 = h(dx,-K*dx) theta'*h(x,u)
            % The agent tries to learn theta so that Q1 = y + gamma*Q2, 
            % that is, theta'*(h1-gamma*h2) = theta'*H = y, where 
            % y = (x'*obj.Q*x + u'*obj.R*u) is the reward.

            % Following is the least square solution.
            h1 = computeQuadraticBasis(x,u,num);
            h2 = computeQuadraticBasis(dx,-obj.K*dx,num);
            H = h1 - obj.Gamma*h2;

            if obj.Counter<=N

                obj.YBuffer(obj.Counter) = y;
                obj.HBuffer(obj.Counter,:) = H;
                obj.Counter = obj.Counter + 1;

            else

                % Update the critic parameters 
                % based on the batch of experiences
                H_buf = obj.HBuffer;
                y_buf = obj.YBuffer;
                theta = (H_buf'*H_buf)\H_buf'*y_buf;
                obj.Theta = theta;
                
                % Derive a new gain matrix 
                % based on the new critic parameters
                obj.K = getNewK(obj);
                
                % Reset the experience buffers
                obj.Counter = 1;
                obj.YBuffer = zeros(N,1);
                obj.HBuffer = zeros(N,0.5*num*(num+1));    
                obj.KUpdate = obj.KUpdate + 1;
                obj.KBuffer{obj.KUpdate} = obj.K;

            end

            % Find and return an action with exploration
            action = getActionWithExplorationImpl(obj,exp{4});
        end

        % Create critic 
        function w0 = createCriticWeights(obj)
            nQ = size(obj.Q,1);
            nR = size(obj.R,1);
            n = nQ+nR;
            w0 = 0.1*ones(0.5*(n+1)*n,1);
        end

        % Update K from critic
        function k = getNewK(obj)
            w = obj.Theta;
            nQ = size(obj.Q,1);
            nR = size(obj.R,1);
            n = nQ+nR;
            idx = 1;
            for r = 1:n
                for c = r:n
                    Phat(r,c) = w(idx);
                    idx = idx + 1;
                end
            end
            S  = 1/2*(Phat+Phat');
            Suu = S(nQ+1:end,nQ+1:end);
            Sux = S(nQ+1:end,1:nQ);
            if rank(Suu) == nR
                k = Suu\Sux;
            else
                k = obj.K;
            end
        end       
        
    end
        
end

%% local function
function B = computeQuadraticBasis(x,u,n)
z = cat(1,x,u);
idx = 1;
for r = 1:n
    for c = r:n
        if idx == 1
            B = z(r)*z(c);
        else
            B = cat(1,B,z(r)*z(c));
        end
        idx = idx + 1;
    end
end
end

Create the custom LQR agent using Q, R, and K0. The agent does not require information on the system matrices A and B.

agent = LQRCustomAgent(Q,R,K0);

For this example, set the agent discount factor to one. To use a discounted future reward, set the discount factor to a value less than one.

agent.Gamma = 1;

Because the linear system has three states and three inputs, the total number of learnable parameters is m=21. To ensure satisfactory performance of the agent, set the number of data point to be collected before updating the critic, Np, to be greater than twice the number of learnable parameters. For this example, set Np to 45.

agent.EstimateNum = 45;

To get good estimation results for θ, you must apply a persistently excited exploration model to the system. In this example, encourage model exploration by adding white noise to the controller output: ut=-Kxt+et. In general, the exploration model depends on the system models.

Train Agent

To train the agent, first specify the training options. For this example, use the following options.

  • Run each training episode for a maximum of 10 episodes, with each episode lasting at most 50 time steps.

  • Display the training progress in the Reinforcement Learning Training Monitor dialog box (set the Plots option) and disable command line display (set the Verbose option).

For more information on training options, see rlTrainingOptions.

trainingOpts = rlTrainingOptions(...
    MaxEpisodes=10, ...
    MaxStepsPerEpisode=50, ...
    Verbose=false, ...
    Plots="training-progress");

Train the agent using the train function.

trainingStats = train(agent,env,trainingOpts);

Simulate Agent and Compare with Optimal Solution

To validate the performance of the trained agent, simulate it within the environment. For more information on agent simulation, see rlSimulationOptions and sim.

simOptions = rlSimulationOptions(MaxSteps=20);
experience = sim(env,agent,simOptions);
totalReward = sum(experience.Reward)
totalReward = 
-5.1179

You can compute the optimal solution for the LQR problem using the dlqr function.

[Koptimal,P] = dlqr(A,B,Q,R); 

The reward that the optimal control gets is Joptimal=-x0Px0.

x0 = experience.Observation.obs1.getdatasamples(1);
Joptimal = -x0'*P*x0;

Compute the error in the reward between the trained LQR agent and the optimal LQR solution.

rewardError = totalReward - Joptimal
rewardError = 
25.5303

View the history of the norm of the difference between the gains between the trained LQR agent and the optimal LQR solution.

% Number of gain updates
len = agent.KUpdate;

% Initialize error vector
err = zeros(len,1);

% Fill elements
for i = 1:len
    err(i) = norm(agent.KBuffer{i}-Koptimal);
end

% Plot logarithm of the error vector
plot(log10(err),'b*-')
title("Log of gain difference")
xlabel("Number of updates")

Figure contains an axes object. The axes object with title Log of gain difference, xlabel Number of updates contains an object of type line.

Compute the norm of final error for the feedback gain.

gainError = norm(agent.K - Koptimal)
gainError = 
2.1203e-11

Overall, the trained agent finds a solution that is very close to the true optimal LQR solution.

Restore the random number stream using the information stored in previousRngState.

rng(previousRngState);

References

[1] Bradtke, S., "Reinforcement Learning Applied to Linear Quadratic Regulation," Advances in Neural Information Processing Systems (1992), Vol. 5.

[2] Asri, S., and Luis Rodrigues. "Data-Driven LQR using Reinforcement Learning and Quadratic Neural Networks." arXiv preprint arXiv:2311.10235 (2023).

See Also

Functions

Objects

Topics