SourceXtractorPlusPlus  0.16
Please provide a description of the project.
MLSegmentation.cpp
Go to the documentation of this file.
1 
18 #include <memory>
19 #include <vector>
20 #include <list>
21 #include <iostream>
22 
23 #include <onnxruntime_cxx_api.h>
24 
26 #include "SEUtils/HilbertCurve.h"
27 
29 
32 
35 
37 
40 
42 
43 
46 
49 
51 
52 
53 
54 namespace SourceXtractor {
55 
56 namespace {
57 class LutzLabellingListener : public Lutz::LutzListener {
58 public:
59  LutzLabellingListener(Segmentation::LabellingListener& listener, std::shared_ptr<SourceFactory> source_factory,
60  int window_size) :
61  m_listener(listener),
62  m_source_factory(source_factory),
63  m_window_size(window_size) {}
64 
65  virtual ~LutzLabellingListener() = default;
66 
67  void publishGroup(Lutz::PixelGroup& pixel_group) override {
68  auto source = m_source_factory->createSource();
69  source->setProperty<PixelCoordinateList>(pixel_group.pixel_list);
70  source->setProperty<SourceId>();
71  m_listener.publishSource(source);
72  }
73 
74  void notifyProgress(int line, int total) override {
75  m_listener.notifyProgress(line, total);
76 
77  if (m_window_size > 0 && line > m_window_size) {
78  m_listener.requestProcessing(
79  ProcessSourcesEvent(std::make_shared<LineSelectionCriteria>(line - m_window_size))
80  );
81  }
82  }
83 
84 private:
85  Segmentation::LabellingListener& m_listener;
88 };
89 
90 
91 }
92 
95 
96  OnnxModel model(m_model_path);
97 
98  auto input_shape = model.getInputShape();
99  auto output_shape = model.getOutputShape();
100 
101  // TODO add sanity check
102 
103  int tile_size = output_shape[1];
104  int data_planes = output_shape[3];
105  float average_rms = frame->getBackgroundMedianRms();
106  float detection_threshold = m_ml_threshold;
107 
108  onnx_logger.info() << "Onnx tile size: " << tile_size << " Data planes: " << data_planes << " RMS: " << average_rms;
109 
110  if (model.getInputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
111  throw Elements::Exception() << "Only ONNX models with float input are supported";
112  }
113 
114  if (model.getOutputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
115  throw Elements::Exception() << "Only ONNX models with float output are supported";
116  }
117 
118  if (model.getInputNb() != 1) {
119  throw Elements::Exception() << "Only ONNX models with a single input tensor are supported";
120  }
121 
122  // allocate memory
123  std::vector<float> input_data(tile_size * tile_size);
124  std::vector<float> output_data(tile_size * tile_size * data_planes);
125 
126  auto image = frame->getSubtractedImage();
127  ImageAccessor<SeFloat> image_acc(image);
128 
131  for (int i=0; i < data_planes; i++) {
132  tmp_images.emplace_back(FitsWriter::newTemporaryImage<float>("_tmp_ml_seg%%%%%%.fits", image->getWidth(), image->getHeight()));
133  check_images.emplace_back(CheckImages::getInstance().getMLDetectionImage(i));
134  }
135 
136  Lutz lutz;
137  LutzLabellingListener lutz_listener(listener, m_source_factory, 0);
138 
139  for (int ox = 0; ox + tile_size * 3 / 4 < image->getWidth(); ox += tile_size / 2) {
140  for (int oy = 0; oy + tile_size * 3 / 4 < image->getHeight(); oy += tile_size / 2) {
141 
142  for (int x = 0; x < tile_size; x++) {
143  for (int y = 0; y < tile_size; y++) {
144  if (ox+x < image->getWidth() && oy+y < image->getHeight()) {
145  input_data[x+y*tile_size] = image_acc.getValue(ox+x, oy+y) / average_rms;
146  } else {
147  input_data[x+y*tile_size] = 0;
148  }
149  }
150  }
151 
152  model.run<float, float>(input_data, output_data);
153 
154  int start_x = (ox == 0) ? 0 : tile_size / 4;
155  int start_y = (oy == 0) ? 0 : tile_size / 4;
156 
157  int end_x = (ox + tile_size * 5 / 4 < image->getWidth()) ? tile_size * 3 / 4 : tile_size ;
158  int end_y = (oy + tile_size * 5 / 4 < image->getHeight()) ? tile_size * 3 / 4 : tile_size;
159 
160  for (int x = start_x; x < end_x; x++) {
161  for (int y = start_y; y < end_y; y++) {
162  if (ox+x < image->getWidth() && oy+y < image->getHeight()) {
163  for (int i=0; i<data_planes; i++) {
164  tmp_images[i]->setValue(ox + x, oy + y, output_data[(x+y*tile_size) * data_planes + i] - detection_threshold);
165  if (check_images[i] != nullptr) {
166  check_images[i]->setValue(ox+x, oy+y, output_data[(x+y*tile_size) * data_planes + i]);
167  }
168  }
169  }
170  }
171  }
172  }
173  }
174  for (int i=0; i<data_planes; i++) {
175  lutz.labelImage(lutz_listener, *tmp_images[i]);
176  }
177 }
178 
179 }
Segmentation::LabellingListener & m_listener
int m_window_size
std::shared_ptr< SourceFactory > m_source_factory
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > x
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > y
static Logging getLogger(const std::string &name="")
void info(const std::string &logMessage)
static CheckImages & getInstance()
Definition: CheckImages.h:138
Implements a Segmentation based on the Lutz algorithm.
Definition: Lutz.h:37
void labelImage(LutzListener &listener, const DetectionImage &image, PixelCoordinate offset=PixelCoordinate(0, 0))
Definition: Lutz.cpp:59
void labelImage(Segmentation::LabellingListener &listener, std::shared_ptr< const DetectionImageFrame > frame) override
std::shared_ptr< SourceFactory > m_source_factory
void run(std::vector< T > &input_data, std::vector< U > &output_data) const
Definition: OnnxModel.h:27
ONNXTensorElementDataType getInputType() const
Definition: OnnxModel.h:92
ONNXTensorElementDataType getOutputType() const
Definition: OnnxModel.h:96
const std::vector< std::int64_t > & getInputShape() const
Definition: OnnxModel.h:100
const std::vector< std::int64_t > & getOutputShape() const
Definition: OnnxModel.h:104
size_t getInputNb() const
Definition: OnnxModel.h:128
T emplace_back(T... args)
Elements::Logging onnx_logger
Logger for the ONNX plugin.
Definition: OnnxPlugin.cpp:26