SourceXtractorPlusPlus  0.16
Please provide a description of the project.
OnnxModel.h
Go to the documentation of this file.
1 /*
2  * OnnxModel.h
3  *
4  * Created on: Feb 16, 2021
5  * Author: mschefer
6  */
7 
8 #ifndef _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
9 #define _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_
10 
11 #include <memory>
12 #include <vector>
13 #include <list>
14 #include <iostream>
15 #include <numeric>
16 
17 #include <onnxruntime_cxx_api.h>
18 
19 namespace SourceXtractor {
20 
21 class OnnxModel {
22 public:
23 
24  explicit OnnxModel(const std::string& model_path);
25 
26  template<typename T, typename U>
27  void run(std::vector<T>& input_data, std::vector<U>& output_data) const {
28  Ort::RunOptions run_options;
29  auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
30 
31  // Allocate memory
33  input_shape[0] = 1;
34  size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
35 
37  output_shape[0] = 1;
38  size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
39 
40  // FIXME check input and output size are OK
41 
42  // Setup input/output tensors
43  auto input_tensor = Ort::Value::CreateTensor<T>(
44  mem_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size());
45  auto output_tensor = Ort::Value::CreateTensor<U>(
46  mem_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
47 
48  // Run the model
49  const char *input_name = m_input_names[0].c_str();
50  const char *output_name = m_output_name.c_str();
51 
52  m_session->Run(run_options, &input_name, &input_tensor, 1, &output_name, &output_tensor, 1);
53  }
54 
55  template<typename T, typename U>
56  void runMultiInput(std::map<std::string, std::vector<T>>& input_data, std::vector<U>& output_data) const {
57  Ort::RunOptions run_options;
58  auto mem_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
59 
60  std::vector<const char *> input_names;
61  std::vector<Ort::Value> input_tensors;
62 
63  int inputs_nb = m_input_names.size();
64  for (int i=0; i<inputs_nb; i++) {
65  input_names.emplace_back(m_input_names[i].c_str());
66 
67  // Allocate memory
69  input_shape[0] = 1;
70  size_t input_size = std::accumulate(input_shape.begin(), input_shape.end(), 1u, std::multiplies<size_t>());
71 
72  input_tensors.emplace_back(Ort::Value::CreateTensor<T>(
73  mem_info, input_data[m_input_names[i]].data(), input_data[m_input_names[i]].size(),
74  input_shape.data(), input_shape.size()));
75  }
76 
77  // Output name and shape
78  const char *output_name = m_output_name.c_str();
80  output_shape[0] = 1;
81 
82  // Setup output tensor
83  size_t output_size = std::accumulate(output_shape.begin(), output_shape.end(), 1u, std::multiplies<size_t>());
84  auto output_tensor = Ort::Value::CreateTensor<U>(
85  mem_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
86 
87  // Run the model
88  m_session->Run(run_options, &input_names[0], &input_tensors[0], inputs_nb, &output_name, &output_tensor, 1);
89  }
90 
91 
92  ONNXTensorElementDataType getInputType() const {
93  return m_input_types[0];
94  }
95 
96  ONNXTensorElementDataType getOutputType() const {
97  return m_output_type;
98  }
99 
101  return m_input_shapes[0];
102  }
103 
105  return m_output_shape;
106  }
107 
109  return m_domain_name;
110  }
111 
113  return m_graph_name;
114  }
115 
117  return m_input_names[0];
118  }
119 
121  return m_output_name;
122  }
123 
125  return m_model_path;
126  }
127 
128  size_t getInputNb() const {
129  return m_input_names.size();
130  }
131 
132  size_t getOutputNb() const {
133  return 1U;
134  }
135 
136 private:
142  ONNXTensorElementDataType m_output_type;
147 };
148 
149 }
150 
151 
152 #endif /* _SEIMPLEMENTATION_COMMON_ONNXMODEL_H_ */
T accumulate(T... args)
T begin(T... args)
T c_str(T... args)
void run(std::vector< T > &input_data, std::vector< U > &output_data) const
Definition: OnnxModel.h:27
ONNXTensorElementDataType getInputType() const
Definition: OnnxModel.h:92
ONNXTensorElementDataType getOutputType() const
Definition: OnnxModel.h:96
std::vector< ONNXTensorElementDataType > m_input_types
Input type.
Definition: OnnxModel.h:141
std::unique_ptr< Ort::Session > m_session
Session, one per model. In theory, it is thread-safe.
Definition: OnnxModel.h:146
std::string getGraphName() const
Definition: OnnxModel.h:112
std::string getDomain() const
Definition: OnnxModel.h:108
std::string m_output_name
Output tensor name.
Definition: OnnxModel.h:140
size_t getOutputNb() const
Definition: OnnxModel.h:132
const std::vector< std::int64_t > & getInputShape() const
Definition: OnnxModel.h:100
std::string getOutputName() const
Definition: OnnxModel.h:120
ONNXTensorElementDataType m_output_type
Output type.
Definition: OnnxModel.h:142
std::vector< std::string > m_input_names
Input tensor name.
Definition: OnnxModel.h:139
void runMultiInput(std::map< std::string, std::vector< T >> &input_data, std::vector< U > &output_data) const
Definition: OnnxModel.h:56
OnnxModel(const std::string &model_path)
Definition: OnnxModel.cpp:32
std::string getInputName() const
Definition: OnnxModel.h:116
std::vector< std::int64_t > m_output_shape
Output tensor shape.
Definition: OnnxModel.h:144
std::string m_graph_name
graph name
Definition: OnnxModel.h:138
std::string m_domain_name
domain name
Definition: OnnxModel.h:137
const std::vector< std::int64_t > & getOutputShape() const
Definition: OnnxModel.h:104
std::string m_model_path
Path to the ONNX model.
Definition: OnnxModel.h:145
size_t getInputNb() const
Definition: OnnxModel.h:128
std::vector< std::vector< std::int64_t > > m_input_shapes
Input tensor shape.
Definition: OnnxModel.h:143
std::string getModelPath() const
Definition: OnnxModel.h:124
T data(T... args)
T emplace_back(T... args)
T end(T... args)
T size(T... args)