/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2022 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include "PatternAnalysis.h"

namespace sixte {

static bool isNeighbor(const NewEvent& e1, const NewEvent& e2) {
  if (((e1.rawx == e2.rawx + 1) && (e1.rawy == e2.rawy)) ||
      ((e1.rawx == e2.rawx - 1) && (e1.rawy == e2.rawy)) ||
      ((e1.rawx == e2.rawx) && (e1.rawy == e2.rawy + 1)) ||
      ((e1.rawx == e2.rawx) && (e1.rawy == e2.rawy - 1))) {
    return true;
  } else {
    return false;
  }
}

void PatternAnalysis::phpat(const Ebounds& ebounds, NewEventfile& src_file, NewEventfile& dst_file) {
  PatternStatistics statistics;
  
  statistics.ngrade.resize(num_patterns);
  statistics.npgrade.resize(num_patterns);
  
  std::fill(statistics.ngrade.begin(), statistics.ngrade.end(), 0);
  std::fill(statistics.npgrade.begin(), statistics.npgrade.end(), 0);

  // List of all events belonging to the current frame.
  std::list<NewEvent> framelist;
  const long maxnframelist = 10000;
  
  static bool threshold_warning_printed = false;
  
  // Check if the input file contains single-pixel events.
  std::string evtype = src_file.getKeyStringCFITSIO("EVTYPE");

  if (evtype.empty()) throw SixteException ("failed to read event type from event file");
  
  if (!boost::iequals(evtype, "PIXEL")) {
    throw SixteException("event type of input file is '" + evtype + "' must be 'PIXEL')");
  }
  dst_file.updateKeyStringCFITSIO("EVTYPE", "PATTERN", "event type");
  
  // Particular instruments require a special pattern recombination scheme (e.g. eROSITA).
  std::string telescope = src_file.getKeyStringCFITSIO("TELESCOP");
  
  bool iseROSITA = false;
  if (boost::iequals(telescope, "EROSITA")) iseROSITA = true;

  size_t rows = src_file.getRowNumCFITSIO("EVENTS");

  for (size_t row = 1; row <= rows; row++) {
    // Read the next event from the file, if the end has not been reached so far.
    NewEvent event = src_file.getEventFromCFITSIOFile(row);

    if (rows == 1 && event.signal == 0) {
      healog(5) << "Empty raw file detected\n";
      break;
    }

    bool newframe = !framelist.empty() && (event.frame != framelist.back().frame);
    if (newframe) {
      patternAnalysisOfNewframe(framelist,
                                iseROSITA,
                                threshold_warning_printed,
                                ebounds,
                                statistics,
                                dst_file);
    }

    framelist.push_back(event);
    if (framelist.size() >= maxnframelist) throw SixteException("too many events in the same frame");
  }

  if (!framelist.empty()) {
    patternAnalysisOfNewframe(framelist,
                              iseROSITA,
                              threshold_warning_printed,
                              ebounds,
                              statistics,
                              dst_file);
  }

  // Store pattern statistics in the output file.
  dst_file.updateKeyCFITSIO("NVALID", statistics.nvalids, "number of valid patterns");
  dst_file.updateKeyCFITSIO("NPVALID", statistics.npvalids, "number of piled up valid patterns");
  dst_file.updateKeyCFITSIO("NINVALID", statistics.ninvalids, "number of invalid patterns");
  dst_file.updateKeyCFITSIO("NPINVALI", statistics.npinvalids, "number of piled up invalid patterns");
  
  // Numbered grades.
  for (size_t jj = 0; jj < num_patterns; jj++) {
    std::string keyword = "NGRAD" + std::to_string(jj);
    std::string comment = "number of patterns with grade " + std::to_string(jj);
    dst_file.updateKeyCFITSIO(keyword, statistics.ngrade[jj], comment);
    keyword = "NPGRA" + std::to_string(jj);
    comment = "number of piled up patterns with grade " + std::to_string(jj);
    dst_file.updateKeyCFITSIO(keyword, statistics.npgrade[jj], comment);
  }
  emptyPixelList(framelist);
}

PatternAnalysis::PatternAnalysis(bool skip_invalids,
                                 XMLData& xml_data,
                                 size_t rawymin, size_t rawymax)
    : rawymin_{rawymin},
      rawymax_{rawymax} {
  auto detector = xml_data.child("detector");
  
  if (auto threshold_elem = detector.optionalChild("threshold_event_lo_keV")) {
    threshold_event_lo_keV_ = threshold_elem->attributeAsDouble("value");
  }

  if (auto threshold_elem = detector.optionalChild("threshold_split_lo_keV")) {
    threshold_split_lo_keV_ = threshold_elem->attributeAsDouble("value");
  }

  if (auto threshold_elem = detector.optionalChild("threshold_split_lo_fraction")) {
    threshold_split_lo_fraction_ = threshold_elem->attributeAsDouble("value");
  }

  if (auto threshold_elem = detector.optionalChild("threshold_pattern_up_keV")) {
    threshold_pattern_up_keV_ = threshold_elem->attributeAsDouble("value");
  }

  // Check if geometry node exists, otherwise use detector directly
  // (for backward compatibility)
  auto parent_node = detector.optionalChild("geometry").value_or(detector);
  xwidth_ = parent_node.child("dimensions").attributeAsInt("xwidth");
  skip_invalids_ = skip_invalids;
  
}

void PatternAnalysis::patternAnalysisOfNewframe(std::list<NewEvent>& framelist,
                                                bool iseROSITA,
                                                bool threshold_warning_printed,
                                                const Ebounds& ebounds,
                                                PatternStatistics& statistics,
                                                NewEventfile& dst_file) {
  assert(!framelist.empty());

  size_t framelist_size = framelist.size();
  
  // List of all neighboring events in the current frame.
  std::list<NewEvent> neighborlist;
  const long maxnneighborlist = 2000;

  while (!framelist.empty()) {

    NewEvent frame = framelist.front();
    framelist.pop_front();

    if (framelist.size() >= framelist_size) throw SixteException("Failed to reduce framelist size");

    if ((frame.signal * frame.signal) < (threshold_event_lo_keV_ * threshold_event_lo_keV_)) {
      continue;
    }
    
    // Start a new neighbor list.
    neighborlist.push_back(frame);
    
    NewEvent maxsignalev = findNeighborMaxEvent(framelist, neighborlist);
    
    double split_threshold = getSplitPatternThreshold(framelist, maxsignalev, iseROSITA);
    
    if ((split_threshold > threshold_event_lo_keV_) && (!threshold_warning_printed)) {
      printWarning("split threshold (" + std::to_string(split_threshold * 1000.0)
                     + ") is above event threshold (" + std::to_string(threshold_event_lo_keV_ * 1000.0)
                     + ") (message is printed only once)" );
      threshold_warning_printed = true;
    }
    
    findNeighborEvents(framelist, neighborlist, split_threshold, maxnneighborlist);
    
    NewEvent max_ev = findMaxSignalPixel(neighborlist);
    NewEvent new_event(max_ev);
    
    new_event.ra = 0.;
    new_event.dec = 0.;
    new_event.npixels = (long)neighborlist.size();
    new_event.signal = 0.;
    for (size_t ii = 0; ii < NEVENTPHOTONS; ii++) new_event.ph_id[ii] = 0;

    if (new_event.pileup != 0) throw SixteException("Pileup flag nonzero before pileup is set");
    
    bool touched_border = false;

    for (const auto& neighbor : neighborlist) {
      new_event.signal += neighbor.signal;
      // If a contribution was negative, flag as invalid (-2, such that it doesn't collide with definition afterwards.
      // Is changed to -1 at the end of the process.)
      // TODO: Check if this can be simplified
      if (neighbor.signal < 0.) {
        new_event.type = -2;
      } else {
        new_event.type = -1;
      }
      
      savePixelContribution(neighbor, max_ev, new_event);
      setPHandSRCid(new_event, neighbor);
      touched_border = checkForBorderPixels(neighbor);
    }
    
    // Determine the PHA channel corresponding to the total signal.
    new_event.pha = ebounds.sixteGetEBOUNDSChannel(new_event.signal);
    
    if (new_event.ph_id[1]) new_event.pileup = 1;
    
    getEventType(new_event, neighborlist.size(), touched_border);
    
    emptyPixelList(neighborlist);
    
    if (checkUpperThreshold(new_event)) {
      updatePatternStatistics(new_event, statistics);
      addEvent2File(new_event, dst_file);
    }

    framelist_size = framelist.size();
  }
  // END of loop over all events in the frame list.
  emptyPixelList(framelist);

  assert(framelist.empty());
}


void PatternAnalysis::getEventType(NewEvent& new_event, size_t size, bool touched_border) {
  if (size <= 0) throw SixteException("Empty neighborlist");

  if (touched_border || new_event.type == -2) {
    new_event.type = -1;
    return;
  }

  switch(size) {
    case 1:
      new_event.type = 0; // Single event.
      break;
    case 2:
      new_event.type = getDoubleType(new_event.signals);
      break;
    case 3:
      new_event.type = getTripleType(new_event.signals);
      break;
    case 4:
      new_event.type = getQuadrupleType(new_event.signals);
      break;
    default: new_event.type = -1;
  }
}

int PatternAnalysis::getDoubleType(const std::vector<double>& signals) {
  if (signals[1] > 0.) {
    return 3; // bottom
  } else if (signals[3] > 0.) {
    return 4; // left
  } else if (signals[7] > 0.) {
    return 1; // top
  } else if (signals[5] > 0.) {
    return 2; // right
  } else {
    return -1;
  }
}

int PatternAnalysis::getTripleType(const std::vector<double>& signals) {
  if (signals[1] > 0.) {
    // bottom
    if (signals[3] > 0.) {
      return 7; // bottom-left
    } else if (signals[5] > 0.) {
      return 6; // bottom-right
    }
  } else if (signals[7] > 0.) {
    // top
    if (signals[3] > 0.) {
      return 8; // top-left
    } else if (signals[5] > 0.) {
      return 5; // top-right
    }
  }
  return -1;
}

int PatternAnalysis::getQuadrupleType(const std::vector<double>& signals) {
  if (signals[0] > 0.) { // bottom-left
    if ((signals[1] > signals[0]) &&
      (signals[3] > signals[0])) {
      return 11;
    }
  } else if (signals[2] > 0.) { // bottom-right
    if ((signals[1] > signals[2]) &&
      (signals[5] > signals[2])) {
      return 10;
    }
  } else if (signals[6] > 0.) { // top-left
    if ((signals[7] > signals[6]) &&
      (signals[3] > signals[6])) {
      return 12;
    }
  } else if (signals[8] > 0.) { // top-right
    if ((signals[7] > signals[8]) &&
      (signals[5] > signals[8])) {
      return 9;
    }
  }
  return -1;
}

void PatternAnalysis::savePixelContribution (const NewEvent& neighbor, const NewEvent& max_ev, NewEvent& new_event) {
  if (neighbor.rawx == max_ev.rawx - 1) {
    if (neighbor.rawy == max_ev.rawy - 1) {
      new_event.signals[0] = neighbor.signal;
      new_event.phas[0] = neighbor.pha;
    } else if (neighbor.rawy == max_ev.rawy) {
      new_event.signals[3] = neighbor.signal;
      new_event.phas[3] = neighbor.pha;
    } else if (neighbor.rawy == max_ev.rawy + 1) {
      new_event.signals[6] = neighbor.signal;
      new_event.phas[6] = neighbor.pha;
    }
  } else if (neighbor.rawx == max_ev.rawx) {
    if (neighbor.rawy == max_ev.rawy - 1) {
      new_event.signals[1] = neighbor.signal;
      new_event.phas[1] = neighbor.pha;
    } else if (neighbor.rawy == max_ev.rawy) {
      new_event.signals[4] = neighbor.signal;
      new_event.phas[4] = neighbor.pha;
    } else if (neighbor.rawy == max_ev.rawy + 1) {
      new_event.signals[7] = neighbor.signal;
      new_event.phas[7] = neighbor.pha;
    }
  } else if (neighbor.rawx == max_ev.rawx + 1) {
    if (neighbor.rawy == max_ev.rawy - 1) {
      new_event.signals[2] = neighbor.signal;
      new_event.phas[2] = neighbor.pha;
    } else if (neighbor.rawy == max_ev.rawy) {
      new_event.signals[5] = neighbor.signal;
      new_event.phas[5] = neighbor.pha;
    } else if (neighbor.rawy == max_ev.rawy + 1) {
      new_event.signals[8] = neighbor.signal;
      new_event.phas[8] = neighbor.pha;
    }
  }
}

NewEvent PatternAnalysis::findMaxSignalPixel(const std::list<NewEvent>& neighborlist) {
  NewEvent max_ev = neighborlist.front();
  for (const NewEvent& neighbors : neighborlist) {
    if (neighbors.signal > max_ev.signal) max_ev = neighbors;
  }
  return max_ev;
}

void PatternAnalysis::findNeighborEvents(std::list<NewEvent>& framelist,
                                         std::list<NewEvent>& neighborlist,
                                         double split_threshold,
                                         size_t maxnneighborlist) {
  if (neighborlist.empty()) throw SixteException("List of neighboring pixels can not be empty");

  for (auto neighbor_it = neighborlist.begin(); neighbor_it != neighborlist.end(); neighbor_it++) {
    auto frame_it = framelist.begin();
    while(frame_it != framelist.end()) {
      if (!isNeighbor(*neighbor_it, *frame_it) || frame_it->signal < split_threshold) {
        frame_it++;
        continue;
      }

      if (neighborlist.size() >= maxnneighborlist) throw SixteException("too many events in the same pattern");

      neighborlist.push_back(*frame_it);
      frame_it = framelist.erase(frame_it);
    }
  }
}

bool PatternAnalysis::checkForBorderPixels(const NewEvent& neighbor) const {
  bool touched_border = false;
  if ((0 == neighbor.rawx)
      || (neighbor.rawx == xwidth_ - 1)
      || (rawymin_ == neighbor.rawy)
      || (neighbor.rawy == rawymax_)) {
    touched_border = true;
  }
  return touched_border;
}

double PatternAnalysis::getSplitPatternThreshold(const std::list<NewEvent>& framelist,
                                                 const NewEvent& maxsignalev,
                                                 bool iseROSITA) const {
  // set the default value
  double split_threshold = threshold_split_lo_keV_;
  
  // For eROSITA we need a special treatment (according to a prescription of K. Dennerl). // ToDo: document somewhere else
  if (threshold_split_lo_fraction_ > 0.) {
    
    if (iseROSITA) {
      double vertical = 0., horizontal = 0.;
      
      for (const NewEvent& event_frame : framelist) {
        if (isNeighbor(maxsignalev, event_frame)) {
          if (event_frame.rawx == maxsignalev.rawx) {
            if (event_frame.signal > horizontal) {
              horizontal = event_frame.signal;
            }
          } else {
            if (event_frame.signal > vertical) {
              vertical = event_frame.signal;
            }
          }
        }
      }
      split_threshold = threshold_split_lo_fraction_ * (maxsignalev.signal + horizontal + vertical);
    } else {
      // Split threshold for generic instruments.
      split_threshold = threshold_split_lo_fraction_ * maxsignalev.signal;
    }
  }
  return split_threshold;
}

NewEvent PatternAnalysis::findNeighborMaxEvent(const std::list<NewEvent>& framelist,
                                               const std::list<NewEvent>& neighborlist) {
  NewEvent maxsignalev = neighborlist.front();
  bool updated;
  do {
    updated = false;
    for (const NewEvent &event_frame : framelist) {
      if (isNeighbor(maxsignalev, event_frame) && (event_frame.signal > maxsignalev.signal)) {
        maxsignalev = event_frame;
        updated = true;
      }
    }
  } while (updated);
  
  return maxsignalev;
}

void PatternAnalysis::setPHandSRCid(NewEvent& new_event, const NewEvent& neighbor) {
  for (long ll = 0; ll < NEVENTPHOTONS; ll++) {
    if (neighbor.ph_id[ll] == 0) break;
    
    for (long mm = 0; mm < NEVENTPHOTONS; mm++) {
      if (new_event.ph_id[mm] == neighbor.ph_id[ll]) break;
      
      if (new_event.ph_id[mm] == 0) {
        new_event.ph_id[mm] = neighbor.ph_id[ll];
        new_event.src_id[mm] = neighbor.src_id[ll];
        break;
      }
    }
  }
}

void PatternAnalysis::updatePatternStatistics(const NewEvent& event, PatternStatistics& statistics) {
  if (event.type < 0) {
    statistics.ninvalids++;
    if (event.pileup > 0) {
      statistics.npinvalids++;
    }
  } else {
    statistics.nvalids++;
    statistics.ngrade[event.type]++;
    if (event.pileup > 0) {
      statistics.npvalids++;
      statistics.npgrade[event.type]++;
    }
  }
}

bool PatternAnalysis::checkUpperThreshold(const NewEvent& event) const {
  if ((threshold_pattern_up_keV_ == 0.) || (event.signal <= threshold_pattern_up_keV_)) {
    return true;
  } else {
    return false;
  }
}

void PatternAnalysis::addEvent2File(const NewEvent& event, NewEventfile& dst_file) {
  if ((!skip_invalids_) || (event.type >= 0)) {
    dst_file.addEvent2CFITSIOFile(event);
  }
}

void PatternAnalysis::emptyPixelList(std::list<NewEvent>& event_list) {
  event_list.clear();
  
  if (!event_list.empty()) throw SixteException("Failed to empty pixel event list");
}

}

