How do I run ONNX model on Simulink?
34 ビュー (過去 30 日間)
古いコメントを表示
I created a simple onnx model by Pytorch as below code.
import torch
import torch.nn as nn
class EmptyModel(nn.Module):
def __init__(self):
super(EmptyModel, self).__init__()
# No trainable parameters, but add a linear layer to match Simulink requirements
self.linear = nn.Linear(7, 2, bias=False)
with torch.no_grad():
self.linear.weight.fill_(0.0)
def forward(self, x):
# Returns the first two elements of the input as is, without any computation
return x[:, :2]
model = EmptyModel()
dummy_input = torch.randn(1, 7, dtype=torch.float32)
torch.onnx.export(
model,
dummy_input,
"empty_model.onnx",
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"},
},
)
It is an ONNX model that performs model inference on 7 input data and returns 2 output data that are the results of the inference.
I would like to incorporate this ONNX model in Simulink and run the simulation.
Since I value performance aspects such as execution speed, I also used ONNX Model Predict( Deep Learing toolbox), but the performance aspect was not so good, so I am looking for a better way.
Do you have any ideas?
I know that there are ways such as using S-Function, but I would like to know if there are any other ways that can be used. If possible, I would appreciate it if you could also tell me how to implement it.
Actually, I just tried below on my own.
I build the below C++ code (File name: smex.cpp) with mex command into MEX file.
Then I tried to call that MEX file in S-Function like this, however MATLAB crashes.

This MEX file (smex.mexw64) is placed in the same folder as this .slx file, but It doesn't work.
I'm using Windows.
Also, I downloaded (onnxruntime-win-x64-1.20.1.zip) from this link https://github.com/microsoft/onnxruntime/releases
then set up ONNX Runtime C++ API environment with Visual Studio 2022.
I couldn't figure out why this error happened about this matter....
#define S_FUNCTION_NAME smex
#define S_FUNCTION_LEVEL 2
#include "simstruc.h"
#include <onnxruntime_cxx_api.h>
#include <vector>
#include <memory>
static std::unique_ptr<Ort::Env> g_env;
static std::unique_ptr<Ort::Session> g_session;
static std::vector<const char*> g_input_node_names;
static std::vector<const char*> g_output_node_names;
#ifdef __cplusplus
extern "C" {
#endif
static void mdlInitializeSizes(SimStruct *S)
{
if (!ssSetNumInputPorts(S, 1)) return;
ssSetInputPortWidth(S, 0, 7);
ssSetInputPortDataType(S, 0, SS_DOUBLE);
ssSetInputPortDirectFeedThrough(S, 0, 1);
ssSetInputPortRequiredContiguous(S, 0, 1);
if (!ssSetNumOutputPorts(S, 1)) return;
ssSetOutputPortWidth(S, 0, 2);
ssSetOutputPortDataType(S, 0, SS_DOUBLE);
ssSetNumSampleTimes(S, 1);
ssSetOptions(S, SS_OPTION_EXCEPTION_FREE_CODE);
}
static void mdlInitializeSampleTimes(SimStruct *S)
{
ssSetSampleTime(S, 0, INHERITED_SAMPLE_TIME);
ssSetOffsetTime(S, 0, 0.0);
}
#define MDL_START
static void mdlStart(SimStruct *S)
{
try {
g_env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, "test");
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
const wchar_t* model_path = L"empty_model.onnx";
g_session = std::make_unique<Ort::Session>(*g_env, model_path, session_options);
Ort::AllocatorWithDefaultOptions allocator;
size_t num_input_nodes = g_session->GetInputCount();
g_input_node_names.resize(num_input_nodes);
for (size_t i = 0; i < num_input_nodes; i++) {
auto input_name = g_session->GetInputNameAllocated(i, allocator);
g_input_node_names[i] = input_name.get();
}
size_t num_output_nodes = g_session->GetOutputCount();
g_output_node_names.resize(num_output_nodes);
for (size_t i = 0; i < num_output_nodes; i++) {
auto output_name = g_session->GetOutputNameAllocated(i, allocator);
g_output_node_names[i] = output_name.get();
}
}
catch (const Ort::Exception& ex) {
mexErrMsgIdAndTxt("myOnnxSfunc:InitError", " %s", ex.what());
}
}
static void mdlOutputs(SimStruct *S, int_T tid)
{
try {
const real_T *u = ssGetInputPortRealSignal(S, 0);
std::vector<float> input_data(7);
for (int i = 0; i < 7; i++) {
input_data[i] = static_cast<float>(u[i]);
}
std::vector<int64_t> input_shape = { 1, 7 };
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
memory_info,
input_data.data(),
input_data.size(),
input_shape.data(),
input_shape.size()
);
auto output_tensors = g_session->Run(
Ort::RunOptions{ nullptr },
g_input_node_names.data(),
&input_tensor,
1,
g_output_node_names.data(),
g_output_node_names.size()
);
float* output_data = output_tensors[0].GetTensorMutableData<float>();
auto type_info = output_tensors[0].GetTensorTypeAndShapeInfo();
auto output_shape = type_info.GetShape();
if (output_shape.size() != 2 || output_shape[0] != 1 || output_shape[1] < 2) {
mexErrMsgIdAndTxt("myOnnxSfunc:OutputError",
"(%lld, %lld)",
output_shape[0], output_shape[1]);
}
real_T *y = ssGetOutputPortRealSignal(S, 0);
y[0] = static_cast<double>(output_data[0]);
y[1] = static_cast<double>(output_data[1]);
}
catch (const Ort::Exception& ex) {
mexErrMsgIdAndTxt("myOnnxSfunc:RuntimeError", "%s", ex.what());
}
}
static void mdlTerminate(SimStruct *S)
{
g_session.reset();
g_env.reset();
g_input_node_names.clear();
g_output_node_names.clear();
}
#ifdef __cplusplus
}
#endif
#ifdef MATLAB_MEX_FILE
#include "simulink.c"
#else
#include "cg_sfun.h"
#endif
Best,
0 件のコメント
回答 (0 件)
参考
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!