# Train Reinforcement Learning Agent for Simple Contextual Bandit Problem

This example shows how to solve a contextual bandit problem [1] using reinforcement learning by training DQN and Q agents. For more information on these agents, see Deep Q-Network (DQN) Agents and Q-Learning Agents.

In contextual bandits problems, an agent selects an action given the initial observation (context), it then receives a reward, and the episode terminates. Hence, the agent action does not affect the next observation.

The context bandits can be used for various applications such as hyper-parameter tuning, recommender systems, medical treatment, and 5G communication.

The following figure describes the difference between reinforcement learning, multi-armed bandits, and contextual bandits.

### Environment

The contextual bandit environment in this example is defined as follows:

Observation (discrete): {1, 2}

The context (initial observation) is sampled randomly.

`$\begin{array}{l}\mathrm{Pr}\left(\mathit{s}=1\right)=0.5\\ \mathrm{Pr}\left(\mathit{s}=2\right)=0.5\end{array}$`

Action (discrete): {1, 2, 3}

Reward:

Rewards in this environment are stochastic. The probability of each observation and action pair is defined below.

`$\begin{array}{l}1.\text{\hspace{0.17em}}\mathit{s}=1,\text{\hspace{0.17em}}\mathit{a}=1\\ \mathrm{Pr}\left(\mathit{r}=\text{\hspace{0.17em}}5\text{\hspace{0.17em}}\text{\hspace{0.17em}\hspace{0.17em}\hspace{0.17em}}\text{\hspace{0.17em}}|\text{\hspace{0.17em}}\mathit{s}=1,\mathit{a}=1\right)=0.3\\ \mathrm{Pr}\left(\mathit{r}=\text{\hspace{0.17em}}2\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}|\text{\hspace{0.17em}}\mathit{s}=1,\mathit{a}=1\right)=0.7\\ \\ 2.\text{\hspace{0.17em}}\mathit{s}=1,\text{\hspace{0.17em}}\mathit{a}=2\\ \mathrm{Pr}\left(\mathit{r}=10\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}|\text{\hspace{0.17em}}\mathit{s}=1,\mathit{a}=2\right)=0.1\\ \mathrm{Pr}\left(\mathit{r}=1\text{\hspace{0.17em}\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}|\text{\hspace{0.17em}}\mathit{s}=1,\mathit{a}=2\right)=0.9\\ \\ 3.\text{\hspace{0.17em}}\mathit{s}=1,\text{\hspace{0.17em}}\mathit{a}=3\\ \mathrm{Pr}\left(\mathit{r}=3.5\text{\hspace{0.17em}}|\text{\hspace{0.17em}}\mathit{s}=1,\mathit{a}=3\right)=1\\ \end{array}$`

`$\begin{array}{l}4.\text{\hspace{0.17em}}\text{\hspace{0.17em}}\mathit{s}=2,\text{\hspace{0.17em}}\mathit{a}=1\\ \mathrm{Pr}\left(\mathit{r}=\text{\hspace{0.17em}}10\text{\hspace{0.17em}\hspace{0.17em}\hspace{0.17em}}\text{\hspace{0.17em}}|\text{\hspace{0.17em}}\mathit{s}=2,\text{\hspace{0.17em}}\mathit{a}=1\right)=0.2\\ \mathrm{Pr}\left(\mathit{r}=\text{\hspace{0.17em}}2\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}|\text{\hspace{0.17em}}\mathit{s}=2,\text{\hspace{0.17em}}\text{\hspace{0.17em}}\mathit{a}=1\right)=0.8\\ \\ 5.\text{\hspace{0.17em}}\mathit{s}=2,\text{\hspace{0.17em}}\mathit{a}=2\\ \mathrm{Pr}\left(\mathit{r}=3\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}\text{\hspace{0.17em}}|\text{\hspace{0.17em}}\mathit{s}=2,\mathit{a}=2\right)=1\\ \\ 6.\text{\hspace{0.17em}}\mathit{s}=2,\text{\hspace{0.17em}}\mathit{a}=3\\ \mathrm{Pr}\left(\mathit{r}=5\text{\hspace{0.17em}}|\text{\hspace{0.17em}}\mathit{s}=2,\mathit{a}=3\right)=0.5\\ \mathrm{Pr}\left(\mathit{r}=0.5\text{\hspace{0.17em}}|\text{\hspace{0.17em}}\mathit{s}=2,\mathit{a}=3\right)=0.5\end{array}$`

