Main Content

Backtest Strategies Using Deep Learning

Construct trading strategies using a deep learning model and then backtest the strategies using the Financial Toolbox™ backtesting framework. The example uses Deep Learning Toolbox™ to train a predictive model from a set of time series and demonstrates the steps necessary to convert the model output into trading signals. It builds a variety of trading strategies that backtest the signal data over a 5-year period.

This example illustrates the following workflow:

The focus of this example is on the workflow from data, to a trained model, to trading strategies, and finally to a backtest of the strategies. The deep learning model, its output, the subsequent trading signals, and the strategies are fictional. The intent is only to show the steps for developing and deploying this type of model.

Load Data

Load the historical price data. This data set contains daily spot prices for 12 different energy products ranging from 1986 to 2021 and consists of the following time series:

  • WTI — West Texas Intermediate light crude oil

  • Brent — Brent light crude oil

  • NaturalGas — Henry Hub natural gas

  • Propane — Mon Belvieu propane

  • Kerosene — US Gulf Coast kerosene-type jet fuel

  • HeatingOil — New York Harbor no. 2 heating oil

  • GulfRegular — US Gulf Coast conventional gasoline

  • LARegular — Los Angeles reformulated RBOB regular gasoline

  • NYRegular — New York Harbor conventional gasoline

  • GulfDiesel — US Gulf Coast ultra-low sulfur no. 2 diesel

  • LADiesel — Los Angeles ultra-low sulfur CARB diesel

  • NYDiesel — New York Harbor ultra-low sulfur no. 2 diesel

The source of this data is the US Energy Information Administration (Nov 2021).

priceData = load('energyPrices.mat','energyPrices');
priceData = priceData.energyPrices;
tail(priceData)
       Time         WTI     Brent    NaturalGas    Propane    Kerosene    HeatingOil    GulfRegular    LARegular    NYRegular    GulfDiesel    LADiesel    NYDiesel
    ___________    _____    _____    __________    _______    ________    __________    ___________    _________    _________    __________    ________    ________

    22-Oct-2021    84.53    85.43        5.1        1.485      2.312        2.414          2.481         2.671        2.571         2.49        2.559       2.558  
    25-Oct-2021    84.64    84.85       5.72        1.378      2.326        2.429          2.506         2.691        2.591        2.501        2.573       2.572  
    26-Oct-2021    85.64    85.11       5.59        1.398      2.339        2.436          2.552         2.636        2.591        2.511        2.598       2.573  
    27-Oct-2021    82.66    84.12       5.91        1.365      2.271        2.368          2.469         2.566        2.508        2.443        2.535       2.505  
    28-Oct-2021    82.78     83.4       5.68         1.36      2.278        2.363          2.471         2.583        2.518        2.448         2.57        2.51  
    29-Oct-2021     83.5     83.1       5.49        1.383      2.285        2.342          2.485         2.662        2.537        2.429        2.573       2.487  
    01-Nov-2021    84.08    84.51       5.22        1.385      2.301        2.364          2.457         2.597        2.494        2.445        2.599       2.511  
    02-Nov-2021    83.91    84.42       5.33        1.388        2.3        2.405          2.466         2.601        2.596        2.441        2.595        2.51  

Clean and Trim Data

The price datasets do not all start at the same time. Some datasets start later than others and have fewer data points. The following plot shows the time span for each price series.

seriesLifespanPlot(priceData)

Figure contains an axes object. The axes object with title Time Series Life Span contains 12 objects of type patch.

To avoid large spans of missing data, remove the series with shorter histories.

priceData = removevars(priceData,["NYDiesel","GulfDiesel","LARegular"]);

The remaining table variables contain sporadic missing elements (NaNs) due to holidays or other reasons. Missing data is handled in a variety of ways depending on the dataset. In some cases, it may be appropriate to interpolate or use the fillmissing function. In this example, you can remove the remaining NaN prices.

priceData = rmmissing(priceData);

Then, convert the price data to a return series using the tick2ret (Financial Toolbox) function. The final dataset consists of nine price series with daily data from 1997 through 2021.

