23 #include <onnxruntime_cxx_api.h>
57 class LutzLabellingListener :
public Lutz::LutzListener {
65 virtual ~LutzLabellingListener() =
default;
67 void publishGroup(Lutz::PixelGroup& pixel_group)
override {
69 source->setProperty<PixelCoordinateList>(pixel_group.pixel_list);
70 source->setProperty<SourceId>();
74 void notifyProgress(
int line,
int total)
override {
79 ProcessSourcesEvent(std::make_shared<LineSelectionCriteria>(line -
m_window_size))
103 int tile_size = output_shape[1];
104 int data_planes = output_shape[3];
105 float average_rms = frame->getBackgroundMedianRms();
108 onnx_logger.
info() <<
"Onnx tile size: " << tile_size <<
" Data planes: " << data_planes <<
" RMS: " << average_rms;
110 if (model.
getInputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
114 if (model.
getOutputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
119 throw Elements::Exception() <<
"Only ONNX models with a single input tensor are supported";
126 auto image = frame->getSubtractedImage();
131 for (
int i=0; i < data_planes; i++) {
132 tmp_images.
emplace_back(FitsWriter::newTemporaryImage<float>(
"_tmp_ml_seg%%%%%%.fits", image->getWidth(), image->getHeight()));
139 for (
int ox = 0; ox + tile_size * 3 / 4 < image->getWidth(); ox += tile_size / 2) {
140 for (
int oy = 0; oy + tile_size * 3 / 4 < image->getHeight(); oy += tile_size / 2) {
142 for (
int x = 0;
x < tile_size;
x++) {
143 for (
int y = 0;
y < tile_size;
y++) {
144 if (ox+x < image->getWidth() && oy+y < image->getHeight()) {
145 input_data[
x+
y*tile_size] = image_acc.
getValue(ox+
x, oy+
y) / average_rms;
147 input_data[
x+
y*tile_size] = 0;
152 model.
run<float,
float>(input_data, output_data);
154 int start_x = (ox == 0) ? 0 : tile_size / 4;
155 int start_y = (oy == 0) ? 0 : tile_size / 4;
157 int end_x = (ox + tile_size * 5 / 4 < image->getWidth()) ? tile_size * 3 / 4 : tile_size ;
158 int end_y = (oy + tile_size * 5 / 4 < image->getHeight()) ? tile_size * 3 / 4 : tile_size;
160 for (
int x = start_x;
x < end_x;
x++) {
161 for (
int y = start_y;
y < end_y;
y++) {
162 if (ox+x < image->getWidth() && oy+
y < image->getHeight()) {
163 for (
int i=0; i<data_planes; i++) {
164 tmp_images[i]->setValue(ox +
x, oy +
y, output_data[(
x+
y*tile_size) * data_planes + i] - detection_threshold);
165 if (check_images[i] !=
nullptr) {
166 check_images[i]->setValue(ox+
x, oy+
y, output_data[(
x+
y*tile_size) * data_planes + i]);
174 for (
int i=0; i<data_planes; i++) {
175 lutz.
labelImage(lutz_listener, *tmp_images[i]);
Segmentation::LabellingListener & m_listener
std::shared_ptr< SourceFactory > m_source_factory
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > x
std::shared_ptr< DependentParameter< std::shared_ptr< EngineParameter > > > y
static Logging getLogger(const std::string &name="")
void info(const std::string &logMessage)
T emplace_back(T... args)