/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2025 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include "SixteAUXBackground.h"

#include "SixtePhoton.h"
#include "sixte_random.h"

#include "simput.h"

namespace sixte {

AUXBackground::AUXBackground(const std::string& filename,
                             double chip_area,
                             size_t tel_id,
                             size_t chip_id,
                             std::optional<double> aux_rate,
                             const std::optional<std::string>& simput_lc_filename)
    : input_fits_file_(filename, FileMode::read), tel_id_(tel_id), chip_id_(chip_id) {

  if(aux_rate){
    aux_rate_ = *aux_rate;
  } else {
    try {
      input_fits_file_.readKey("RATE", aux_rate_);
    } catch (const CCfits::FitsException& e) {
      throw SixteException("RATE not given in XML or FITS file. Must be given in either one");
    }
  }

  if (aux_rate_ == 0.) throw SixteException("No AUX Rate given, generation of AUX Background failed");

  aux_rate_ *= chip_area * 10000; // area needs to be in cm^2 here

  int is_pho_bkg;
  if (input_fits_file_.checkKeyExists("PHOTONS", is_pho_bkg)) {
    input_fits_file_.readKey("PHOTONS", is_pho_bkg);
    is_photon_bkg_ = is_pho_bkg;
  } else {
    is_photon_bkg_ = false; // it's a particle bkg
  }

  num_rows_input_file_ = input_fits_file_.getNumRows();

  input_fits_file_.readCol("primaryid", 1, num_rows_input_file_, impact_ids_);
  input_fits_file_.readCol("edep", 1, num_rows_input_file_, impact_energies_);
  input_fits_file_.readCol("X", 1, num_rows_input_file_, impact_x_positions_);
  input_fits_file_.readCol("Y", 1, num_rows_input_file_, impact_y_positions_);

  auto energy_col_num = input_fits_file_.getColNum("edep");
  bool energy_input_is_keV = energyInputIsKeV(energy_col_num);

  if (!energy_input_is_keV) { // energy given in eV then
    for(size_t ii=0; ii<num_rows_input_file_; ii++) {
      impact_energies_[ii] /= 1000;
    }
  }

  event_list_ = fillEventList();

  if (simput_lc_filename) {
    background_rate_form_lc_ = BackgroundRateFormLC(*simput_lc_filename);
    simput_lc_ = true;
  } else {
    background_rate_form_lc_ = std::nullopt;
    simput_lc_ = false;
  }
}


bool AUXBackground::energyInputIsKeV(size_t col_num) {

  std::string tunit_name = "TUNIT" + std::to_string(col_num);
  std::string tunit_val;
  input_fits_file_.readKey(tunit_name, tunit_val);

  if (tunit_val.empty()) {
    printWarning("Energy unit is not given! Assuming it to be in keV...");
    return true;
  }

  if (boost::iequals(tunit_val, "kev")) return true;
  if (boost::iequals(tunit_val,"ev")) {
    printWarning("Energy is given in eV. Background output will be in keV!");
    return false;
  }

  throw SixteException("Energy unit must be keV or eV! Generation of AUX background failed");
}


std::vector<size_t> AUXBackground::fillEventList() {
  std::vector<size_t> eventlist;

  for(size_t ii=0; ii<num_rows_input_file_-1; ii++) {
    if (impact_ids_[ii] != impact_ids_[ii+1]) {
      num_events_++;
      eventlist.emplace_back(ii + 1);
    }
  }

  if (eventlist.empty()) {
    eventlist.resize(1);
    eventlist[0] = 0;
  }

  return eventlist;
}


double AUXBackground::calcEventRate (double interval) const {
  return aux_rate_ * interval;
}

void AUXBackground::clearBackgroundList() {
  
  out_num_events_ = 0;
  out_num_impacts_ = 0;
  out_impact_times_.clear();
  out_impact_energies_.clear();
  out_impact_ids_.clear();
  out_impact_x_pos_.clear();
  out_impact_y_pos_.clear();
}


void AUXBackground::bkgGetBackgroundList(double time_start, double time_end) {

  double interval = time_end - time_start;

  if (interval <= 0) throw SixteException("Invalid interval for background generation specified");

  clearBackgroundList();

  double events_per_interval = calcEventRate(interval);

  // If we use a rate function we multiply the output poisson events by the normalized rate(s) for this interval.
  // Otherwise, we just take the unchanged result of the poisson random function.
  if (simput_lc_) {
    throw SixteException("Time variable aux background not yet implemented");
    // TODO: test and adapt light curve calculation
    auto background_rates = background_rate_form_lc_->bkgGetRates(time_start, time_end);
    size_t num_elements = background_rates.size();

    for(size_t ii = 0; ii < num_elements; ii++) {
      out_num_events_ += getPoissonGSLRngNumber(events_per_interval * background_rates[ii]);
    }

  } else {
    out_num_events_ = getPoissonGSLRngNumber(events_per_interval);
  }

  if (out_num_events_ <= 0) return;

  if (out_num_events_ > 0) {

    for(size_t ii = 0; ii < out_num_events_; ii++) {
      size_t impact_count = 0;
      int rand = floor((double)(num_events_ - 1) * getFlatGSLRngNumber(0, 1));
      double time = getFlatGSLRngNumber(time_start, time_end);

      /* Do this until we reach the next "real" event and add subevents to the list. Count their number with out_num_impacts_. */
      for (;;) {
        size_t current_impact = event_list_[rand] + impact_count;

        out_impact_times_.emplace_back(time);
        out_impact_energies_.emplace_back(impact_energies_[current_impact]);
        out_impact_ids_.emplace_back(impact_ids_[current_impact]);
        out_impact_x_pos_.emplace_back(impact_x_positions_[current_impact]);
        out_impact_y_pos_.emplace_back(impact_y_positions_[current_impact]);

        out_num_impacts_++;

        /* Check if we are currently inside a subevent (i.e. the next "event" in the list belongs to the same event -> same time). */
        if (impact_ids_[current_impact] != impact_ids_[num_rows_input_file_ - 1]) {
          if (impact_ids_[current_impact + 1] != impact_ids_[current_impact]) {
            break;
          } else impact_count++;
        } else break;
      }
      std::sort(out_impact_times_.begin(), out_impact_times_.end());
    }
  }
}


CarrierPtr AUXBackground::getNextAUXBkgCarrier (Detector& detector,
                                                std::pair<double, double> dt,
                                                double attitude_dt) {

  for (;;) {
    if (!aux_photon_carrier_buffer_.empty()) return pop_next_carrier(aux_photon_carrier_buffer_);

    if (time_start_ < dt.first) time_start_ = dt.first;
    double time_end = std::min(time_start_ + attitude_dt, dt.second);

    if (time_start_ < last_time_interval_end_ || time_start_ >= time_end) {
      if (aux_photon_carrier_buffer_.empty()) return nullptr;
      else return pop_next_carrier(aux_photon_carrier_buffer_);
    }

    bkgGetBackgroundList(time_start_, time_end);

    time_start_ += attitude_dt;
    last_time_interval_end_ = time_end;

    double cos_rota = cos(detector.absorberGeometry()->rota_);
    double sin_rota = sin(detector.absorberGeometry()->rota_);

    for (size_t ii = 0; ii < out_num_impacts_; ++ii) {

      double x_pos = out_impact_x_pos_[ii] * 0.001 * cos_rota - out_impact_y_pos_[ii] * 0.001 * sin_rota;
      double y_pos = out_impact_x_pos_[ii] * 0.001 * sin_rota - out_impact_y_pos_[ii] * 0.001 * cos_rota;

      SixtePoint position(x_pos, y_pos, 0.);

      PhotonMetainfo photon_metainfo(PhId::mxs_id, SrcType::aux_bkg, tel_id_, chip_id_);

      if (is_photon_bkg_) {

        SixtePhoton sixte_photon(out_impact_times_[ii],
                                 out_impact_energies_[ii],
                                 position,
                                 photon_metainfo);

        auto aux_photon_carriers = detector.absorb(sixte_photon);
        transferCarriers(aux_photon_carriers, aux_photon_carrier_buffer_);

      } else {
        EnergyDepositions energy_depositions;
        energy_depositions.emplace_back(out_impact_energies_[ii], position, photon_metainfo, out_impact_times_[ii]);

        auto aux_carriers = detector.absorbEnergy(energy_depositions);
        transferCarriers(aux_carriers, aux_photon_carrier_buffer_);
      }
    }
  }
}


BackgroundRateFormLC::BackgroundRateFormLC(const std::string& simput_filename) {
  SimputLC *rate_lc = nullptr;

  if (simput_filename.empty()) throw SixteException("no SIMPUT filename specified for rate function!");

  int status = EXIT_SUCCESS;

  rate_lc = loadSimputLC(simput_filename.c_str(), &status);

  // Make sure that the light curve is given as a function of time.
  // Periodic light curves cannot be processed here.
  if (rate_lc->time == nullptr) {
    throw SixteException("Light curve for background variation does not contain TIME column");
  }

  num_elements_ = rate_lc->nentries;
  start_time_ = rate_lc->timezero;

  times_.reserve(rate_lc->nentries);
  for (long ii=0; ii<rate_lc->nentries; ii++) {
    times_.emplace_back(rate_lc->time[ii]);
  }

  rates_.reserve(rate_lc->nentries);
  for (long ii=0; ii<rate_lc->nentries; ii++) {
    times_.emplace_back(rate_lc->flux[ii]);
  }

  freeSimputLC(&rate_lc);
}


std::vector<double> BackgroundRateFormLC::bkgGetRates(double time_start, double time_end) {
  double interval_sum_{0};
  size_t current_time_{0};
  size_t current_rate_{0};
  double rate_interval_sum_{0};


  auto full_interval = time_end - time_start;
  double interval = time_end - time_start;

  double rate_current_slope = getCurrentLCSlope(current_time_, current_rate_);

  std::vector<double> current_rates;

  if (time_start < start_time_) {
    // if yes we assume the interval fraction in front of the lightcurve as rate 1 and process the rest.
    if ((interval_sum_ + interval) >= start_time_) {
      current_rates.emplace_back((start_time_ - interval_sum_) / full_interval);
      interval -= (start_time_ - interval_sum_);
      interval_sum_ = start_time_;
    } else {
      // if not we assume rate 1 and return.
      current_rates.emplace_back( 1);
      interval_sum_ += interval;
      return current_rates;
    }
  }

  // Add interpolated lightcurve data points to rate array until we've covered the interval.
  // All data will be normalized to the respective length fraction of the interval as the
  // rate output is multiplied later with the total number of events directly.
  while ((interval_sum_ >= start_time_)
      && (interval > 0)
      && (times_[current_time_] < times_[num_elements_ - 1])) {

    auto start_fraction = times_[0] + rate_interval_sum_ - times_[current_time_];
    auto end_fraction = rates_[current_rate_ + 1] - (times_[0] + rate_interval_sum_);
    interval -= end_fraction;
    rate_interval_sum_ += end_fraction;
    interval_sum_ += end_fraction;

    // if the interval is larger than the current bin we jump to the next and update the slope
    if(interval > 0) {
      current_rates.emplace_back((0.5 * ((rates_[current_rate_] + rate_current_slope * start_fraction) +
              rates_[current_rate_ + 1])) * end_fraction / full_interval);
      current_rate_++;
      current_time_++;
      rate_current_slope = (rates_[current_rate_ + 1] - rates_[current_rate_])/
          (times_[current_time_ + 1] - times_[current_time_]);
    } else {
      current_rates.emplace_back((0.5 * ((rates_[current_rate_] + rate_current_slope * start_fraction) +
                rates_[current_rate_ + 1] + rate_current_slope * interval)
                * (times_[current_time_ + 1] - (times_[current_time_] + start_fraction) + interval)
                / full_interval));
    }
  }

  // if we hit the end of the last bin we increase the pointers and recalculate the slope
  if(interval == 0) {
    current_rate_++;
    current_time_++;
    rate_current_slope = getCurrentLCSlope(current_time_, current_rate_);
  } else {
    rate_interval_sum_ += interval;
    interval_sum_ += interval;
  }

  // if we are at the EOF but still have some interval left we set the rate to 1.
  if ((interval > 0) && (times_[current_time_] >= times_[num_elements_ - 1])) {
    current_rates[current_rates.size() - 1] = interval / full_interval;
  }

  return current_rates;
}


double BackgroundRateFormLC::getCurrentLCSlope(size_t time, size_t rate) const{
  return ((rates_[rate + 1]) - (rates_[rate])) /
      ((times_[time + 1]) - (times_[time]));
}

} // sixte
