SourceXtractorPlusPlus  0.16
Please provide a description of the project.
OnnxModel.cpp
Go to the documentation of this file.
1 /*
2  * OnnxModel.cpp
3  *
4  * Created on: Feb 16, 2021
5  * Author: mschefer
6  */
7 
11 
14 
15 namespace SourceXtractor {
16 
17 namespace {
21 static std::string formatShape(const std::vector<int64_t>& shape) {
22  std::ostringstream stream;
23  for (auto i = shape.begin(); i != shape.end() - 1; ++i) {
24  stream << *i << " x ";
25  }
26  stream << shape.back();
27  return stream.str();
28 }
29 
30 }
31 
32 OnnxModel::OnnxModel(const std::string& model_path) {
33  m_model_path = model_path;
34 
36  auto allocator = Ort::AllocatorWithDefaultOptions();
37 
38  onnx_logger.info() << "Loading ONNX model " << model_path;
39  m_session = Euclid::make_unique<Ort::Session>(ORT_ENV, model_path.c_str(), Ort::SessionOptions{nullptr});
40 
41  if (m_session->GetOutputCount() != 1) {
42  throw Elements::Exception() << "Only ONNX models with a single output tensor are supported";
43  }
44 
45  for (int i=0; i<m_session->GetInputCount(); i++) {
46  auto input_type = m_session->GetInputTypeInfo(i);
47 
48  m_input_names.emplace_back(m_session->GetInputName(i, allocator));
49  m_input_shapes.emplace_back(input_type.GetTensorTypeAndShapeInfo().GetShape());
50  m_input_types.emplace_back(input_type.GetTensorTypeAndShapeInfo().GetElementType());
51  }
52 
53  m_output_name = m_session->GetOutputName(0, allocator);
54  m_domain_name = m_session->GetModelMetadata().GetDomain(allocator);
55  m_graph_name = m_session->GetModelMetadata().GetGraphName(allocator);
56 
57  auto output_type = m_session->GetOutputTypeInfo(0);
58 
59  m_output_shape = output_type.GetTensorTypeAndShapeInfo().GetShape();
60  m_output_type = output_type.GetTensorTypeAndShapeInfo().GetElementType();
61 
62 // onnx_logger.info() << "ONNX model with input of " << formatShape(m_input_shapes[0]);
63 // onnx_logger.info() << "ONNX model with output of " << formatShape(m_output_shape);
64 }
65 
66 }
T back(T... args)
T begin(T... args)
T c_str(T... args)
static Logging getLogger(const std::string &name="")
void info(const std::string &logMessage)
std::vector< ONNXTensorElementDataType > m_input_types
Input type.
Definition: OnnxModel.h:141
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition: OnnxModel.h:146
std::string m_output_name
Output tensor name.
Definition: OnnxModel.h:140
ONNXTensorElementDataType m_output_type
Output type.
Definition: OnnxModel.h:142
std::vector< std::string > m_input_names
Input tensor name.
Definition: OnnxModel.h:139
OnnxModel(const std::string &model_path)
Definition: OnnxModel.cpp:32
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition: OnnxModel.h:144
std::string m_graph_name
graph name
Definition: OnnxModel.h:138
std::string m_domain_name
domain name
Definition: OnnxModel.h:137
std::string m_model_path
Path to the ONNX model.
Definition: OnnxModel.h:145
std::vector< std::vector< std::int64_t > > m_input_shapes
Input tensor shape.
Definition: OnnxModel.h:143
T emplace_back(T... args)
T end(T... args)
static std::string formatShape(const std::vector< int64_t > &shape)
Elements::Logging onnx_logger
Logger for the ONNX plugin.
Definition: OnnxPlugin.cpp:26
Ort::Env ORT_ENV
Definition: OnnxCommon.cpp:25
T str(T... args)