Import PyTorch LSTM Model into Matlab

9 ビュー (過去 30 日間)
Felix
Felix 2025 年 5 月 14 日
回答済み: Gayathri 2025 年 5 月 15 日
Hey Guys,
I am currently trying to use my Pytorch LSTM in Matlab (Trained with Pytorch Lightning) but I have no idea how to use the importNetworkFromPyTorch function with an LSTM. The Structure of the model is the following:
LSTM -> Linear -> Sigmoid
The LSTM properties (https://docs.pytorch.org/docs/stable/generated/torch.nn.LSTM.html) are (num_inputs=3, nhid=5, nlayers=5) which causes the Linear layer to be (in=5, out=1).
The Training Data has the shape [BS, 600, 3] with BS being batch_size, 600 being the time series and 3 being the individual input at one timestep. The shape of the hidden state is [5, BS, 5].
So my problem is that I do not understand what input sizes I have to put into the importNetworkFromPyTorch function.
I expect it so be something like this:
net = importNetworkFromPyTorch("example/path/model.pt",PyTorchInputSizes={[NaN,3], [2, 5, NaN, 5]})
I exported the traced model by:
traced_model = torch.jit.trace(model.model.forward, (input, hidden_input))
torch.jit.save(traced_model, "model.pt")
The shape of input is [3] and of hidden_input is ([5, 1, 5], [5, 1, 5]) (one for hidden state and one for context)
Can you please tell me how to use this importNetworkFromPyTorch function.

回答 (1 件)

Gayathri
Gayathri 2025 年 5 月 15 日
Hi @Felix,
Can you please confirm on which MATLAB function you are using? And are you facing any errors when running the "importNetworkFromPyTorch" command in MATLAB?
I can see in the MATLAB documentation that importing LSTM layers is only supported from MATLAB R2025a. Please upgrade to MATLAB R2025a to import the LSTM model.
Hope this helps!

カテゴリ

Help Center および File ExchangeImage Data Workflows についてさらに検索

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by