Initializing LSTM which is imported using ONNX

12 ビュー (過去 30 日間)
Andreas
Andreas 2024 年 7 月 17 日
回答済み: Andreas 2024 年 7 月 23 日
Hi,
I am training an LSTM for RL using Ray in Python. I would like to export this model using ONNX and afterwards import it in Matlab. As far as I have understood, I need to initialize the model in matlab after importing. However, I cannot find out the correct input shapes/formats in Matlab to make this work.
Minimum working example:
Python code to train LSTM:
import torch
import numpy as np
from ray.rllib.algorithms.ppo import PPOConfig
% Config Algorithm
algo = (
PPOConfig()
.env_runners(num_env_runners=1)
.resources(num_gpus=0)
.environment(env="CartPole-v1")
.training(model={"use_lstm": True})
.build()
)
% train for 2 episodes
for i in range(2):
result = algo.train()
% get policiy
ppo_policy = algo.get_policy()
% batch size
B=1
% initialize LSTM input:
input_dict = {"obs": torch.tensor(np.random.uniform(0, 1.0, size=(B,4)).astype(np.float32))}
state_batches = [torch.zeros((B,256), dtype=torch.float32),torch.zeros((B,256), dtype=torch.float32)]
seq_lens = torch.ones([B], dtype=int)
% apply LSTM to inputs
policy = algo.get_policy()
model = policy.model
print(model(input_dict, state=state_batches, seq_lens=seq_lens))
% save model to ONNX
ppo_policy.export_model('onnx14', onnx=14)
Code in Matlab:
% Import model from where I saved it
net = importNetworkFromONNX('path/to/onnx-model');
% input shapes
obs_size = [1,4];
state_size=[2,1,256];
seq_lens_size=[1];
% initialize input arrays
obs = dlarray(rand(obs_size),"BS");
state = dlarray(rand(state_size),"SBS");
seq_len = dlarray(rand(seq_lens_size),"SB");
% initialize net
net = initialize(net,obs,state,seq_len);
Error message:
I appreciate any help!
Best,
Andreas
  2 件のコメント
Nilesh
Nilesh 2024 年 7 月 17 日
編集済み: Nilesh 2024 年 7 月 17 日
Hello Andreas,
Have you tried asking your issue with ChatGPT.
Andreas
Andreas 2024 年 7 月 17 日
Yes, but without success so far.

サインインしてコメントする。

回答 (3 件)

Joss Knight
Joss Knight 2024 年 7 月 18 日
This code is suspect
% initialize input arrays
obs = dlarray(rand(obs_size),"BS");
state = dlarray(rand(state_size),"SBS");
seq_len = dlarray(rand(seq_lens_size),"SB");
% initialize net
net = initialize(net,obs,state,seq_len);
I think your network has a single input, so you need to pass a single input to initialize (along with the network), basically just some example input exactly like you want to pass to predict. I think you have two channels and a sequence length of 256? And one of your dimensions is Time so you need a T dimension. And I don't think you have any spatial dimensions, so no S labels. So you need something like
exampleInput = dlarray(rand(2,1,256),'CBT');
net = initialize(net, exampleInput);
Or if you prefer, a permutation of that like
exampleInput = dlarray(rand(256,2,1),'TCB');
net = initialize(net, exampleInput);
If this doesn't work, try running analyzeNetwork(net) to see where your inputs are and we can work out what to expect.
  1 件のコメント
Andreas
Andreas 2024 年 7 月 23 日
Hi,
the network does not have a single input. I managed to solve the issue, see below for my response. Thank you, for your help anyway!

サインインしてコメントする。


Kaustab Pal
Kaustab Pal 2024 年 7 月 19 日
It seems you want to determine the input dimension of your imported network. You can easily find this information using the analyzeNetwork function. This function provides an interactive visualization of the network architecture and detailed information, including:
  • Layer types
  • Sizes and formats of layer learnable parameters
  • States and activations
  • Total number of learnable parameters
The activation size of the topmost layer will give you the input dimension.
Additionally, when creating dlarray objects in MATLAB, you need to specify the format, which must follow this order:
  • "S" (Spatial)
  • "C" (Channel)
  • "B" (Batch)
  • "T" (Time)
  • "U" (Unspecified)
For more details, you can refer to the following links:
  1. analyzeNetwork Documentation: https://www.mathworks.com/help/deeplearning/ref/analyzenetwork.html#mw_bdd24886-fa03-4540-a111-391541a0a684
  2. dlarray Documentation:: https://www.mathworks.com/help/deeplearning/ref/dlarray.html#d126e57736:~:text=When%20you%20create%20a%20formatted%20dlarray%20object%2C%20the%20software%20automatically%20permutes%20the%20dimensions%20such%20that%20the%20format%20has%20dimensions%20in%20this%20order%3A
Hope this helps.
  1 件のコメント
Joss Knight
Joss Knight 2024 年 7 月 19 日
Just FYI, the formats do not have to follow that order.

サインインしてコメントする。


Andreas
Andreas 2024 年 7 月 23 日
Helly everyone,
thank you for your help. Unfortunately, I had to work around this issue but I could solve it in the end. I believe the reason for matlab struggling is that within Ray's Rllib the models contain a lot of complicated overhead. In particular the inputs to the network are lists/dicts etc which undergo quite some reformatting which seemed to cause some issues. In the end, what I did is extract the actual torch models which are relevant from the trained Rllib object and joined them in a new torch.nn.Module object. For this object it worked out just fine using torch.onnx.export.
Thank you all for your help.
Best, Andreas

カテゴリ

Help Center および File ExchangeSequence and Numeric Feature Data Workflows についてさらに検索

製品


リリース

R2024a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by