Main Content

Train Reinforcement Learning Agent for Simple Contextual Bandit Problem

This example shows how to solve a contextual bandit problem [1] using reinforcement learning by training DQN and Q agents. For more information on these agents, see Deep Q-Network (DQN) Agents and Q-Learning Agents.

In contextual bandits problems, an agent selects an action given the initial observation (context), it then receives a reward, and the episode terminates. Hence, the agent action does not affect the next observation.

The context bandits can be used for various applications such as hyper-parameter tuning, recommender systems, medical treatment, and 5G communication.

The following figure describes the difference between reinforcement learning, multi-armed bandits, and contextual bandits.

Environment

The contextual bandit environment in this example is defined as follows:

Observation (discrete): {1, 2}

The context (initial observation) is sampled randomly.

Pr(s=1)=0.5Pr(s=2)=0.5

Action (discrete): {1, 2, 3}

Reward:

Rewards in this environment are stochastic. The probability of each observation and action pair is defined below.

1.s=1,a=1Pr(r=5   |s=1,a=1)=0.3Pr(r=2|s=1,a=1)=0.72.s=1,a=2Pr(r=10|s=1,a=2)=0.1Pr(r=1  |s=1,a=2)=0.93.s=1,a=3Pr(r=3.5|s=1,a=3)=1

4.s=2,a=1Pr(r=10   |s=2,a=1)=0.2Pr(r=2|s=2,a=1)=0.85.s=2,a=2Pr(r=3|s=2,a=2)=16.s=2,a=3Pr(r=5|s=2,a=3)=0.5Pr(r=0.5|s=2,a=3)=0.5

Note that the agent doesn't know these distributions.

Is-Done signal: This is a contextual bandit problem, and each episode has only one step. Hence, the Is-Done signal is always 1.

Create Environment Interface

Create the contextual bandit environment using ToyContextualBanditEnvironment located in this example folder.

env = ToyContextualBanditEnvironment;
obsInfo = getObservationInfo(env);
actInfo = getActionInfo(env);

Fix the random generator seed for reproducibility.

rng(1) 

Create DQN Agent

Create a DQN agent with a default network structure using rlAgentInitializationOptions.

agentOpts = rlDQNAgentOptions(...
    UseDoubleDQN = false, ...    
    TargetSmoothFactor = 1, ...
    TargetUpdateFrequency = 4, ...     
    MiniBatchSize = 64);
agentOpts.EpsilonGreedyExploration.EpsilonDecay = 0.0005;

initOpts = rlAgentInitializationOptions(NumHiddenUnit = 16);

DQNagent = rlDQNAgent(obsInfo, actInfo, initOpts, agentOpts);

Train Agent

To train the agent, first, specify the training options. For this example, use the following options:

  • Train for 3000 episodes.

  • Since this is a contextual bandit problem, and each episode has only one step, set MaxStepsPerEpisode to 1.

For more information, see rlTrainingOptions.

Train the agent using the train function. To save time while running this example, load a pre-trained agent by setting doTraining to false. To train the agent yourself, set doTraining to be true.

MaxEpisodes = 3000;
trainOpts = rlTrainingOptions(...
    MaxEpisodes = MaxEpisodes, ...
    MaxStepsPerEpisode = 1, ...
    Verbose = false, ...
    Plots = "training-progress",...
    StopTrainingCriteria = "EpisodeCount",...
    StopTrainingValue = MaxEpisodes); 

doTraining = false;
if doTraining
    % Train the agent.
    trainingStats = train(DQNagent,env,trainOpts);
else
    % Load the pre-trained agent for the example.
    load("ToyContextualBanditDQNAgent.mat","DQNagent")
end

Validate DQN Agent

Assume that you know the distribution of the rewards, and you can compute the optimal actions. Validate the agent's performance by comparing these optimal actions with the actions selected by the agent. First, compute the true expected rewards with the true distributions.

1. The expected reward of each action at s=1 is as follows.

Ifa=1E[R]=0.3*5+0.7*2=2.9Ifa=2E[R]=0.1*10+0.9*1=1.9Ifa=3E[R]=3.5

Hence, the optimal action is 3 when s=1.

2. The expected reward of each action at s=2 is as follows.

Ifa=1E[R]=0.2*10+0.8*2=3.6Ifa=2E[R]=3.0If  a=3E[R]=0.5*5+0.5*0.5=2.75

Hence, the optimal action is 1 when s=2.

With enough sampling, the Q-values should be closer to the true expected reward. Visualize the true expected rewards.

