SourceXtractorPlusPlus  0.16
Please provide a description of the project.
SegmentationConfig.cpp
Go to the documentation of this file.
1 
23 #include <iostream>
24 #include <fstream>
25 
26 #include <boost/regex.hpp>
27 using boost::regex;
28 using boost::regex_match;
29 using boost::smatch;
30 
31 #include <boost/algorithm/string.hpp>
32 
34 
36 
39 
42 
43 using namespace Euclid::Configuration;
44 namespace po = boost::program_options;
45 
46 namespace SourceXtractor {
47 
49 
50 static const std::string SEGMENTATION_ALGORITHM {"segmentation-algorithm" };
51 static const std::string SEGMENTATION_DISABLE_FILTERING {"segmentation-disable-filtering" };
52 static const std::string SEGMENTATION_FILTER {"segmentation-filter" };
53 static const std::string SEGMENTATION_LUTZ_WINDOW_SIZE {"segmentation-lutz-window-size" };
54 static const std::string SEGMENTATION_BFS_MAX_DELTA {"segmentation-bfs-max-delta" };
55 static const std::string SEGMENTATION_ML_MODEL {"segmentation-ml-model" };
56 static const std::string SEGMENTATION_ML_THRESHOLD {"segmentation-ml-threshold" };
57 
58 SegmentationConfig::SegmentationConfig(long manager_id) : Configuration(manager_id), m_selected_algorithm(Algorithm::UNKNOWN)
59  , m_lutz_window_size(0)
60  , m_bfs_max_delta(1000)
61  , m_ml_threshold(0.9) {}
62 
64  return { {"Detection image", {
65  {SEGMENTATION_ALGORITHM.c_str(), po::value<std::string>()->default_value("LUTZ"),
66  "Segmentation algorithm to be used (LUTZ, TILES or ML (a ONNX-format model must be provided))"},
67  {SEGMENTATION_DISABLE_FILTERING.c_str(), po::bool_switch(),
68  "Disables filtering"},
69  {SEGMENTATION_FILTER.c_str(), po::value<std::string>()->default_value(""),
70  "Loads a filter"},
71  {SEGMENTATION_LUTZ_WINDOW_SIZE.c_str(), po::value<int>()->default_value(0),
72  "Lutz sliding window size (0=disable)"},
73  {SEGMENTATION_BFS_MAX_DELTA.c_str(), po::value<int>()->default_value(1000),
74  "BFS algorithm max source x/y size (default=1000)"},
75  {SEGMENTATION_ML_MODEL.c_str(), po::value<std::string>()->default_value(""),
76  "ONNX model to use with machine learning segmentation"},
77  {SEGMENTATION_ML_THRESHOLD.c_str(), po::value<double>()->default_value(0.9),
78  "Probability threshold for ML detection"},
79  }}};
80 }
81 
83  auto algorithm_name = boost::to_upper_copy(args.at(SEGMENTATION_ALGORITHM).as<std::string>());
84  if (algorithm_name == "LUTZ") {
86  } else if (algorithm_name == "BFS") {
88  } else if (algorithm_name == "ML") {
89 #ifdef WITH_ML_SEGMENTATION
91 #else
92  throw Elements::Exception() << "SourceXtractor++ has not been compiled with ONNX support";
93 #endif
94  } else {
95  throw Elements::Exception() << "Unknown segmentation algorithm : " << algorithm_name;
96  }
97 
98  if (args.at(SEGMENTATION_DISABLE_FILTERING).as<bool>()) {
99  m_filter = nullptr;
100  } else {
101  auto filter_filename = args.at(SEGMENTATION_FILTER).as<std::string>();
102  if (filter_filename != "") {
103  m_filter = loadFilter(filter_filename);
104  if (m_filter == nullptr)
105  throw Elements::Exception() << "Can not load filter: " << filter_filename;
106  } else {
108  }
109  }
110 
114  m_ml_threshold = args.at(SEGMENTATION_ML_THRESHOLD).as<double>();
115 
117  throw Elements::Exception() << "Machine learning segmentation requested but no ONNX model was provided";
118  }
119 }
120 
122 }
123 
125  segConfigLogger.info() << "Using the default segmentation (3x3) filter.";
126  auto convolution_kernel = VectorImage<SeFloat>::create(3, 3);
127  convolution_kernel->setValue(0,0, 1);
128  convolution_kernel->setValue(0,1, 2);
129  convolution_kernel->setValue(0,2, 1);
130 
131  convolution_kernel->setValue(1,0, 2);
132  convolution_kernel->setValue(1,1, 4);
133  convolution_kernel->setValue(1,2, 2);
134 
135  convolution_kernel->setValue(2,0, 1);
136  convolution_kernel->setValue(2,1, 2);
137  convolution_kernel->setValue(2,2, 1);
138 
139  return std::make_shared<BackgroundConvolution>(convolution_kernel, true);
140 }
141 
143  // check for the extension ".fits"
144  std::string fits_ending(".fits");
145  if (filename.length() >= fits_ending.length()
146  && filename.compare (filename.length() - fits_ending.length(), fits_ending.length(), fits_ending)==0) {
147  // load a FITS filter
148  return loadFITSFilter(filename);
149  }
150  else{
151  // load an ASCII filter
152  return loadASCIIFilter(filename);
153  }
154 }
155 
157 
158  // read in the FITS file
159  auto convolution_kernel = FitsReader<SeFloat>::readFile(filename);
160 
161  // give some feedback on the filter
162  segConfigLogger.info() << "Loaded segmentation filter: " << filename << " height: " << convolution_kernel->getHeight() << " width: " << convolution_kernel->getWidth();
163 
164  // return the correct object
165  return std::make_shared<BackgroundConvolution>(convolution_kernel, true);
166 }
167 
168 static bool getNormalization(std::istream& line_stream) {
169  std::string conv, norm_type;
170  line_stream >> conv >> norm_type;
171  if (conv != "CONV") {
172  throw Elements::Exception() << "Unexpected start for ASCII filter: " << conv;
173  }
174  if (norm_type == "NORM") {
175  return true;
176  }
177  else if (norm_type == "NONORM") {
178  return false;
179  }
180 
181  throw Elements::Exception() << "Unexpected normalization type: " << norm_type;
182 }
183 
184 template <typename T>
185 static void extractValues(std::istream& line_stream, std::vector<T>& data) {
186  T value;
187  while (line_stream.good()) {
188  line_stream >> value;
189  data.push_back(value);
190  }
191 }
192 
194  std::ifstream file;
195 
196  // open the file and check
197  file.open(filename);
198  if (!file.good() || !file.is_open()){
199  throw Elements::Exception() << "Can not load filter: " << filename;
200  }
201 
202  enum class LoadState {
203  STATE_START,
204  STATE_FIRST_LINE,
205  STATE_OTHER_LINES
206  };
207 
208  LoadState state = LoadState::STATE_START;
209  bool normalize = false;
210  std::vector<SeFloat> kernel_data;
211  unsigned int kernel_width = 0;
212 
213  while (file.good()) {
214  std::string line;
215  std::getline(file, line);
216  line = regex_replace(line, regex("\\s*#.*"), std::string(""));
217  line = regex_replace(line, regex("\\s*$"), std::string(""));
218  if (line.size() == 0) {
219  continue;
220  }
221 
222  std::stringstream line_stream(line);
223 
224  switch (state) {
225  case LoadState::STATE_START:
226  normalize = getNormalization(line_stream);
227  state = LoadState::STATE_FIRST_LINE;
228  break;
229  case LoadState::STATE_FIRST_LINE:
230  extractValues(line_stream, kernel_data);
231  kernel_width = kernel_data.size();
232  state = LoadState::STATE_OTHER_LINES;
233  break;
234  case LoadState::STATE_OTHER_LINES:
235  extractValues(line_stream, kernel_data);
236  break;
237  }
238  }
239 
240  // compute the dimensions and create the kernel
241  auto kernel_height = kernel_data.size() / kernel_width;
242  auto convolution_kernel = VectorImage<SeFloat>::create(kernel_width, kernel_height, kernel_data);
243 
244  // give some feedback on the filter
245  segConfigLogger.info() << "Loaded segmentation filter: " << filename << " width: " << convolution_kernel->getWidth() << " height: " << convolution_kernel->getHeight();
246 
247  // return the correct object
248  return std::make_shared<BackgroundConvolution>(convolution_kernel, normalize);
249 }
250 
251 } // SourceXtractor namespace
T at(T... args)
T c_str(T... args)
static Logging getLogger(const std::string &name="")
void info(const std::string &logMessage)
std::shared_ptr< DetectionImageFrame::ImageFilter > m_filter
std::map< std::string, Configuration::OptionDescriptionList > getProgramOptions() override
void preInitialize(const UserValues &args) override
std::shared_ptr< DetectionImageFrame::ImageFilter > getDefaultFilter() const
void initialize(const UserValues &args) override
std::shared_ptr< DetectionImageFrame::ImageFilter > loadFITSFilter(const std::string &filename) const
std::shared_ptr< DetectionImageFrame::ImageFilter > loadASCIIFilter(const std::string &filename) const
std::shared_ptr< DetectionImageFrame::ImageFilter > loadFilter(const std::string &filename) const
static std::shared_ptr< VectorImage< T > > create(Args &&... args)
Definition: VectorImage.h:100
T getline(T... args)
T good(T... args)
T is_open(T... args)
static void extractValues(std::istream &line_stream, std::vector< T > &data)
static const std::string SEGMENTATION_ML_THRESHOLD
static const std::string SEGMENTATION_ALGORITHM
static const std::string SEGMENTATION_FILTER
static bool getNormalization(std::istream &line_stream)
static const std::string SEGMENTATION_LUTZ_WINDOW_SIZE
static const std::string SEGMENTATION_DISABLE_FILTERING
static const std::string SEGMENTATION_ML_MODEL
static Elements::Logging segConfigLogger
static const std::string SEGMENTATION_BFS_MAX_DELTA
string filename
Definition: conf.py:63
T open(T... args)
T push_back(T... args)
T regex_replace(T... args)
T length(T... args)