returnData = tick2ret(priceData)
returnData=6167×9 timetable
       Time           WTI           Brent       NaturalGas     Propane       Kerosene     HeatingOil    GulfRegular    NYRegular      LADiesel 
    ___________    __________    ___________    __________    __________    __________    __________    ___________    __________    __________

    08-Jan-1997      0.011429     0.00080775    -0.0052356             0      0.012931      0.010974     0.0014347     -0.0028369    -0.0065789
    09-Jan-1997    -0.0094162      0.0020178         -0.05     -0.036969    -0.0085106    -0.0013569     -0.024355      -0.024182             0
    10-Jan-1997    -0.0057034      -0.024567      0.085873     0.0095969     -0.010014     -0.012228    -0.0088106     -0.0058309    -0.0092715
    13-Jan-1997     -0.036329      -0.033443      0.020408     -0.024715     -0.034682     -0.037139      -0.02963      -0.036657    -0.0066845
    15-Jan-1997      0.029762     -0.0042717         0.085     -0.048733      0.023952      0.021429     0.0030534      0.0060883     -0.013459
    16-Jan-1997     -0.019268              0      0.085253     -0.028689     -0.019006     -0.020979     0.0060883      0.0030257      0.020464
    17-Jan-1997    -0.0019646      -0.018876      -0.16985     -0.016878     -0.020864         -0.02    -0.0060514     -0.0075415             0
    20-Jan-1997     -0.011811    -0.00043725      -0.16624     -0.027897     -0.022831     -0.021866     0.0015221      -0.013678    -0.0040107
    21-Jan-1997     -0.011952      0.0052493     -0.082822     -0.004415     -0.014019     -0.020864      0.021277       0.012327    -0.0067114
    22-Jan-1997     -0.016129     -0.0021758      0.020067    -0.0044346      0.031596      0.019787      0.013393       0.016743             0
    23-Jan-1997     -0.022541              0     -0.029508    -0.0022272    -0.0061256      -0.01194     -0.035242       0.010479      0.040541
    24-Jan-1997             0     -0.0056694      -0.11486     -0.075893      0.010786     0.0060423     0.0060883     -0.0088889     0.0064935
    27-Jan-1997             0      -0.010526        0.1374      0.016908      0.012195      0.009009     0.0030257     -0.0029895      0.029677
    28-Jan-1997     0.0020964      0.0026596       0.02349    -0.0047506     0.0090361    -0.0089286     -0.010558      -0.011994       0.04386
    29-Jan-1997      0.025105       0.017241     -0.045902     -0.042959      0.059701      0.033033      0.042683       0.018209      0.014406
    30-Jan-1997      0.012245       0.018253     -0.017182             0      0.016901      0.023256    -0.0087719       0.026826      0.047337
      ⋮

Prepare Data for Training LSTM Model

Prepare and partition the dataset in order to train the LSTM model. The model uses a 30-day rolling window of trailing feature data and predicts the next day price changes for four of the assets: Brent crude oil, natural gas, propane, and kerosene.

% Model is trained using a 30-day rolling window to predict 1 day in the
% future.
historySize = 30;
futureSize = 1;

% Model predicts returns for oil, natural gas, propane, and kerosene.
outputVarName = ["Brent" "NaturalGas", "Propane" "Kerosene"];
numOutputs = numel(outputVarName);

% start_idx and end_idx are the index positions in the returnData
% timetable corresponding to the first and last date for making a prediction.
start_idx = historySize + 1;
end_idx   = height(returnData) - futureSize + 1;
numSamples = end_idx - start_idx + 1;

% The date_vector variable stores the dates for making predictions.
date_vector = returnData.Time(start_idx-1:end_idx-1);

Convert the returnData timetable to a numSamples-by-1 cell array. Each cell contains a numFeatures-by-seqLength matrix. The response variable is a numSamples-by-numResponses matrix.

network_features  = cell(numSamples,1);
network_responses = zeros(numSamples,numOutputs);

for j = 1:numSamples
    network_features{j} = (returnData(j:j+historySize-1,:).Variables)';
    network_responses(j,:) = ...
        (returnData(j+historySize:j+historySize+futureSize-1,outputVarName).Variables)';
end

