RegressionGAM
Description
A RegressionGAM
object is a
generalized additive model (GAM) object for regression. It is an interpretable model
that explains a response variable using a sum of univariate and bivariate shape
functions.
You can predict responses for new observations by using the predict
function,
and plot the effect of each shape function on the prediction (response value) for an
observation by using the plotLocalEffects
function. For the full list of object functions for RegressionGAM
, see Object Functions.
Creation
Create a RegressionGAM
object by using fitrgam
. You can
specify both linear terms and interaction terms for predictors to include univariate shape
functions (predictor trees) and bivariate shape functions (interaction trees) in a trained
model, respectively.
You can update a trained model by using resume
or addInteractions
.
The
resume
function resumes training for the existing terms in a model.The
addInteractions
function adds interaction terms to a model that contains only linear terms.
Properties
GAM Properties
BinEdges
— Bin edges for numeric predictors
cell array of numeric vectors | []
This property is read-only.
Bin edges for numeric predictors, specified as a cell array of p numeric vectors, where p is the number of predictors. Each vector includes the bin edges for a numeric predictor. The element in the cell array for a categorical predictor is empty because the software does not bin categorical predictors.
The software bins numeric predictors only if you specify the 'NumBins'
name-value argument as a positive integer scalar when training a model with tree learners.
The BinEdges
property is empty if the 'NumBins'
value is empty (default).
You can reproduce the binned predictor data Xbinned
by using the
BinEdges
property of the trained model
mdl
.
X = mdl.X; % Predictor data
Xbinned = zeros(size(X));
edges = mdl.BinEdges;
% Find indices of binned predictors.
idxNumeric = find(~cellfun(@isempty,edges));
if iscolumn(idxNumeric)
idxNumeric = idxNumeric';
end
for j = idxNumeric
x = X(:,j);
% Convert x to array if x is a table.
if istable(x)
x = table2array(x);
end
% Group x into bins by using the discretize
function.
xbinned = discretize(x,[-inf; edges{j}; inf]);
Xbinned(:,j) = xbinned;
end
Xbinned
contains the bin indices, ranging from 1 to the number of bins, for numeric predictors.
Xbinned
values are 0 for categorical predictors. If
X
contains NaN
s, then the corresponding
Xbinned
values are NaN
s.
Data Types: cell
Interactions
— Interaction term indices
two-column matrix of positive integers | []
This property is read-only.
Interaction term indices, specified as a t
-by-2 matrix of positive
integers, where t
is the number of interaction terms in the model.
Each row of the matrix represents one interaction term and contains the column indexes
of the predictor data X
for the interaction term. If the model does
not include an interaction term, then this property is empty
([]
).
The software adds interaction terms to the model in the order of importance based on the p-values. Use this property to check the order of the interaction terms added to the model.
Data Types: double
Intercept
— Intercept term of model
numeric scalar
This property is read-only.
Intercept (constant) term of the model, which is the sum of the intercept terms in the predictor trees and interaction trees, specified as a numeric scalar.
Data Types: single
| double
IsStandardDeviationFit
— Flag indicating whether standard deviation model is fit
false
| true
Flag indicating whether a model for the standard deviation of the response
variable is fit, specified as false
or true
.
Specify the 'FitStandardDeviation'
name-value argument of
fitrgam
as true
to fit the model for the
standard deviation.
If IsStandardDeviationFit
is true
, then
you can evaluate the standard deviation at a new observation or at a training
observation of predictor values by using predict
or
resubPredict
, respectively. These functions also return the prediction
intervals of the response variable, evaluated at given observations.
Data Types: logical
ModelParameters
— Parameters used to train model
model parameter object
This property is read-only.
Parameters used to train the model, specified as a model parameter object.
ModelParameters
contains parameter values such as those for the
name-value arguments used to train the model. ModelParameters
does
not contain estimated parameters.
Access the fields of ModelParameters
by using dot notation. For example,
access the maximum number of decision splits per interaction tree by using
Mdl.ModelParameters.MaxNumSplitsPerInteraction
.
PairDetectionBinEdges
— Bin edges for interaction term detection
cell array of numeric vectors
This property is read-only.
Bin edges for interaction term detection for numeric predictors, specified as a cell array of p numeric vectors, where p is the number of predictors. Each vector includes the bin edges for a numeric predictor. The element in the cell array for a categorical predictor is empty because the software does not bin categorical predictors.
To speed up the interaction term detection process, the software bins numeric predictors into at most 8 equiprobable bins. The number of bins can be less than 8 if a predictor has fewer than 8 unique values.
Data Types: cell
ReasonForTermination
— Reason training stops
structure
This property is read-only.
Reason training the model stops, specified as a structure with two fields,
PredictorTrees
and InteractionTrees
.
Use this property to check if the model contains the specified number of trees for
each linear term ('NumTreesPerPredictor'
) and for each interaction term ('NumTreesPerInteraction'
). If the fitrgam
function terminates training before adding the specified number of trees, this
property contains the reason for the termination.
Data Types: struct
Other Regression Properties
CategoricalPredictors
— Categorical predictor indices
vector of positive integers | []
This property is read-only.
Categorical predictor
indices, specified as a vector of positive integers. CategoricalPredictors
contains index values indicating that the corresponding predictors are categorical. The index
values are between 1 and p
, where p
is the number of
predictors used to train the model. If none of the predictors are categorical, then this
property is empty ([]
).
Data Types: double
ExpandedPredictorNames
— Expanded predictor names
cell array of character vectors
This property is read-only.
Expanded predictor names, specified as a cell array of character vectors.
ExpandedPredictorNames
is the same as PredictorNames
for a generalized additive model.
Data Types: cell
NumObservations
— Number of observations
numeric scalar
This property is read-only.
Number of observations in the training data stored in X
and Y
, specified as a numeric scalar.
Data Types: double
PredictorNames
— Predictor variable names
cell array of character vectors
This property is read-only.
Predictor variable names, specified as a cell array of character vectors. The order of the
elements in PredictorNames
corresponds to the order in which the
predictor names appear in the training data.
Data Types: cell
ResponseName
— Response variable name
character vector
This property is read-only.
Response variable name, specified as a character vector.
Data Types: char
ResponseTransform
— Response transformation function
'none'
| function handle
Response transformation function, specified as 'none'
or a function handle.
ResponseTransform
describes how the software transforms raw
response values.
For a MATLAB® function or a function that you define, enter its function handle. For
example, you can enter Mdl.ResponseTransform =
@function
, where
function
accepts a numeric vector of the
original responses and returns a numeric vector of the same size containing the
transformed responses.
Data Types: char
| function_handle
RowsUsed
— Rows used in fitting
[]
| logical vector
This property is read-only.
Rows of the original training data used in fitting the RegressionGAM
model,
specified as a logical vector. This property is empty if all rows are used.
Data Types: logical
W
— Observation weights
numeric vector
This property is read-only.
Observation weights used to train the model, specified as an n-by-1 numeric
vector. n is the number of observations
(NumObservations
).
The software normalizes the observation weights specified in the 'Weights'
name-value argument so that the elements of W
sum up to 1.
Data Types: double
X
— Predictors
numeric matrix | table
This property is read-only.
Predictors used to train the model, specified as a numeric matrix or table.
Each row of X
corresponds to one observation, and each column corresponds to one variable.
Data Types: single
| double
| table
Y
— Response
numeric vector
This property is read-only.
Response, specified as a numeric vector.
Each row of Y
represents the observed response of the
corresponding row of X
.
Data Types: single
| double
Hyperparameter Optimization Properties
HyperparameterOptimizationResults
— Description of cross-validation optimization of hyperparameters
BayesianOptimization
object | table
This property is read-only.
Description of the cross-validation optimization of hyperparameters, specified as
a BayesianOptimization
object or a table of
hyperparameters and associated values. This property is nonempty when the 'OptimizeHyperparameters'
name-value argument of
fitrgam
is not 'none'
(default) when the
object is created. The value of HyperparameterOptimizationResults
depends on the setting of the Optimizer
field in the HyperparameterOptimizationOptions
structure of
fitrgam
when the object is created.
Value of Optimizer Option | Value of HyperparameterOptimizationResults |
---|---|
"bayesopt" (default) | Object of class BayesianOptimization |
"gridsearch" or "randomsearch" | Table of hyperparameters used, observed objective function values (cross-validation loss), and rank of observations from lowest (best) to highest (worst) |
Object Functions
Create CompactRegressionGAM
compact | Reduce size of machine learning model |
Create RegressionPartitionedGAM
crossval | Cross-validate machine learning model |
Update GAM
addInteractions | Add interaction terms to univariate generalized additive model (GAM) |
resume | Resume training of generalized additive model (GAM) |
Interpret Prediction
lime | Local interpretable model-agnostic explanations (LIME) |
partialDependence | Compute partial dependence |
plotLocalEffects | Plot local effects of terms in generalized additive model (GAM) |
plotPartialDependence | Create partial dependence plot (PDP) and individual conditional expectation (ICE) plots |
shapley | Shapley values |
Assess Predictive Performance on New Observations
Assess Predictive Performance on Training Data
resubPredict | Predict responses for training data using trained regression model |
resubLoss | Resubstitution regression loss |
Examples
Train Generalized Additive Model
Train a univariate GAM, which contains linear terms for predictors. Then, interpret the prediction for a specified data instance by using the plotLocalEffects
function.
Load the data set NYCHousing2015
.
load NYCHousing2015
The data set includes 10 variables with information on the sales of properties in New York City in 2015. This example uses these variables to analyze the sale prices (SALEPRICE
).
Preprocess the data set. Remove outliers, convert the datetime
array (SALEDATE
) to the month numbers, and move the response variable (SALEPRICE
) to the last column.
idx = isoutlier(NYCHousing2015.SALEPRICE); NYCHousing2015(idx,:) = []; NYCHousing2015.SALEDATE = month(NYCHousing2015.SALEDATE); NYCHousing2015 = movevars(NYCHousing2015,'SALEPRICE','After','SALEDATE');
Display the first three rows of the table.
head(NYCHousing2015,3)
BOROUGH NEIGHBORHOOD BUILDINGCLASSCATEGORY RESIDENTIALUNITS COMMERCIALUNITS LANDSQUAREFEET GROSSSQUAREFEET YEARBUILT SALEDATE SALEPRICE _______ ____________ ____________________________ ________________ _______________ ______________ _______________ _________ ________ _________ 2 {'BATHGATE'} {'01 ONE FAMILY DWELLINGS'} 1 0 4750 2619 1899 8 0 2 {'BATHGATE'} {'01 ONE FAMILY DWELLINGS'} 1 0 4750 2619 1899 8 0 2 {'BATHGATE'} {'01 ONE FAMILY DWELLINGS'} 1 1 1287 2528 1899 12 0
Train a univariate GAM for the sale prices. Specify the variables for BOROUGH
, NEIGHBORHOOD
, BUILDINGCLASSCATEGORY
, and SALEDATE
as categorical predictors.
Mdl = fitrgam(NYCHousing2015,'SALEPRICE','CategoricalPredictors',[1 2 3 9])
Mdl = RegressionGAM PredictorNames: {'BOROUGH' 'NEIGHBORHOOD' 'BUILDINGCLASSCATEGORY' 'RESIDENTIALUNITS' 'COMMERCIALUNITS' 'LANDSQUAREFEET' 'GROSSSQUAREFEET' 'YEARBUILT' 'SALEDATE'} ResponseName: 'SALEPRICE' CategoricalPredictors: [1 2 3 9] ResponseTransform: 'none' Intercept: 3.7518e+05 IsStandardDeviationFit: 0 NumObservations: 83517
Mdl
is a RegressionGAM
model object. The model display shows a partial list of the model properties. To view the full list of properties, double-click the variable name Mdl
in the Workspace. The Variables editor opens for Mdl
. Alternatively, you can display the properties in the Command Window by using dot notation. For example, display the estimated intercept (constant) term of Mdl
.
Mdl.Intercept
ans = 3.7518e+05
Predict the sale price for the first observation of the training data, and plot the local effects of the terms in Mdl
on the prediction.
yFit = predict(Mdl,NYCHousing2015(1,:))
yFit = 4.4421e+05
plotLocalEffects(Mdl,NYCHousing2015(1,:))
The predict
function predicts the sale price for the first observation as 4.4421e5
. The plotLocalEffects
function creates a horizontal bar graph that shows the local effects of the terms in Mdl
on the prediction. Each local effect value shows the contribution of each term to the predicted sale price.
Train GAM with Interaction Terms
Train a generalized additive model that contains linear and interaction terms for predictors in three different ways:
Specify the interaction terms using the
formula
input argument.Specify the
'Interactions'
name-value argument.Build a model with linear terms first and add interaction terms to the model by using the
addInteractions
function.
Load the carbig
data set, which contains measurements of cars made in the 1970s and early 1980s.
load carbig
Create a table that contains the predictor variables (Acceleration
, Displacement
, Horsepower
, and Weight
) and the response variable (MPG
).
tbl = table(Acceleration,Displacement,Horsepower,Weight,MPG);
Specify formula
Train a GAM that contains the four linear terms (Acceleration
, Displacement
, Horsepower
, and Weight
) and two interaction terms (Acceleration*Displacement
and Displacement*Horsepower
). Specify the terms using a formula in the form 'Y ~ terms'
.
Mdl1 = fitrgam(tbl,'MPG ~ Acceleration + Displacement + Horsepower + Weight + Acceleration:Displacement + Displacement:Horsepower');
The function adds interaction terms to the model in the order of importance. You can use the Interactions
property to check the interaction terms in the model and the order in which fitrgam
adds them to the model. Display the Interactions
property.
Mdl1.Interactions
ans = 2×2
2 3
1 2
Each row of Interactions
represents one interaction term and contains the column indexes of the predictor variables for the interaction term.
Specify 'Interactions'
Pass the training data (tbl
) and the name of the response variable in tbl
to fitrgam
, so that the function includes the linear terms for all the other variables as predictors. Specify the 'Interactions'
name-value argument using a logical matrix to include the two interaction terms, x1*x2
and x2*x3
.
Mdl2 = fitrgam(tbl,'MPG','Interactions',logical([1 1 0 0; 0 1 1 0])); Mdl2.Interactions
ans = 2×2
2 3
1 2
You can also specify 'Interactions'
as the number of interaction terms or as 'all'
to include all available interaction terms. Among the specified interaction terms, fitrgam
identifies those whose p-values are not greater than the 'MaxPValue'
value and adds them to the model. The default 'MaxPValue'
is 1 so that the function adds all specified interaction terms to the model.
Specify 'Interactions','all'
and set the 'MaxPValue'
name-value argument to 0.05.
Mdl3 = fitrgam(tbl,'MPG','Interactions','all','MaxPValue',0.05);
Warning: Model does not include interaction terms because all interaction terms have p-values greater than the 'MaxPValue' value, or the software was unable to improve the model fit.
Mdl3.Interactions
ans = 0x2 empty double matrix
Mdl3
includes no interaction terms, which implies one of the following: all interaction terms have p-values greater than 0.05, or adding the interaction terms does not improve the model fit.
Use addInteractions
Function
Train a univariate GAM that contains linear terms for predictors, and then add interaction terms to the trained model by using the addInteractions
function. Specify the second input argument of addInteractions
in the same way you specify the 'Interactions'
name-value argument of fitrgam
. You can specify the list of interaction terms using a logical matrix, the number of interaction terms, or 'all'
.
Specify the number of interaction terms as 3 to add the three most important interaction terms to the trained model.
Mdl4 = fitrgam(tbl,'MPG');
UpdatedMdl4 = addInteractions(Mdl4,3);
UpdatedMdl4.Interactions
ans = 3×2
2 3
1 2
3 4
Mdl4
is a univariate GAM, and UpdatedMdl4
is an updated GAM that contains all the terms in Mdl4
and three additional interaction terms.
Resume Training Interaction Trees in GAM
Train a regression GAM that contains both linear and interaction terms. Specify to train the interaction terms for a small number of iterations. After training the interaction terms for more iterations, compare the resubstitution loss.
Load the carbig
data set, which contains measurements of cars made in the 1970s and early 1980s.
load carbig
Specify Acceleration
, Displacement
, Horsepower
, and Weight
as the predictor variables (X
) and MPG
as the response variable (Y
).
X = [Acceleration,Displacement,Horsepower,Weight]; Y = MPG;
Train a GAM that includes all available linear and interaction terms in X
. Specify the number of trees per interaction term as 2. fitrgam
iterates the boosting algorithm 300 times (default) for linear terms, and iterates the algorithm the specified number of iterations for interaction terms. For each boosting iteration, the function adds one tree per linear term or one tree per interaction term. Specify 'Verbose'
as 1 to display diagnostic messages at every 10 iterations.
Mdl = fitrgam(X,Y,'Interactions','all','NumTreesPerInteraction',2,'Verbose',1);
|========================================================| | Type | NumTrees | Deviance | RelTol | LearnRate | |========================================================| | 1D| 0| 2.4432e+05| - | - | | 1D| 1| 9507.4| Inf| 1| | 1D| 10| 4470.6| 0.00025206| 1| | 1D| 20| 3895.3| 0.00011448| 1| | 1D| 30| 3617.7| 3.5365e-05| 1| | 1D| 40| 3402.5| 3.7992e-05| 1| | 1D| 50| 3257.1| 2.4983e-05| 1| | 1D| 60| 3131.8| 2.3873e-05| 1| | 1D| 70| 3019.8| 2.2967e-05| 1| | 1D| 80| 2925.9| 2.8071e-05| 1| | 1D| 90| 2845.3| 1.6811e-05| 1| | 1D| 100| 2772.7| 1.852e-05| 1| | 1D| 110| 2707.8| 1.6754e-05| 1| | 1D| 120| 2649.8| 1.651e-05| 1| | 1D| 130| 2596.6| 1.1723e-05| 1| | 1D| 140| 2547.4| 1.813e-05| 1| | 1D| 150| 2501.1| 1.8659e-05| 1| | 1D| 160| 2455.7| 1.386e-05| 1| | 1D| 170| 2416.9| 1.0615e-05| 1| | 1D| 180| 2377.2| 8.534e-06| 1| | 1D| 190| 2339| 7.6771e-06| 1| | 1D| 200| 2303.3| 9.5866e-06| 1| | 1D| 210| 2270.7| 8.4276e-06| 1| | 1D| 220| 2240.1| 8.5778e-06| 1| | 1D| 230| 2209.2| 9.6761e-06| 1| | 1D| 240| 2178.7| 7.0622e-06| 1| | 1D| 250| 2150.3| 8.3082e-06| 1| | 1D| 260| 2122.3| 7.9542e-06| 1| | 1D| 270| 2097.7| 7.6328e-06| 1| | 1D| 280| 2070.4| 9.4322e-06| 1| | 1D| 290| 2044.3| 7.5722e-06| 1| | 1D| 300| 2019.7| 6.6719e-06| 1| |========================================================| | Type | NumTrees | Deviance | RelTol | LearnRate | |========================================================| | 2D| 0| 2019.7| - | - | | 2D| 1| 1795.5| 0.0005975| 1| | 2D| 2| 1523.4| 0.0010079| 1|
To check whether fitrgam
trains the specified number of trees, display the ReasonForTermination
property of the trained model and view the displayed messages.
Mdl.ReasonForTermination
ans = struct with fields:
PredictorTrees: 'Terminated after training the requested number of trees.'
InteractionTrees: 'Terminated after training the requested number of trees.'
Compute the regression loss for the training data.
resubLoss(Mdl)
ans = 3.8277
Resume training the model for another 100 iterations. Because Mdl
contains both linear and interaction terms, the resume
function resumes training for the interaction terms and adds more trees for them (interaction trees).
UpdatedMdl = resume(Mdl,100);
|========================================================| | Type | NumTrees | Deviance | RelTol | LearnRate | |========================================================| | 2D| 0| 1523.4| - | - | | 2D| 1| 1363.9| 0.00039695| 1| | 2D| 10| 594.04| 8.0295e-05| 1| | 2D| 20| 359.44| 4.3201e-05| 1| | 2D| 30| 238.51| 2.6869e-05| 1| | 2D| 40| 153.98| 2.6271e-05| 1| | 2D| 50| 91.464| 8.0936e-06| 1| | 2D| 60| 61.882| 3.8528e-06| 1| | 2D| 70| 43.206| 5.9888e-06| 1|
UpdatedMdl.ReasonForTermination
ans = struct with fields:
PredictorTrees: 'Terminated after training the requested number of trees.'
InteractionTrees: 'Unable to improve the model fit.'
resume
terminates training when adding more trees does not improve the deviance of the model fit.
Compute the regression loss using the updated model.
resubLoss(UpdatedMdl)
ans = 0.0944
The regression loss decreases after resume
updates the model with more iterations.
More About
Generalized Additive Model (GAM) for Regression
A generalized additive model (GAM) is an interpretable model that explains a response variable using a sum of univariate and bivariate shape functions of predictors.
fitrgam
uses a boosted tree as a shape function for each predictor and, optionally, each pair of predictors; therefore, the function can capture a nonlinear relation between a predictor and the response variable. Because contributions of individual shape functions to the prediction (response value) are well separated, the model is easy to interpret.
The standard GAM uses a univariate shape function for each predictor.
where y is a response variable that follows the normal distribution with mean μ and standard deviation σ. g(μ) is an identity link function, and c is an intercept (constant) term. fi(xi) is a univariate shape function for the ith predictor, which is a boosted tree for a linear term for the predictor (predictor tree).
You can include interactions between predictors in a model by adding bivariate shape functions of important interaction terms to the model.
where fij(xixj) is a bivariate shape function for the ith and jth predictors, which is a boosted tree for an interaction term for the predictors (interaction tree).
fitrgam
finds important interaction terms based on the p-values of F-tests. For details, see Interaction Term Detection.
If you specify 'FitStandardDeviation'
of fitrgam
as
false
(default), then fitrgam
trains a model for
the mean μ. If you specify 'FitStandardDeviation'
as
true
, then fitrgam
trains an additional model
for the standard deviation σ and sets the
IsStandardDeviationFit
property of the GAM object to
true
.
References
[1] Lou, Yin, Rich Caruana, and Johannes Gehrke. "Intelligible Models for Classification and Regression." Proceedings of the 18th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD ’12). Beijing, China: ACM Press, 2012, pp. 150–158.
[2] Lou, Yin, Rich Caruana, Johannes Gehrke, and Giles Hooker. "Accurate Intelligible Models with Pairwise Interactions." Proceedings of the 19th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (KDD ’13) Chicago, Illinois, USA: ACM Press, 2013, pp. 623–631.
Version History
Introduced in R2021a
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Select a Web Site
Choose a web site to get translated content where available and see local events and offers. Based on your location, we recommend that you select: .
You can also select a web site from the following list
How to Get Best Site Performance
Select the China site (in Chinese or English) for best site performance. Other MathWorks country sites are not optimized for visits from your location.
Americas
- América Latina (Español)
- Canada (English)
- United States (English)
Europe
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
Asia Pacific
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)