dlupdate
Update parameters using custom function
Syntax
Description
updates the learnable parameters of the netUpdated = dlupdate(fun,net)dlnetwork object
net by evaluating the function fun with each
learnable parameter as an input. fun is a function handle to a function
that takes one parameter array as an input argument and returns an updated parameter array.
Examples
Perform L1 regularization on a structure of parameter gradients.
Create the sample input data.
dlX = dlarray(rand(100,100,3),'SSC');Initialize the learnable parameters for the convolution operation.
params.Weights = dlarray(rand(10,10,3,50)); params.Bias = dlarray(rand(50,1));
Calculate the gradients for the convolution operation using the helper function convGradients, defined at the end of this example.
gradients = dlfeval(@convGradients,dlX,params);
Define the regularization factor.
L1Factor = 0.001;
Create an anonymous function that regularizes the gradients. By using an anonymous function to pass a scalar constant to the function, you can avoid having to expand the constant value to the same size and structure as the parameter variable.
L1Regularizer = @(grad,param) grad + L1Factor.*sign(param);
Use dlupdate to apply the regularization function to each of the gradients.
gradients = dlupdate(L1Regularizer,gradients,params);
The gradients in grads are now regularized according to the function L1Regularizer.
convGradients Function
The convGradients helper function takes the learnable parameters of the convolution operation and a mini-batch of input data dlX, and returns the gradients with respect to the learnable parameters.
function gradients = convGradients(dlX,params) dlY = dlconv(dlX,params.Weights,params.Bias); dlY = sum(dlY,'all'); gradients = dlgradient(dlY,params); end
Use dlupdate to train a network using a custom update function that implements the stochastic gradient descent algorithm (without momentum).
Load Training Data
Load the digits training data.
[XTrain,TTrain] = digitTrain4DArrayData; classes = categories(TTrain); numClasses = numel(classes);
Define the Network
Define the network architecture and specify the average image value using the Mean option in the image input layer.
layers = [
imageInputLayer([28 28 1],'Mean',mean(XTrain,4))
convolution2dLayer(5,20)
reluLayer
convolution2dLayer(3,20,'Padding',1)
reluLayer
convolution2dLayer(3,20,'Padding',1)
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer];Create a dlnetwork object from the layer array.
net = dlnetwork(layers);
Define Model Loss Function
Create the helper function modelLoss, listed at the end of this example. The function takes a dlnetwork object and a mini-batch of input data with corresponding labels, and returns the loss and the gradients of the loss with respect to the learnable parameters.
Define Stochastic Gradient Descent Function
Create the helper function sgdFunction, listed at the end of this example. The function takes the parameters and the gradients of the loss with respect to the parameters, and returns the updated parameters using the stochastic gradient descent algorithm, expressed as
where is the iteration number, is the learning rate, is the parameter vector, and is the loss function.
Specify Training Options
Specify the options to use during training.
miniBatchSize = 128; numEpochs = 30; numObservations = numel(TTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Specify the learning rate.
learnRate = 0.01;
Train Network
Calculate the total number of iterations for the training progress monitor.
numIterations = numEpochs * numIterationsPerEpoch;
Initialize the TrainingProgressMonitor object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.
monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");
Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters by calling dlupdate with the function sgdFunction defined at the end of this example. At the end of each epoch, display the training progress.
Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
iteration = 0; epoch = 0; while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Shuffle data. idx = randperm(numel(TTrain)); XTrain = XTrain(:,:,:,idx); TTrain = TTrain(idx); i = 0; while i < numIterationsPerEpoch && ~monitor.Stop i = i + 1; iteration = iteration + 1; % Read mini-batch of data and convert the labels to dummy % variables. idx = (i-1)*miniBatchSize+1:i*miniBatchSize; X = XTrain(:,:,:,idx); T = zeros(numClasses, miniBatchSize,"single"); for c = 1:numClasses T(c,TTrain(idx)==classes(c)) = 1; end % Convert mini-batch of data to dlarray. X = dlarray(single(X),"SSCB"); % If training on a GPU, then convert data to a gpuArray. if canUseGPU X = gpuArray(X); end % Evaluate the model loss and gradients using dlfeval and the % modelLoss function. [loss,gradients] = dlfeval(@modelLoss,net,X,T); % Update the network parameters using the SGD algorithm defined in % the sgdFunction helper function. updateFcn = @(net,gradients) sgdFunction(net,gradients,learnRate); net = dlupdate(updateFcn,net,gradients); % Update the training progress monitor. recordMetrics(monitor,iteration,Loss=loss); updateInfo(monitor,Epoch=epoch + " of " + numEpochs); monitor.Progress = 100 * iteration/numIterations; end end

Test Network
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.
[XTest,TTest] = digitTest4DArrayData;
Convert the data to a dlarray with the dimension format "SSCB" (spatial, spatial, channel, batch). For GPU prediction, also convert the data to a gpuArray.
XTest = dlarray(XTest,"SSCB"); if canUseGPU XTest = gpuArray(XTest); end
To classify images using a dlnetwork object, use the predict function and find the classes with the highest scores.
YTest = predict(net,XTest); [~,idx] = max(extractdata(YTest),[],1); YTest = classes(idx);
Evaluate the classification accuracy.
accuracy = mean(YTest==TTest)
accuracy = 0.9040
Model Loss Function
The helper function modelLoss takes a dlnetwork object net and a mini-batch of input data X with corresponding labels T, and returns the loss and the gradients of the loss with respect to the learnable parameters in net. To compute the gradients automatically, use the dlgradient function.
function [loss,gradients] = modelLoss(net,X,T) Y = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables); end
Stochastic Gradient Descent Function
The helper function sgdFunction takes the learnable parameters parameters, the gradients of the loss with respect to the learnable parameters, and the learning rate learnRate, and returns the updated parameters using the stochastic gradient descent algorithm, expressed as
where is the iteration number, is the learning rate, is the parameter vector, and is the loss function.
function parameters = sgdFunction(parameters,gradients,learnRate) parameters = parameters - learnRate .* gradients; end
Input Arguments
Network, specified as a dlnetwork object.
The function updates the Learnables property of the
dlnetwork object. net.Learnables is a table with
three variables:
Layer— Layer name, specified as a string scalar.Parameter— Parameter name, specified as a string scalar.Value— Value of parameter, specified as a cell array containing adlarray.
Network learnable parameters, specified as a dlarray, a numeric
array, a cell array, a structure, or a table.
If you specify params as a table, it must contain the following
three variables.
Layer— Layer name, specified as a string scalar.Parameter— Parameter name, specified as a string scalar.Value— Value of parameter, specified as a cell array containing adlarray.
You can specify params as a container of learnable parameters
for your network using a cell array, structure, or table, or nested cell arrays or
structures. The learnable parameters inside the cell array, structure, or table must be
dlarray or numeric values of data type double or
single.
The input argument A1,...,An must be provided with exactly the
same data type, ordering, and fields (for structures) or variables (for tables) as
params.
Data Types: single | double | struct | table | cell
Additional input arguments to fun, specified as
dlarray objects, numeric arrays, cell arrays, structures, or tables
with a Value variable.
The exact form of A1,...,An depends on the input network or
learnable parameters. The following table shows the required format for
A1,...,An for possible inputs to
dlupdate.
| Input | Learnable Parameters | A1,...,An |
|---|---|---|
net | Table net.Learnables containing
Layer, Parameter, and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray. | Table with the same data type, variables, and ordering as
net.Learnables. A1,...,An must have a
Value variable consisting of cell arrays that contain the
additional input arguments for the function fun to apply to
each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params. |
| Numeric array | Numeric array with the same data type and ordering as
params. | |
| Cell array | Cell array with the same data types, structure, and ordering as
params. | |
| Structure | Structure with the same data types, fields, and ordering as
params. | |
Table with Layer, Parameter, and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray. | Table with the same data types, variables and ordering as
params. A1,...,An must have a
Value variable consisting of cell arrays that contain the
additional input argument for the function fun to apply to
each learnable parameter. |
Output Arguments
Network, returned as a dlnetwork object.
The function updates the Learnables property of the
dlnetwork object.
Updated network learnable parameters, returned as a dlarray, a
numeric array, a cell array, a structure, or a table with a Value
variable containing the updated learnable parameters of the network.
Additional output arguments from the function fun, where
fun is a function handle to a function that returns multiple
outputs, returned as dlarray objects, numeric arrays, cell arrays,
structures, or tables with a Value variable.
The exact form of X1,...,Xm depends on the input network or
learnable parameters. The following table shows the returned format of
X1,...,Xm for possible inputs to
dlupdate.
| Input | Learnable parameters | X1,...,Xm |
|---|---|---|
net | Table net.Learnables containing
Layer, Parameter, and
Value variables. The Value variable
consists of cell arrays that contain each learnable parameter as a
dlarray. | Table with the same data type, variables, and ordering as
net.Learnables. X1,...,Xm has a
Value variable consisting of cell arrays that contain the
additional output arguments of the function fun applied to
each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering as
params. |
| Numeric array | Numeric array with the same data type and ordering as
params. | |
| Cell array | Cell array with the same data types, structure, and ordering as
params. | |
| Structure | Structure with the same data types, fields, and ordering as
params. | |
Table with Layer, Parameter, and
Value variables. The Value variable must
consist of cell arrays that contain each learnable parameter as a
dlarray. | Table with the same data types, variables. and ordering as
params. X1,...,Xm has a
Value variable consisting of cell arrays that contain the
additional output argument of the function fun applied to
each learnable parameter. |
Extended Capabilities
The dlupdate function
supports GPU array input with these usage notes and limitations:
When at least one of the following input arguments is a
gpuArrayor adlarraywith underlying data of typegpuArray, this function runs on the GPU.paramsA1,...,An
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2019b
See Also
dlnetwork | dlarray | adamupdate | rmspropupdate | sgdmupdate | dlgradient | dljacobian | dldivergence | dllaplacian | dlfeval
MATLAB Command
You clicked a link that corresponds to this MATLAB command:
Run the command by entering it in the MATLAB Command Window. Web browsers do not support MATLAB commands.
Web サイトの選択
Web サイトを選択すると、翻訳されたコンテンツにアクセスし、地域のイベントやサービスを確認できます。現在の位置情報に基づき、次のサイトの選択を推奨します:
また、以下のリストから Web サイトを選択することもできます。
最適なサイトパフォーマンスの取得方法
中国のサイト (中国語または英語) を選択することで、最適なサイトパフォーマンスが得られます。その他の国の MathWorks のサイトは、お客様の地域からのアクセスが最適化されていません。
南北アメリカ
- América Latina (Español)
- Canada (English)
- United States (English)
ヨーロッパ
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)