Split the network_features and the network_responses into three parts: training, validation, and backtesting. Select the backtesting set as a set of sequential data points. The remainder of the data is randomly split into a training and a validation set. Use the validation set to prevent overfitting while training the model. The backtesting set is not used in the training process, but it is reserved for the final strategy backtest.

% Specify rows to use in the backtest (31-Dec-2015 to 2-Nov-2021).
backtest_start_idx = find(date_vector < datetime(2016,1,1),1,'last');
backtest_indices = backtest_start_idx:size(network_responses,1);

% Specify data reserved for the backtest.
Xbacktest = network_features(backtest_indices);
Tbacktest = network_responses(backtest_indices,:);

% Remove the backtest data.
network_features = network_features(1:backtest_indices(1)-1);
network_responses = network_responses(1:backtest_indices(1)-1,:);

% Partition the remaining data into training and validation sets.
rng('default');
cv_partition = cvpartition(size(network_features,1),'HoldOut',0.2);

% Training set
Xtraining = network_features(~cv_partition.test,:);
Ttraining = network_responses(~cv_partition.test,:);

% Validation set
Xvalidation = network_features(cv_partition.test,:);
Tvalidation = network_responses(cv_partition.test,:);

Define LSTM Network Architecture

Specify the network architecture as a series of layers. For more information on LSTM networks, see Long Short-Term Memory Networks. The Deep Network Designer is a powerful tool for designing deep learning models.

numFeatures = width(returnData);
numHiddenUnits_LSTM = 10;

layers_LSTM = [ ...
    sequenceInputLayer(numFeatures) 
    lstmLayer(numHiddenUnits_LSTM)
    layerNormalizationLayer
    lstmLayer(numHiddenUnits_LSTM)
    layerNormalizationLayer
    lstmLayer(numHiddenUnits_LSTM,'OutputMode','last')
    layerNormalizationLayer
    fullyConnectedLayer(numOutputs)
    regressionLayer];

Specify Training Options for LSTM Model

Next, you specify training options using the trainingOptions function. Many training options are available and their use varies depending on your use case. Use the Experiment Manager to explore different network architectures and sets of network hyperparamters.

max_epochs = 500;
mini_batch_size = 128;
learning_rate = 1e-4;

options_LSTM = trainingOptions('adam', ...
    'Plots','training-progress', ...
    'Verbose',0, ...
    'MaxEpochs',max_epochs, ...
    'MiniBatchSize',mini_batch_size, ...
    'Shuffle','every-epoch', ...
    'ValidationData',{Xvalidation,Tvalidation}, ...
    'ValidationFrequency',50, ...
    'ValidationPatience',10, ...
    'InitialLearnRate',learning_rate, ...
    'GradientThreshold',1);

Train LSTM Model

Train the LSTM network. Use the trainNetwork function to train the network until the network meets a stopping criteria. This process can take several minutes depending on the computer running the example. For more information on increasing the network training performance, see Scale Up Deep Learning in Parallel, on GPUs, and in the Cloud.

To avoid waiting for the network training, load the pretrained network by setting the doTrain flag to false. To train the network using trainNetwork, set the doTrain flag to true.

doTrain = false;

if doTrain
    % Train the LSTM network.
    net_LSTM = trainNetwork(Xtraining,Ttraining,layers_LSTM,options_LSTM);
else
    % Load the pretrained network.
    load lstmBacktestNetwork
end

Visualize Training Results

Visualize the results of the trained model by comparing the predicted values against the actual values from the validation set.

% Compare the actual returns to model predicted returns.
actual    = Tvalidation;
predicted = predict(net_LSTM,Xvalidation,'MiniBatchSize',mini_batch_size);
% Overlay histogram of actual vs. predicted returns for the validation set.
output_idx = 1;
figure;

[~,edges] = histcounts(actual(:,output_idx),100);
histogram(actual(:,output_idx),edges);
hold on
histogram(predicted(:,output_idx),edges)
hold off
xlabel('Percentage Change in Closing Price')
legend('Actual','Predicted')
title(sprintf('%s: Distribution of Returns, Actual vs. Predicted', outputVarName(output_idx)))

Figure contains an axes object. The axes object with title Brent: Distribution of Returns, Actual vs. Predicted contains 2 objects of type histogram. These objects represent Actual, Predicted.

