8 #ifndef _SEIMPLEMENTATION_PLUGIN_FLEXIBLEMODELFITTING_ONNXCOMPACTMODEL_H_
9 #define _SEIMPLEMENTATION_PLUGIN_FLEXIBLEMODELFITTING_ONNXCOMPACTMODEL_H_
24 template <
typename ImageType>
41 double getValue(
double,
double)
const override {
47 ImageType image = Traits::factory(size_x, size_y);
49 int largest_size =
std::max(size_x, size_y);
54 if (largest_size < shape[2]) {
55 selected_model = model;
60 if (selected_model ==
nullptr) {
61 logger.
warn() <<
"No large enough ONNX model could be found, skipping...";
67 int render_size = output_shape[2];
69 if (selected_model->
getOutputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
78 input_data_arrays[it.first] =
std::vector<float>( {
static_cast<float>(it.second->getValue()) } );
86 for (
int y=0;
y<(int)size_y; ++
y) {
87 int dy =
y - size_y / 2;
88 for (
int x=0;
x<(int)size_x; ++
x) {
89 int dx =
x - size_x / 2;
94 input_data_arrays[
"x"][
x +
y * render_size] = x2;
95 input_data_arrays[
"y"][
x +
y * render_size] = y2;
99 selected_model->
runMultiInput<float,
float>(input_data_arrays, output_data);
101 for (
int y = 0;
y < (int) size_y; ++
y) {
102 for (
int x = 0;
x < (int) size_x; ++
x) {
103 Traits::at(image,
x,
y) = output_data[
x +
y * render_size];
std::shared_ptr< EngineParameter > dx
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > x
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > y
std::shared_ptr< EngineParameter > dy
static Logging getLogger(const std::string &name="")
void warn(const std::string &logMessage)
void renormalize(ImageType &image, double flux) const
Mat22 getCombinedTransform(double pixel_scale) const
ImageType getRasterizedImage(double pixel_scale, std::size_t size_x, std::size_t size_y) const override
std::map< std::string, std::shared_ptr< BasicParameter > > m_params
OnnxCompactModel(std::vector< std::shared_ptr< SourceXtractor::OnnxModel >> models, std::shared_ptr< BasicParameter > x_scale, std::shared_ptr< BasicParameter > y_scale, std::shared_ptr< BasicParameter > rotation, double width, double height, std::shared_ptr< BasicParameter > x, std::shared_ptr< BasicParameter > y, std::shared_ptr< BasicParameter > flux, std::map< std::string, std::shared_ptr< BasicParameter >> params, std::tuple< double, double, double, double > transform)
std::vector< std::shared_ptr< SourceXtractor::OnnxModel > > m_models
virtual ~OnnxCompactModel()=default
double getValue(double, double) const override
std::shared_ptr< BasicParameter > m_flux
static Elements::Logging logger