SourceXtractorPlusPlus  0.16
Please provide a description of the project.
OnnxTaskFactory.cpp
Go to the documentation of this file.
1 
18 #include <onnxruntime_cxx_api.h>
19 
21 #include <NdArray/NdArray.h>
22 
24 
29 
31 
32 namespace SourceXtractor {
33 
38  std::stringstream prop_name;
39 
40  std::string domain = model.getDomain();
41  if (!domain.empty()) {
42  prop_name << domain << '.';
43  }
44 
45  std::string graph_name = model.getGraphName();
46  if (!graph_name.empty()) {
47  prop_name << graph_name << '.';
48  }
49 
50  prop_name << model.getOutputName();
51 
52  return prop_name.str();
53 }
54 
59  std::ostringstream stream;
60  for (auto i = shape.begin(); i != shape.end() - 1; ++i) {
61  stream << *i << " x ";
62  }
63  stream << shape.back();
64  return stream.str();
65 }
66 
68 
70  if (property_id == PropertyId::create<OnnxProperty>()) {
71  return std::make_shared<OnnxSourceTask>(m_model_infos);
72  }
73  return nullptr;
74 }
75 
78 }
79 
81  const auto& onnx_config = manager.getConfiguration<OnnxConfig>();
82  const auto& models = onnx_config.getModels();
83 
84  for (auto model_path : models) {
85  auto model = std::make_shared<OnnxModel>(model_path);
86 
87  if (model->getInputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
88  throw Elements::Exception() << "Only ONNX models with float input are supported";
89  }
90 
91  if (model->getInputShape().size() != 4) {
92  throw Elements::Exception() << "Expected 4 axes for the input layer, got " << model->getInputShape().size();
93  }
94 
95  auto prop_name = generatePropertyName(*model);
96  onnx_logger.info() << "Output name will be " << prop_name;
97 
98  m_model_infos.emplace_back(OnnxSourceTask::OnnxModelInfo {model, prop_name});
99 
100  }
101 }
102 
103 template<typename T>
104 static void registerColumnConverter(OutputRegistry& registry, const OnnxSourceTask::OnnxModelInfo& model_info) {
105  auto key = model_info.prop_name;
106 
108  model_info.prop_name, [key](const OnnxProperty& prop) {
109  return prop.getData<T>(key);
110  }, "", model_info.model->getModelPath()
111  );
112 }
113 
115  for (const auto& model_info : m_model_infos) {
116  switch (model_info.model->getOutputType()) {
117  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
118  registerColumnConverter<float>(registry, model_info);
119  break;
120  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
121  registerColumnConverter<int32_t>(registry, model_info);
122  break;
123  default:
124  throw Elements::Exception() << "Unsupported output type: " << model_info.model->getOutputType();
125  }
126  }
127 }
128 
129 } // end of namespace SourceXtractor
T back(T... args)
T begin(T... args)
void info(const std::string &logMessage)
const std::vector< std::string > & getModels() const
Definition: OnnxConfig.h:44
std::string getGraphName() const
Definition: OnnxModel.h:112
std::string getDomain() const
Definition: OnnxModel.h:108
std::string getOutputName() const
Definition: OnnxModel.h:120
void reportConfigDependencies(Euclid::Configuration::ConfigManager &manager) const override
Registers all the Configuration dependencies.
std::shared_ptr< Task > createTask(const PropertyId &property_id) const override
Returns a Task producing a Property corresponding to the given PropertyId.
void registerPropertyInstances(OutputRegistry &registry) override
std::vector< OnnxSourceTask::OnnxModelInfo > m_model_infos
void configure(Euclid::Configuration::ConfigManager &manager) override
Method which should initialize the object.
void registerColumnConverter(std::string column_name, ColumnConverter< PropertyType, OutType > converter, std::string column_unit="", std::string column_description="")
Identifier used to set and retrieve properties.
Definition: PropertyId.h:40
T empty(T... args)
T end(T... args)
static void registerColumnConverter(OutputRegistry &registry, const OnnxSourceTask::OnnxModelInfo &model_info)
static std::string generatePropertyName(const OnnxModel &model)
static std::string formatShape(const std::vector< int64_t > &shape)
Elements::Logging onnx_logger
Logger for the ONNX plugin.
Definition: OnnxPlugin.cpp:26
T str(T... args)