% Display the predicted vs. actual daily returns for the validation set.
figure
plot(actual(:,output_idx))
hold on
plot(predicted(:,output_idx))
yline(0)
legend({'Actual','Predicted'})
title(sprintf('%s: Daily Returns, Actual vs. Predicted', outputVarName(output_idx)))

Figure contains an axes object. The axes object with title Brent: Daily Returns, Actual vs. Predicted contains 3 objects of type line, constantline. These objects represent Actual, Predicted.

% Examine the residuals.
residuals = actual(:,output_idx) - predicted(:,output_idx);
figure;
normplot(residuals);

Figure contains an axes object. The axes object with title Normal Probability Plot contains 3 objects of type line.

The actual data has fatter tails than the trained model predictions. The model predictions are not accurate, but the goal of this example is to show the workflow from loading data, to model development, to backtesting. A more sophisticated model with a larger and more varied set of training data is likely to have more predictive power.

Prepare Backtest Data

Use the predictions from the LSTM model to build the backtest strategies. You can post-process the model output in a number of ways to create trading signals. However, for this example, take the model regression output and convert it to a timetable.

Use predict with the trained network to generate model predictions over the backtest period.

backtestPred_LSTM = predict(net_LSTM,Xbacktest,'MiniBatchSize',mini_batch_size);

Convert the predictions to a trading signal timetable.

backtestSignalTT = timetable(date_vector(backtest_indices),backtestPred_LSTM);

Construct the prices timetable corresponding to the backtest time span. The backtest trades in and out of the four energy commodities. The prices timetable has the closing price for the day on which the prediction is made.

backtestPriceTT = priceData(date_vector(backtest_indices),outputVarName);

Set the risk-free rate to be 1% annualized. The backtest engine also supports setting the risk-free rate to a timetable containing the historical daily rates.

risk_free_rate = 0.01;

Create Backtest Strategies

Use backtestStrategy (Financial Toolbox) to create four trading strategies based on the signal indicators. The following trading strategies are intended as examples to show how to convert the trading signals into actionable asset allocation strategies that you can then backtest:

  • Long Only — Invest all capital across the assets with positive predicted return, proportional to their signal strength (predicted return).

  • Long Short — Invest capital across the assets, both long and short positions, proportional to their signal strength.

  • Best Bet — Invest all capital into the single asset with the highest predicted return.

  • Equal Weight — Rebalance each day to equal-weighted allocation.

% Specify 10 basis points as the trading cost.
tradingCosts = 0.001;

% Invest in long positions proportionally to their predicted return.
LongStrategy = backtestStrategy('LongOnly',@LongOnlyRebalanceFcn, ...
    'TransactionCosts',tradingCosts, ...
    'LookbackWindow',1);

% Invest in both long and short positions proportionally to their predicted returns.
LongShortStrategy = backtestStrategy('LongShort',@LongShortRebalanceFcn, ...
    'TransactionCosts',tradingCosts, ...
    'LookbackWindow',1);

% Invest 100% of capital into single asset with highest predicted returns.
BestBetStrategy = backtestStrategy('BestBet',@BestBetRebalanceFcn, ...
    'TransactionCosts',tradingCosts, ...
    'LookbackWindow',1);

% For comparison, invest in an equal-weighted (buy low and sell high) strategy.
equalWeightFcn = @(current_weights,prices,signal) ones(size(current_weights)) / numel(current_weights);
EqualWeightStrategy = backtestStrategy('EqualWeight',equalWeightFcn, ...
    'TransactionCosts',tradingCosts, ...
    'LookbackWindow',0);

Put the strategies into an array and then use backtestEngine (Financial Toolbox) to create the backtesting engine.

strategies = [LongStrategy LongShortStrategy BestBetStrategy EqualWeightStrategy];

bt = backtestEngine(strategies,'RiskFreeRate',risk_free_rate);

Run Backtest

Use runBacktest (Financial Toolbox) to backtest the strategies over the backtest range.

