SourceXtractorPlusPlus  0.16
Please provide a description of the project.
OnnxCompactModel.h
Go to the documentation of this file.
1 /*
2  * OnnxCompactModel.h
3  *
4  * Created on: Sep 3, 2021
5  * Author: mschefer
6  */
7 
8 #ifndef _SEIMPLEMENTATION_PLUGIN_FLEXIBLEMODELFITTING_ONNXCOMPACTMODEL_H_
9 #define _SEIMPLEMENTATION_PLUGIN_FLEXIBLEMODELFITTING_ONNXCOMPACTMODEL_H_
10 
11 #include <numeric>
12 
13 #include <ElementsKernel/Logging.h>
14 
16 
19 
20 namespace ModelFitting {
21 
22 static auto logger = Elements::Logging::getLogger("FlexibleModelFitting");
23 
24 template <typename ImageType>
25 class OnnxCompactModel : public CompactModelBase<ImageType> {
26 public:
29  std::shared_ptr<BasicParameter> rotation, double width, double height,
34  : CompactModelBase<ImageType>(x_scale, y_scale, rotation, width, height, x, y, transform),
35  m_models(models), m_flux(flux), m_params(params)
36  {
37  }
38 
39  virtual ~OnnxCompactModel() = default;
40 
41  double getValue(double, double) const override {
42  return 0.0; // unused
43  }
44 
45  ImageType getRasterizedImage(double pixel_scale, std::size_t size_x, std::size_t size_y) const override {
47  ImageType image = Traits::factory(size_x, size_y);
48 
49  int largest_size = std::max(size_x, size_y);
50 
52  for (auto model : m_models) {
53  auto shape = model->getOutputShape();
54  if (largest_size < shape[2]) {
55  selected_model = model;
56  break;
57  }
58  }
59 
60  if (selected_model == nullptr) {
61  logger.warn() << "No large enough ONNX model could be found, skipping...";
62  return image;
63  }
64 
65  auto input_shape = selected_model->getInputShape();
66  auto output_shape = selected_model->getOutputShape();
67  int render_size = output_shape[2];
68 
69  if (selected_model->getOutputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
70  throw Elements::Exception() << "Only ONNX models with float output are supported";
71  }
72 
73  // allocate memory
75  std::vector<float> output_data(render_size * render_size);
76 
77  for (auto const& it : m_params) {
78  input_data_arrays[it.first] = std::vector<float>( { static_cast<float>(it.second->getValue()) } );
79  }
80 
81  input_data_arrays["x"] = std::vector<float>(render_size * render_size);
82  input_data_arrays["y"] = std::vector<float>(render_size * render_size);
83 
85 
86  for (int y=0; y<(int)size_y; ++y) {
87  int dy = y - size_y / 2;
88  for (int x=0; x<(int)size_x; ++x) {
89  int dx = x - size_x / 2;
90 
91  float x2 = dx * transform[0] + dy * transform[1];
92  float y2 = dx * transform[2] + dy * transform[3];
93 
94  input_data_arrays["x"][x + y * render_size] = x2;
95  input_data_arrays["y"][x + y * render_size] = y2;
96  }
97  }
98 
99  selected_model->runMultiInput<float, float>(input_data_arrays, output_data);
100 
101  for (int y = 0; y < (int) size_y; ++y) {
102  for (int x = 0; x < (int) size_x; ++x) {
103  Traits::at(image, x, y) = output_data[x + y * render_size];
104  }
105  }
106 
107  renormalize(image, m_flux->getValue());
108  return image;
109  }
110 
111 private:
118 
120 
121  // parameters
124 };
125 
126 }
127 
128 
129 
130 #endif /* _SEIMPLEMENTATION_PLUGIN_FLEXIBLEMODELFITTING_ONNXCOMPACTMODEL_H_ */
std::shared_ptr< EngineParameter > dx
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > x
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > y
std::shared_ptr< EngineParameter > dy
const double pixel_scale
Definition: TestImage.cpp:74
static Logging getLogger(const std::string &name="")
void warn(const std::string &logMessage)
void renormalize(ImageType &image, double flux) const
Mat22 getCombinedTransform(double pixel_scale) const
ImageType getRasterizedImage(double pixel_scale, std::size_t size_x, std::size_t size_y) const override
std::map< std::string, std::shared_ptr< BasicParameter > > m_params
OnnxCompactModel(std::vector< std::shared_ptr< SourceXtractor::OnnxModel >> models, std::shared_ptr< BasicParameter > x_scale, std::shared_ptr< BasicParameter > y_scale, std::shared_ptr< BasicParameter > rotation, double width, double height, std::shared_ptr< BasicParameter > x, std::shared_ptr< BasicParameter > y, std::shared_ptr< BasicParameter > flux, std::map< std::string, std::shared_ptr< BasicParameter >> params, std::tuple< double, double, double, double > transform)
std::vector< std::shared_ptr< SourceXtractor::OnnxModel > > m_models
virtual ~OnnxCompactModel()=default
double getValue(double, double) const override
std::shared_ptr< BasicParameter > m_flux
ONNXTensorElementDataType getOutputType() const
Definition: OnnxModel.h:96
const std::vector< std::int64_t > & getInputShape() const
Definition: OnnxModel.h:100
void runMultiInput(std::map< std::string, std::vector< T >> &input_data, std::vector< U > &output_data) const
Definition: OnnxModel.h:56
const std::vector< std::int64_t > & getOutputShape() const
Definition: OnnxModel.h:104
T max(T... args)
static Elements::Logging logger
std::pair< double, double > transform(int x, int y, const std::array< double, 4 > &t)
ModelFitting::ImageTraits< ImageInterfaceTypePtr > Traits