Note that the agent doesn't know these distributions.

Is-Done signal: This is a contextual bandit problem, and each episode has only one step. Hence, the Is-Done signal is always 1.

### Create Environment Interface

Create the contextual bandit environment using `ToyContextualBanditEnvironment `located in this example folder.

```env = ToyContextualBanditEnvironment; obsInfo = getObservationInfo(env); actInfo = getActionInfo(env);```

Fix the random generator seed for reproducibility.

`rng(1) `

### Create DQN Agent

Create a DQN agent with a default network structure using `rlAgentInitializationOptions`.

```agentOpts = rlDQNAgentOptions(... UseDoubleDQN = false, ... TargetSmoothFactor = 1, ... TargetUpdateFrequency = 4, ... MiniBatchSize = 64); agentOpts.EpsilonGreedyExploration.EpsilonDecay = 0.0005; initOpts = rlAgentInitializationOptions(NumHiddenUnit = 16); DQNagent = rlDQNAgent(obsInfo, actInfo, initOpts, agentOpts);```

### Train Agent

To train the agent, first, specify the training options. For this example, use the following options:

• Train for 3000 episodes.

• Since this is a contextual bandit problem, and each episode has only one step, set `MaxStepsPerEpisode` to 1.

For more information, see `rlTrainingOptions`.

Train the agent using the `train` function. To save time while running this example, load a pre-trained agent by setting `doTraining` to `false`. To train the agent yourself, set `doTraining` to be `true`.

```MaxEpisodes = 3000; trainOpts = rlTrainingOptions(... MaxEpisodes = MaxEpisodes, ... MaxStepsPerEpisode = 1, ... Verbose = false, ... Plots = "training-progress",... StopTrainingCriteria = "EpisodeCount",... StopTrainingValue = MaxEpisodes); doTraining = false; if doTraining % Train the agent. trainingStats = train(DQNagent,env,trainOpts); else % Load the pre-trained agent for the example. load("ToyContextualBanditDQNAgent.mat","DQNagent") end```

### Validate DQN Agent

Assume that you know the distribution of the rewards, and you can compute the optimal actions. Validate the agent's performance by comparing these optimal actions with the actions selected by the agent. First, compute the true expected rewards with the true distributions.

1. The expected reward of each action at s=1 is as follows.

`$\begin{array}{l}\mathrm{If}\text{\hspace{0.17em}}\mathit{a}=1\\ \mathit{E}\left[\mathit{R}\right]=\mathrm{0}.3*5+0.7*2=2.9\text{\hspace{0.17em}}\\ \mathrm{If}\text{\hspace{0.17em}}\mathit{a}=2\\ \mathit{E}\left[\mathit{R}\right]=0.1*10+0.9*1=1.9\text{\hspace{0.17em}}\\ \mathrm{If}\text{\hspace{0.17em}}\mathit{a}=3\\ \mathit{E}\left[\mathit{R}\right]=3.5\end{array}$`

Hence, the optimal action is 3 when s=1.

2. The expected reward of each action at s=2 is as follows.

`$\begin{array}{l}\mathrm{If}\text{\hspace{0.17em}}\mathit{a}=1\\ \mathit{E}\left[\mathit{R}\right]=0.2*10+0.8*2=3.6\\ \mathrm{If}\text{\hspace{0.17em}}\text{\hspace{0.17em}}\mathit{a}=2\\ \mathit{E}\left[\mathit{R}\right]=3.0\\ \mathrm{If}\text{\hspace{0.17em}\hspace{0.17em}}\mathit{a}=3\\ \mathit{E}\left[\mathit{R}\right]=0.5*5\text{\hspace{0.17em}}+0.5*0.5=2.75\text{\hspace{0.17em}}\end{array}$`

Hence, the optimal action is 1 when s=2.

With enough sampling, the Q-values should be closer to the true expected reward. Visualize the true expected rewards.

