/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2022 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include "Sensor.h"
#include "OperationStrategy.h"
#include "paraminput.h"
#include "Absorber.h"
#include <algorithm>
#include <numeric>
#include <set>
#include <utility>
#include <vector>


namespace sixte {

ShiftArray::ShiftArray(XMLData& xml_data, const ObsInfo& obs_info, const OutputFiles& outfiles,
                       bool skip_invalids, size_t xml_id, std::shared_ptr<const Ebounds> canonical_ebounds)
      : operation_strategy_(xml_data, outfiles, obs_info, skip_invalids, xml_id, canonical_ebounds) {
  array_geometry_ = std::make_shared<RectangularArray>(xml_data); // TODO: Use createGeometry?

  operation_strategy_.setReadoutStrategy(
          std::make_unique<SimpleShiftArrayReadout>(xml_data, outfiles.raw_datas[xml_id],
                                                    outfiles.clobber, outfiles.delete_rawdata,
                                                    obs_info, canonical_ebounds));
  
  if (auto cte_elem = xml_data.child("detector").optionalChild("cte")) {
    charge_transfer_efficiency_ = cte_elem->attributeAsDouble("value");
  }

  num_lines_ = array_geometry_->getYWidth();
  num_pix_in_line_ = array_geometry_->getXWidth();
  active_count_in_line_.assign(num_lines_, 0);

  // Initialize pix_ids_in_lines_, lines_of_pix_ids_, and shift_target_pix_ids_
  size_t total_pixels = array_geometry_->numpix();
  lines_of_pix_ids_.reserve(total_pixels);
  shift_target_pix_ids_.reserve(total_pixels);
  
  for (int line = 0; line < num_lines_; line++) {
    std::vector<T_PixId> pix_ids_in_line(num_pix_in_line_);
    T_PixId first_pix_id_in_line = line * num_pix_in_line_;
    std::iota(std::begin(pix_ids_in_line), std::end(pix_ids_in_line),
              first_pix_id_in_line);
    pix_ids_in_lines_.push_back(pix_ids_in_line);
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-variable"
    for (auto const& pix_id: pix_ids_in_line) {
      lines_of_pix_ids_.push_back(line);
    }
  }
#pragma GCC diagnostic pop

  // Precompute shift target pixel IDs using array_geometry_
  for (T_PixId pix_id = 0; pix_id < total_pixels; pix_id++) {
    auto xy = array_geometry_->PixId2xy(pix_id);
    if (xy.second > 0) {
      T_PixId target_id = array_geometry_->xy2PixId(xy.first, xy.second - 1);
      shift_target_pix_ids_.push_back(target_id);
    } else {
      // First line pixels have no target (invalid ID)
      shift_target_pix_ids_.push_back(static_cast<T_PixId>(-1));
    }
  }

  line_info_ = std::make_shared<LineInfo>(num_lines_, lines_of_pix_ids_); // TODO: Initialize lines in LineInfo?

  //TODO: Use algorithm
  for (size_t ii = 0; ii < array_geometry_->numpix(); ii++) {
    // TODO: Is there a better place for createPixel?
    pixels_.push_back(createPixel(xml_data, operation_strategy_.frameDuration(),
                                  ii, line_info_));
  }
}

void ShiftArray::addSignal(const Signal& signal, T_PixId pix_id) {
  pixels_[pix_id]->addSignal(signal);
  addActiveID(pix_id);
}

void ShiftArray::addSignal(Signal&& signal, T_PixId pix_id) {
  pixels_[pix_id]->addSignal(std::move(signal));
  addActiveID(pix_id);
}

void ShiftArray::addCarrier(const Carrier& carrier) {
  // Get all affected pixels
  std::vector<T_PixId> pix_ids = this->array_geometry_->getPixIds(carrier.bounding_box());

  for (const auto& id: pix_ids) {
    auto signal = carrier.charge(array_geometry_->getPolygon(id));
    if (signal < 1e-9) {
      continue;
    }

    addSignal(Signal(signal,
                     std::vector<PhotonMetainfo>{carrier.photon_metainfo()},
                     carrier.creationTime()), id);
  }
}

void ShiftArray::propagateReadout(double tstop) {
  operation_strategy_.propagate(*this, tstop);
}

void ShiftArray::finishReadout(double tstop) {
  propagateReadout(tstop);
}

void ShiftArray::postprocessing(Absorber& absorber, NewAttitude& attitude, const GTICollection& gti) {
  operation_strategy_.postprocessing(*this, absorber, attitude, gti);
}

std::shared_ptr<ArrayGeometry> ShiftArray::arrayGeometry() {
  return array_geometry_;
}

std::shared_ptr<RectangularArray> ShiftArray::rectangularArrayGeometry() {
  return array_geometry_;
}

std::shared_ptr<LineInfo> ShiftArray::lineInfo() {
  return line_info_;
}

std::optional<Signal> ShiftArray::getSignal(T_PixId pix_id) const {
  return pixels_[pix_id]->getSignal();
}

std::optional<Signal> ShiftArray::releaseSignal(T_PixId pix_id) {
  auto sig = pixels_[pix_id]->releaseSignal();
  if (!sig) return std::nullopt;

  if (!pixels_[pix_id]->anySignal()) {
    removeActiveID(pix_id);
  }

  return sig;
}


void ShiftArray::setLastReadoutTime(unsigned int lineindex, double current_clock_time) {
  line_info_->setLastReadoutTime(lineindex, current_clock_time);
}

void ShiftArray::addLastReadoutTime(unsigned int lineindex, double time) {
  line_info_->addLastReadoutTime(lineindex, time);
}

void ShiftArray::addActiveID(T_PixId id) {
  if (std::find(active_pixels_.begin(), active_pixels_.end(), id) == active_pixels_.end()) {
    active_pixels_.emplace_back(id);
    const unsigned int line = lines_of_pix_ids_[id];
    ++active_count_in_line_[line];
  }
}

void ShiftArray::removeActiveID(T_PixId id) {
  auto it = std::find(active_pixels_.begin(), active_pixels_.end(), id);
  if (it != active_pixels_.end()) {
    const unsigned int line = lines_of_pix_ids_[id];
    assert(active_count_in_line_[line] > 0);
    --active_count_in_line_[line];
    active_pixels_.erase(it);
  }
}

void ShiftArray::clearPixels(const std::vector<T_PixId>& pix_ids) {
  for (auto id: pix_ids) {
    pixels_[id]->clearSignal();
    if (!pixels_[id]->anySignal()) {
      removeActiveID(id);
    }
  }
}

void ShiftArray::shiftLines() {
  if (num_lines_ < 2 || active_pixels_.empty())
    return;

  static std::vector<T_PixId> pixel_id;
  pixel_id.assign(active_pixels_.begin(), active_pixels_.end());

  std::sort(pixel_id.begin(), pixel_id.end());

  for (const auto& id: pixel_id) {
    unsigned int src_line = lines_of_pix_ids_[id];
    
    if (src_line == 0) {
      if (active_count_in_line_[src_line] == 0)
        throw SixteException("No signal in active line");

      continue;
    }

    auto src_pixel_signal = releaseSignal(id);
    if (!src_pixel_signal) throw SixteException("No signal in active pixel");

    src_pixel_signal->scale(charge_transfer_efficiency_);

    auto target_pixel_id = shift_target_pix_ids_[id];
    addSignal(std::move(*src_pixel_signal), target_pixel_id);
  }
}

const std::vector<T_PixId>& ShiftArray::getPixIdsInLine(unsigned int lineindex) const {
  return pix_ids_in_lines_[lineindex];
}

void ShiftArray::clearLine(unsigned int lineindex) {
  if (!anySignalInLine(lineindex)) {
    return;
  }

  const auto& pixels_to_clear = pix_ids_in_lines_[lineindex];
  this->clearPixels(pixels_to_clear);
}

bool ShiftArray::anySignal() const {
  return !active_pixels_.empty();
}

bool ShiftArray::anySignalInLine(unsigned int lineindex) const {
  return active_count_in_line_[lineindex] > 0;
}

size_t ShiftArray::numpix() {
  return array_geometry_->numpix();
}

std::pair<unsigned int, unsigned int> ShiftArray::PixId2xy(T_PixId pix_id) {
  return array_geometry_->PixId2xy(pix_id);
}

int ShiftArray::numLines() const {
  return num_lines_;
}

void ShiftArray::clear() {
  for (int line = 0; line < num_lines_; line++) {
    clearLine(line);
  }
}

void ShiftArray::setStartTime(double tstart) {
  operation_strategy_.setStartTime(tstart);
}

PixArray::PixArray(XMLData& xml_data, const ObsInfo& obs_info, const OutputFiles& outfiles, bool skip_invalids, size_t xml_id, std::shared_ptr<const Ebounds> canonical_ebounds)
    : operation_strategy_(outfiles.evt_files[xml_id], outfiles.clobber, skip_invalids, xml_data, obs_info, canonical_ebounds) {
  auto type = xml_data.child("detector").child("geometry").attributeAsString("type");
  if (type != "rectarray") {
    throw SixteException("PixArray only supports rectarray geometry");
  }

  array_geometry_ = std::make_shared<RectangularArray>(xml_data);
  for (size_t pix = 0; pix < array_geometry_->numpix(); pix++) {
    pixels_.emplace_back(xml_data);
  }

  operation_strategy_.setReadoutStrategy(std::make_unique<SimplePixArrayReadout>(xml_data, outfiles.raw_datas[xml_id],
                                                                                 outfiles.clobber, outfiles.delete_rawdata,
                                                                                 obs_info, canonical_ebounds));

}

void PixArray::addCarrier(const Carrier& carrier) {
  // Get all affected pixels
  std::vector<T_PixId> pix_ids = array_geometry_->getPixIds(carrier.bounding_box());

  for (const auto& id: pix_ids) {
    auto signal = carrier.charge(array_geometry_->getPolygon(id));
    if (signal < 1e-9) {
      continue;
    }

    // Get the charge in this pixel, create signal and add it to respective pixel.
    pixels_[id].addSignal(Signal(signal,
          std::vector<PhotonMetainfo>{carrier.photon_metainfo()},
          carrier.creationTime()));

    if (pixels_[id].getSignal()) {
      // Remember that this pixel is active.
      // But only if it has signal (may not have happened due to deadtime!)
      active_pixels_.insert(id);
    }
  }
}

std::set<T_PixId> PixArray::getActiveInArray() const {
  return active_pixels_;
}

bool PixArray::anySignal() const {
  return !active_pixels_.empty();
}

const std::optional<Signal>& PixArray::getSignal(T_PixId pix_id) const {
  return pixels_[pix_id].getSignal();
}

std::optional<Signal> PixArray::releaseSignal(T_PixId pix_id) {
  auto sig = pixels_[pix_id].releaseSignal();
  if (!sig) return std::nullopt;

  if (!pixels_[pix_id].anySignal()) {
    active_pixels_.erase(pix_id);
  }
  return sig;
}

void PixArray::clearPixels(const std::set<T_PixId>& pix_ids) {
  for (auto id: pix_ids) {
    pixels_[id].clearSignal();

    // Some pixel types may still have signal after being cleared.
    if (!pixels_[id].anySignal()) {
      active_pixels_.erase(id);
    }
  }
}

void PixArray::propagateReadout(double tstop) {
  operation_strategy_.propagate(*this, tstop);
}

void PixArray::finishReadout(double tstop) {
  propagateReadout(tstop);
}

void PixArray::postprocessing(Absorber& absorber, NewAttitude& attitude, const GTICollection& gti) {
  operation_strategy_.postprocessing(*this, absorber, attitude, gti);
}

std::shared_ptr<ArrayGeometry> PixArray::arrayGeometry() {
  return array_geometry_;
}

std::shared_ptr<RectangularArray> PixArray::rectangularArrayGeometry() {
  return array_geometry_;
}

void PixArray::clear() {
  clearPixels(getActiveInArray());
}

void PixArray::setStartTime(double /*tstart*/) {
  // Nothing to do here
}

MicroCal::MicroCal(XMLData& xml_data, const ObsInfo& obs_info, const OutputFiles& outfiles, bool, size_t num_tels,
                   std::shared_ptr<RmfRegistry> rmf_registry)
: operation_strategy_(xml_data, outfiles.evt_files[num_tels], outfiles.clobber, obs_info, std::move(rmf_registry)) {
  array_geometry_ = createGeometry(xml_data);

  for (size_t pix = 0; pix < array_geometry_->numpix(); pix++) {
    pixels_.emplace_back(xml_data);
  }

  threshold_event_lo_keV_ = xml_data.child("detector").child("threshold_event_lo_keV").attributeAsDouble("value");

  std::string do_crosstalk = queryParameterString("doCrosstalk");

  if (auto reconstruction = xml_data.child("detector").optionalChild("reconstruction")) {
    if (reconstruction->hasChild("crosstalk") && do_crosstalk != "none" ) {
      xt_handler_.emplace(xml_data, do_crosstalk, array_geometry_.get());
    }
  }

  last_readout_time_.assign(pixels_.size(), std::nullopt);
}

void MicroCal::addCarrier(const Carrier& carrier) {
  // Get all affected pixels
  std::vector<T_PixId> pix_ids =
    array_geometry_->getPixIds(carrier.bounding_box());

  num_events_++;

  for (const auto& id: pix_ids) {
    auto deposition = carrier.charge(array_geometry_->getPolygon(id));
    if (deposition < 1e-9) {
      continue;
    }

    Signal new_sig(deposition,
        std::vector<PhotonMetainfo>{carrier.photon_metainfo()},
        carrier.creationTime());

    // handle triggering
    if (new_sig.val() >= threshold_event_lo_keV_) {
      addSignal(id, new_sig);
    } else {
      // small energy photons cause crosstalk,
      // but are otherwise ignored in terms of readout
      num_evt_below_trigger_++;
    }

    if (xt_handler_.has_value()) {
      xt_handler_->createProxies(id, new_sig);

      // handle the case of crosstalk causing "fake" triggers
      auto xt_triggered = xt_handler_->checkProxyTriggers(id);
      for (auto vic_id: xt_triggered) {
        Signal fake_sig(0.,
            std::vector<PhotonMetainfo>{carrier.photon_metainfo()},
            carrier.creationTime());

        addSignal(vic_id, std::move(fake_sig));
      }

      // regularily clear proxies
      proxy_counter++;

      if (proxy_counter >= PROXY_CLEANUP_PERIOD) {
        double oldest_signal_time = carrier.creationTime();
        for (auto id2: active_pixels_) {

          double t = getSignal(id2)->creation_time();

          oldest_signal_time = std::min(t, oldest_signal_time);
        }

        xt_handler_->cleanProxiesGlobal(oldest_signal_time);

        proxy_counter = 0;
      }
    }
  }
}

void MicroCal::addSignal(const T_PixId& pix_id, const Signal& new_signal) {
  addSignalImpl(pix_id, new_signal);
}

void MicroCal::addSignal(const T_PixId& pix_id, Signal&& new_signal) {
  addSignalImpl(pix_id, std::move(new_signal));
}

void MicroCal::propagateReadout(double tstop) {
  operation_strategy_.propagate(*this, tstop);
}

void MicroCal::finishReadout(double) {
  // read out every remaining pixel, even if tstop is smaller than
  // the end of the record
  auto active_pixels = getActivePixels();

  for (auto id: active_pixels) {
    operation_strategy_.runReadout(*this, id, std::nullopt);
  }
}

void MicroCal::postprocessing(Absorber& absorber,
    NewAttitude& attitude,
    const GTICollection& gti) {
  operation_strategy_.postprocessing(*this, absorber, attitude, gti);

  if (float(num_evt_below_trigger_) / float(num_events_) >= 0.1) {
    std::cout << "\n*** WARNING: Number of non-triggering photons "
      << "(i.e., energy below " << threshold_event_lo_keV_ << " keV)"
      << " is greater than 10\% of total photons ("
      << num_evt_below_trigger_ << "/" << num_events_
      << ")!\n***          These are not correctly treated in "
      << "the current Microcalorimeter implementation!"
      << std::endl;
  }
}

bool MicroCal::anySignal() {
  return !active_pixels_.empty();
}

const std::optional<Signal>& MicroCal::getSignal(T_PixId pix_id) {
  return pixels_[pix_id].getSignal();
}

const std::set<T_PixId>& MicroCal::getActivePixels() const {
  return active_pixels_;
}

void MicroCal::clearSignal(T_PixId pix_id) {
  pixels_[pix_id].clearSignal();

  active_pixels_.erase(pix_id);
}

std::pair<unsigned int, double> MicroCal::calcTotalCrosstalk(
    T_PixId i_pix,
    const Signal &pix_signal,
    unsigned int grade_id) {
  if (xt_handler_.has_value()) {
    return xt_handler_->calcTotalCrosstalk(i_pix, pix_signal, grade_id);
  } else {
    return std::pair(0,0.);
  }
}

std::optional<double> MicroCal::getLastReadoutTime(T_PixId pix_id) {
  return last_readout_time_[pix_id];
}

void MicroCal::setLastReadoutTime(T_PixId pix_id, double t) {
  last_readout_time_[pix_id] = t;
}

void MicroCal::clear() {
  for (const auto& pix_id: active_pixels_) {
    clearSignal(pix_id);
  }
}

void MicroCal::setStartTime(double /*tstart*/) {
  // Nothing to do here
}

std::unique_ptr<Sensor> createSensor(XMLData& xml_data, const ObsInfo& obs_info, const OutputFiles& outfiles,
                                    bool skip_invalids, size_t xml_id,
                                    std::shared_ptr<const Ebounds> canonical_ebounds,
                                    std::shared_ptr<RmfRegistry> rmf_registry) {
  // TODO: Use enum+switch
  auto type = xml_data.child("detector").attributeAsString("type");

  if (type == "ccd") {
    return std::make_unique<ShiftArray>(xml_data, obs_info, outfiles, skip_invalids, xml_id, canonical_ebounds);
  }

  if (type == "microcal") {
    return std::make_unique<MicroCal>(xml_data, obs_info, outfiles, skip_invalids, xml_id,
                                      std::move(rmf_registry));
  }

  if (type == "depfet") {
    return std::make_unique<ShiftArray>(xml_data, obs_info, outfiles, skip_invalids, xml_id, canonical_ebounds);
  }

  // NuSTAR
  if (type == "CdZnTe") {
    return std::make_unique<PixArray>(xml_data, obs_info, outfiles, skip_invalids, xml_id, canonical_ebounds);
  }

  if (type == "ssd") {
    return std::make_unique<PixArray>(xml_data, obs_info, outfiles, skip_invalids, xml_id, canonical_ebounds);
  }

  throw SixteException("Failed to read detector type");
}

}