bt = runBacktest(bt,backtestPriceTT,backtestSignalTT)
bt = 
  backtestEngine with properties:

               Strategies: [1×4 backtestStrategy]
             RiskFreeRate: 0.0100
           CashBorrowRate: 0
          RatesConvention: "Annualized"
                    Basis: 0
    InitialPortfolioValue: 10000
           DateAdjustment: "Previous"
                NumAssets: 4
                  Returns: [1462×4 timetable]
                Positions: [1×1 struct]
                 Turnover: [1462×4 timetable]
                  BuyCost: [1462×4 timetable]
                 SellCost: [1462×4 timetable]

Examine Backtest Results

Use the summary (Financial Toolbox) and equityCurve (Financial Toolbox) functions to summarize and plot the backtest results. This model and its derivative trading strategies are not expected to be profitable in a realistic trading scenario. However, this example illustrates a workflow that should be useful for practitioners with more comprehensive data sets and more sophisticated models and strategies.

summary(bt)
ans=9×4 table
                       LongOnly     LongShort    BestBet     EqualWeight
                       _________    _________    ________    ___________

    TotalReturn           5.6962      8.3314       3.0248        4.8347 
    SharpeRatio         0.062549    0.071321     0.044571      0.056775 
    Volatility          0.025296    0.025795     0.031625      0.026712 
    AverageTurnover       0.1828     0.22754       0.2459     0.0095931 
    MaxTurnover          0.96059     0.97368            1           0.5 
    AverageReturn      0.0016216    0.001879     0.001449      0.001556 
    MaxDrawdown          0.73831     0.62935      0.81738       0.70509 
    AverageBuyCost        3.6293      7.1838       3.8139       0.20262 
    AverageSellCost       3.6225      7.2171       3.8071       0.19578 

figure;
equityCurve(bt)

Figure contains an axes object. The axes object with title Equity Curve contains 4 objects of type line. These objects represent LongOnly, LongShort, BestBet, EqualWeight.

Local Functions

function new_weights = LongOnlyRebalanceFcn(current_weights,pricesTT,signalTT) %#ok<INUSD> 
% Long only strategy, in proportion to the signal.

signal = signalTT.backtestPred_LSTM(end,:);

if any(0 < signal)
    signal(signal < 0) = 0;
    new_weights = signal / sum(signal);
else
    new_weights = zeros(size(current_weights));
end

end


function new_weights = LongShortRebalanceFcn(current_weights,pricesTT,signalTT) %#ok<INUSD> 
% Long/Short strategy, in proportion to the signal

signal = signalTT.backtestPred_LSTM(end,:);
abssum = sum(abs(signal));

if 0 < abssum
    new_weights = signal / abssum;
else
    new_weights = zeros(size(current_weights));
end

end


function new_weights = BestBetRebalanceFcn(current_weights,pricesTT,signalTT) %#ok<INUSD> 
% Best bet strategy, invest in the asset with the most upside.

signal = signalTT.backtestPred_LSTM(end,:);
new_weights = zeros(size(current_weights));
new_weights(signal == max(signal)) = 1;

end


function seriesLifespanPlot(priceData)
% Plot the lifespan of each time series.

% Specify all time series end on same day.
d2 = numel(priceData.Time);

% Plot the lifespan patch for each series.
numSeries = size(priceData,2);
for i = 1:numSeries
    % Find start date index.
    d1 = find(~isnan(priceData{:,i}),1,'first');
    % Plot patch.
    x = [d1 d1 d2 d2];
    y = i + [-0.4 0.4 0.4 -0.4];
    patch(x,y,[0 0.4470 0.7410])

    hold on
end
hold off

% Set the plot properties.
xlim([-100 d2]);
ylim([0.2 numSeries + 0.8]);

yticks(1:numSeries);
yticklabels(priceData.Properties.VariableNames');
flipud(gca);

years = 1990:5:2021;
xtick_idx = zeros(size(years));
for yidx = 1:numel(years)
    xtick_idx(yidx) = find(years(yidx) == year(priceData.Time),1,'first');
end
xticks(xtick_idx);
xticklabels(string(years));

title('Time Series Life Span');

end

See Also

| | | (Financial Toolbox) | (Financial Toolbox) | (Financial Toolbox) | (Financial Toolbox) | (Financial Toolbox)

Related Topics