```ExpectedRewards = zeros(2,3); ExpectedRewards(1,1) = 0.3*5 + 0.7*2; ExpectedRewards(1,2) = 0.1*10 + 0.9*1; ExpectedRewards(1,3) = 3.5; ExpectedRewards(2,1) = 0.2*10 + 0.8*2; ExpectedRewards(2,2) = 3.0; ExpectedRewards(2,3) = 0.5*5 + 0.5*0.5; localPlotQvalues(ExpectedRewards, "Expected Rewards")```

Now, validate whether the DQN agent learns the optimal behavior.

If state =1, the optimal action is 3.

```observation = 1; getAction(DQNagent,observation)```
```ans = 1x1 cell array {[3]} ```

The agent also selects the optimal action.

If state = 2, the optimal action is1.

```observation = 2; getAction(DQNagent,observation)```
```ans = 1x1 cell array {[1]} ```

The agent also selects the optimal action. Thus, the DQN agent has learned the optimal behavior.

Next, compare the Q-Value function to the true expected reward when selecting the optimal action.

```% Get critic figure(1) DQNcritic = getCritic(DQNagent); QValues = zeros(2,3); for s = 1:2 QValues(s,:) = getValue(DQNcritic, {s}); end % Visualize Q values localPlotQvalues(QValues, "Q values")```

The learned Q-values are close to the true expected rewards computed above.

### Create Q-learning Agent

Next, we train a Q-learning agent. To create a Q-learning agent, first, create a table using the observation and action specifications from the environment.

```rng(1); % For reproducibility qTable = rlTable(obsInfo, actInfo); critic = rlQValueFunction(qTable, obsInfo, actInfo); opt = rlQAgentOptions; opt.EpsilonGreedyExploration.Epsilon = 1; opt.EpsilonGreedyExploration.EpsilonDecay = 0.0005; Qagent = rlQAgent(critic,opt);```

### Train Q-learning Agent

To save time while running this example, load a pre-trained agent by setting `doTraining` to `false`. To train the agent yourself, set `doTraining` to `true`.

```doTraining = false; if doTraining % Train the agent. trainingStats = train(Qagent,env,trainOpts); else % Load the pre-trained agent for the example. load("ToyContextualBanditQAgent.mat","Qagent") end```

### Validate Q-learning Agent

When state =1, the optimal action is 3.

```observation = 1; getAction(Qagent,observation)```
```ans = 1x1 cell array {[3]} ```

The agent also selects the optimal action.

When state =2, the optimal action is 1.

```observation = 2; getAction(Qagent,observation)```
```ans = 1x1 cell array {[1]} ```

The agent also selects the optimal action. Hence, the Q-learning agent has learned the optimal behavior.

Next, compare the Q-Value function to the true expected reward when selecting the optimal action.

```% Get critic figure(2) Qcritic = getCritic(Qagent); QValues = zeros(2,3); for s = 1:2 for a = 1:3 QValues(s,a) = getValue(Qcritic, {s}, {a}); end end % Visualize Q values localPlotQvalues(QValues, "Q values")```

Again, the learned Q-values are close to the true expected rewards computed above. The Q-values for deterministic rewards, Q(s=1, a=3) and Q(s=2, a=2), are the same as true expected rewards. Note that the corresponding Q-values learned by the DQN network, while close, are not identical to the true values. This happens because the DQN uses a neural network instead of a table as internal function approximator.

### Local Function

```function localPlotQvalues(QValues, titleText) % Visualize Q values figure; imagesc(QValues,[1,4]) colormap("autumn") title(titleText) colorbar set(gca,'Xtick',1:3,'XTickLabel',{"a=1", "a=2", "a=3"}) set(gca,'Ytick',1:2,'YTickLabel',{"s=1", "s=2"}) % Plot values on the image x = repmat(1:size(QValues,2), size(QValues,1), 1); y = repmat(1:size(QValues,1), size(QValues,2), 1)'; QValuesStr = num2cell(QValues); QValuesStr = cellfun(@num2str, QValuesStr, UniformOutput=false); text(x(:), y(:), QValuesStr, HorizontalAlignment = "Center") end```

### Reference

[1] RS Sutton, AG Barto, "Reinforcement learning: An Introduction, 2nd edition", MIT press