ExpectedRewards = zeros(2,3);
ExpectedRewards(1,1) = 0.3*5 + 0.7*2;
ExpectedRewards(1,2) = 0.1*10 + 0.9*1;
ExpectedRewards(1,3) = 3.5;
ExpectedRewards(2,1) = 0.2*10 + 0.8*2;
ExpectedRewards(2,2) = 3.0;
ExpectedRewards(2,3) = 0.5*5 + 0.5*0.5;

localPlotQvalues(ExpectedRewards, "Expected Rewards")

Figure contains an axes object. The axes object with title Expected Rewards contains 7 objects of type image, text.

Now, validate whether the DQN agent learns the optimal behavior.

If state =1, the optimal action is 3.

observation = 1;
getAction(DQNagent,observation)
ans = 1x1 cell array
    {[3]}

The agent also selects the optimal action.

If state = 2, the optimal action is1.

observation = 2;
getAction(DQNagent,observation)
ans = 1x1 cell array
    {[1]}

The agent also selects the optimal action. Thus, the DQN agent has learned the optimal behavior.

Next, compare the Q-Value function to the true expected reward when selecting the optimal action.

% Get critic
figure(1)
DQNcritic = getCritic(DQNagent);
QValues = zeros(2,3);
for s = 1:2
    QValues(s,:) = getValue(DQNcritic, {s});
end

% Visualize Q values
localPlotQvalues(QValues, "Q values")

Figure contains an axes object. The axes object with title Q values contains 7 objects of type image, text.

The learned Q-values are close to the true expected rewards computed above.

Create Q-learning Agent

Next, we train a Q-learning agent. To create a Q-learning agent, first, create a table using the observation and action specifications from the environment.

rng(1); % For reproducibility

qTable = rlTable(obsInfo, actInfo);
critic = rlQValueFunction(qTable, obsInfo, actInfo);

opt = rlQAgentOptions;
opt.EpsilonGreedyExploration.Epsilon = 1;
opt.EpsilonGreedyExploration.EpsilonDecay = 0.0005;

Qagent = rlQAgent(critic,opt);

Train Q-learning Agent

To save time while running this example, load a pre-trained agent by setting doTraining to false. To train the agent yourself, set doTraining to true.

doTraining = false;
if doTraining
    % Train the agent.
    trainingStats = train(Qagent,env,trainOpts);
else
    % Load the pre-trained agent for the example.
    load("ToyContextualBanditQAgent.mat","Qagent")
end

Validate Q-learning Agent

When state =1, the optimal action is 3.

observation = 1;
getAction(Qagent,observation)
ans = 1x1 cell array
    {[3]}

The agent also selects the optimal action.

When state =2, the optimal action is 1.

observation = 2;
getAction(Qagent,observation)
ans = 1x1 cell array
    {[1]}

The agent also selects the optimal action. Hence, the Q-learning agent has learned the optimal behavior.

Next, compare the Q-Value function to the true expected reward when selecting the optimal action.

% Get critic
figure(2)
Qcritic = getCritic(Qagent);
QValues = zeros(2,3);
for s = 1:2
    for a = 1:3
        QValues(s,a) = getValue(Qcritic, {s}, {a});
    end
end

% Visualize Q values
localPlotQvalues(QValues, "Q values")

Figure contains an axes object. The axes object with title Q values contains 7 objects of type image, text.

Again, the learned Q-values are close to the true expected rewards computed above. The Q-values for deterministic rewards, Q(s=1, a=3) and Q(s=2, a=2), are the same as true expected rewards. Note that the corresponding Q-values learned by the DQN network, while close, are not identical to the true values. This happens because the DQN uses a neural network instead of a table as internal function approximator.

Local Function

function localPlotQvalues(QValues, titleText)
    % Visualize Q values 
    figure;
    imagesc(QValues,[1,4])
    colormap("autumn")
    title(titleText)
    colorbar
    set(gca,'Xtick',1:3,'XTickLabel',{"a=1", "a=2", "a=3"})
    set(gca,'Ytick',1:2,'YTickLabel',{"s=1", "s=2"})

    % Plot values on the image
    x = repmat(1:size(QValues,2), size(QValues,1), 1);
    y = repmat(1:size(QValues,1), size(QValues,2), 1)';
    QValuesStr = num2cell(QValues);
    QValuesStr = cellfun(@num2str, QValuesStr, UniformOutput=false);
    text(x(:), y(:), QValuesStr, HorizontalAlignment = "Center")
end

Reference

[1] RS Sutton, AG Barto, "Reinforcement learning: An Introduction, 2nd edition", MIT press