Tips on Importing Models from TensorFlow, PyTorch, and ONNX
This topic provides tips on how to overcome common hurdles in importing a model from TensorFlow™, PyTorch®, or ONNX™ as a MATLAB® network or layer graph. You can read each section of this topic independently. For a high-level overview of the import and export functions in Deep Learning Toolbox™, see Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX.
Import Functions of Deep Learning Toolbox
This table lists the Deep Learning Toolbox import functions. Use these functions to import networks or layer graphs from TensorFlow, PyTorch, and ONNX.
You must have the relevant support package to run these import functions. If the support package is not installed, each function provides a download link to the corresponding support package in the Add-On Explorer. A recommended practice is to download the support package to the default location for the version of MATLAB you are running. You can also directly download the support packages from File Exchange.
Recommended Functions to Import TensorFlow Models
The Deep Learning Toolbox Converter for TensorFlow Models support package offers these functions:
importTensorFlowNetwork
andimportKerasNetwork
— Import a TensorFlow model as a network.importTensorFlowLayers
andimportKerasLayers
— Import a TensorFlow model as a layer graph.
Note
The importTensorFlowNetwork
and
importTensorFlowLayers
functions are recommended over the
importKerasNetwork
and
importKerasLayers
functions.
This table compares the Deep Learning Toolbox Converter for TensorFlow Models functions. The comparison highlights the reasons that the
importTensorFlowNetwork
and
importTensorFlowLayers
functions are recommended over the
importKerasNetwork
and importKerasLayers
functions.
Features | importTensorFlowNetwork and
importTensorFlowLayers | importKerasNetwork and
importKerasLayers |
---|---|---|
Automatically generates custom layers | Yes | No |
Supports TensorFlow 2 | Yes | Limited |
Supports SavedModel format | Yes | No |
Can import network as dlnetwork (or
LayerGraph compatible with
dlnetwork ) | Yes | No |
For more information on the advantages of migrating from TensorFlow 1 to TensorFlow 2, see Migrate
from TensorFlow 1.x to TensorFlow 2. For more information on the TensorFlow versions that the import functions support, see Limitations
(importTensorFlowNetwork
and
importTensorFlowLayers)
and Limitations
(importKerasNetwork
and
importKerasLayers
).
To import a TensorFlow model that is in the HDF5 format, instead of using
importKerasNetwork
to import the model as a Deep Learning Toolbox network, convert the TensorFlow model to the SavedModel
format and use the
importTensorFlowNetwork
function.
Autogenerated Custom Layers
The
importTensorFlowNetwork
andimportTensorFlowLayers
functions can automatically generate custom layers when you import custom TensorFlow layers or when the software cannot convert TensorFlow layers into equivalent built-in MATLAB layers. For an example, see Import TensorFlow Network with Autogenerated Custom Layers. For a list of layers for which the software supports conversion, see TensorFlow-Keras Layers Supported for Conversion into Built-In MATLAB Layers.The
importONNXNetwork
andimportONNXLayers
functions can also generate custom layers when the software cannot convert ONNX operators into equivalent built-in MATLAB layers. For an example, see Import ONNX Network with Autogenerated Custom Layers. For a list of layers for which the software supports conversion, see ONNX Operators Supported for Conversion into Built-In MATLAB Layers.In rare cases, when
importONNXNetwork
andimportONNXLayers
cannot import an ONNX model into layers, you can useimportONNXFunction
to import the model as a function. For more information on how to select an ONNX import function, see Select Function to Import ONNX Pretrained Network.The
importNetworkFromPyTorch
function imports a PyTorch layer into MATLAB by trying these steps in order:The function tries to import the PyTorch layer as a built-in MATLAB layer. For more information, see Conversion of PyTorch Layers.
The function tries to import the PyTorch layer as a built-in MATLAB function. For more information, see Conversion of PyTorch Layers.
The function tries to import the PyTorch layer as a custom layer. For an example, see Import Network from PyTorch and Find Generated Custom Layers.
The function imports the PyTorch layer as a custom layer with a placeholder function. For more information, see Placeholder Functions.
The importTensorFlowNetwork
,
importTensorFlowLayers
, importONNXNetwork
,
importONNXLayers
, and
importNetworkFromPyTorch
functions save the automatically
generated custom layers to a package in the current folder. For more information on the
custom layers package, see the PackageName
name-value argument of
each function.
Placeholder Layers
The importTensorFlowLayers
and
importONNXLayers
functions insert placeholder layers in the
place of TensorFlow layers or ONNX operators when these conditions apply:
The function cannot convert the TensorFlow layers or ONNX operators to built-in MATLAB layers. For lists of TensorFlow layers and ONNX operators for which the functions support conversion, see TensorFlow-Keras Layers Supported for Conversion into Built-In MATLAB Layers and ONNX Operators Supported for Conversion into Built-In MATLAB Layers, respectively.
The function cannot generate custom layers in place of the TensorFlow layers or ONNX operators that the function cannot convert to built-in MATLAB layers.
If these conditions apply, the importTensorFlowNetwork
and
importONNXNetwork
functions return an error. These flowcharts
describe these workflows.
To find the names and indices of the placeholder layers in the layer graph, use the
findPlaceholderLayers
function. You can then replace a placeholder layer
with a built-in MATLAB layer, custom layer, or functionLayer
object. For more information about custom layers, see Define Custom Deep Learning Layers. For an example with
a functionLayer
object, see Replace Unsupported Keras Layer with Function Layer. To replace a layer,
use replaceLayer
. For an example, see Import ONNX Model as Layer Graph with Placeholder Layers.
The importNetworkFromPyTorch
function generates a custom layer
with a placeholder function instead of a placeholder layer. For more information, see
Placeholder Functions.
Input Dimension Ordering
The dimension ordering of the input data differs between Deep Learning Toolbox, TensorFlow, and ONNX. This table compares input dimension ordering between platforms for different input types.
Input Type | Dimension Ordering | |||
---|---|---|---|---|
MATLAB | TensorFlow | PyTorch | ONNX | |
Features | CN | NC | NC | NC |
2-D image | HWCN | NHWC | NCHW | NCHW |
3-D image | HWDCN | NHWDC | NCDHW | NCHWD |
Vector sequence | CSN | NSC | SNC | NSC |
2-D image sequence | HWCSN | NSWHC | NCSHW | NSCHW |
3-D image sequence | HWDCSN | NSWHDC | NCSDHW | NSCHWD |
Variable names in the table:
N — Number of observations
C — Number of features or channels
H — Height of images
W — Width of images
D — Depth of images
S — Sequence length
Data Formats for Prediction with dlnetwork
The importTensorFlowNetwork
and
importONNXNetwork
functions can import a TensorFlow or ONNX model as a DAGNetwork
or
dlnetwork
object. Specify the type
of imported network by setting the TargetNetwork
name-value
argument. For more details, see TargetNetwork
for importTensorFlowNetwork
and
TargetNetwork
for importONNXNetwork
.
The importNetworkFromPyTorch
function imports a PyTorch model as an uninitialized dlnetwork
object. Before you
use the network, do one of the following:
Add an input layer to the imported network and initialize the network by using the
addInputLayer
function. For an example, see Import Network from PyTorch and Add Input Layer.Initialize the network by using the
initialize
function and set the appropriate format. For an example, see Import Network from PyTorch and Initialize.
To predict using a dlnetwork
object, you must convert the input data
to a dlarray
object with the appropriate data format. For an example, see Import TensorFlow Network as dlnetwork to Classify Image. Use this table to
choose the right data format for each input type and layer.
Input Type | Input Layer ** | Input Format * |
---|---|---|
Features | featureInputLayer | CB |
2-D image | imageInputLayer | SSCB |
3-D image | image3dInputLayer | SSCB |
Vector sequence | sequenceInputLayer | CBT |
2-D image sequence | sequenceInputLayer | SSCBT |
3-D image sequence | sequenceInputLayer | SSSCBT |
* In Deep Learning Toolbox, each data format must be one of these labels:
S
— SpatialC
— ChannelB
— Batch observationsT
— Time or sequenceU
— Unspecified
** A dlnetwork
object does not require an input layer. The network
can infer the input layer type from the input data format.
For more information on data formats, see dlarray
.
Input Data Preprocessing
Preprocessing data is a common first step in the deep learning workflow to prepare data in a format that the network can accept. You must preprocess the input data in the same way as the training data.
The input layer of the pretrained deep learning networks available in Deep Learning Toolbox performs some of the input data preprocessing. For example, the input
layer of the pretrained mobilenetv2
normalizes the image input data. Display the Normalization
property
of the network input layer.
net = mobilenetv2; net.Layers(1).Normalization
ans = 'zscore'
Networks that you import from TensorFlow or ONNX might not have built-in preprocessing in the input layer. For example, the
input layer of the imported MobileNetV2
from TensorFlow does not normalize the input image. Import MobileNetV2
and display the Normalization
property of the network input
layer.
net = importTensorFlowNetwork("MobileNetV2", ... OutputLayerType="classification"); net.Layers(1).Normalization
ans = 'none'
Often, open-source repositories provide information about the required input data preprocessing. For example, see tf.keras.applications.mobilenet_v2.preprocess_input and ShuffleNet in ONNX Model Zoo. To learn more about how to preprocess images and other types of data in Deep Learning Toolbox, see Preprocess Images for Deep Learning and Preprocess Data for Deep Neural Networks.
See Also
importTensorFlowNetwork
| importNetworkFromPyTorch
| importONNXNetwork
| importTensorFlowLayers
| importONNXLayers
| dlarray
Related Topics
- Interoperability Between Deep Learning Toolbox, TensorFlow, PyTorch, and ONNX
- Pretrained Deep Neural Networks
- Select Function to Import ONNX Pretrained Network