/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.

   Copyright 2023 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include "sixte_arfgen.hpp"
#include "Parameters.h"
#include "PhotonProjection.h"
#include <algorithm>
#include <random>

#define TOOLSUB sixte_arfgen_main
#include "sixt_main.c"

namespace sixte {

Parameters::Parameters() {
  ARFCorr = queryParameterString("ARFCorr");
  XMLFile = queryParameterString("XMLFile");

  Seed = queryParameterInt("Seed");
  clobber = queryParameterBool("clobber");

  attitude_file = queryParameterString("Attitude");
  if (attitude_file == "none") {
    attitude_file = "";
  }

  gti_file = queryParameterString("GTIfile");
  if (gti_file == "none") {
    gti_file = "";
  }

  SourceRA = queryParameterDouble("SourceRA");
  SourceDec = queryParameterDouble("SourceDec");
  simputfile = queryParameterString("Simput");
  if (isNone(simputfile)) {
    simputfile = "";
  }

  RefRA = queryParameterDouble("RefRA");
  RefDec = queryParameterDouble("RefDec");
  Projection = queryParameterString("Projection");

  crpix1 = 0.0;
  crpix2 = 0.0;
  cdelt1 = -0.05 / 3600.;
  cdelt2 = 0.05 / 3600.;

  regfilter = queryParameterString("regfile");

  n_photons = queryParameterInt("n_photons");
  arf_sampling_factor = queryParameterInt("sampling_factor");

  show_progress = queryParameterBool("progressbar");
}

std::vector<double> calcPhotonTimes(unsigned int n_photons, GTICollection& gtis) {
  auto bins = gtis.getAllGtiBins();

  auto totexp = gtis.totalExposure();

  // calculate number of photons per bin
  std::vector<unsigned int> phots_per_bin(bins.size(), 0);

  std::transform(bins.begin(), bins.end(), phots_per_bin.begin(),
      [n_photons, totexp](std::pair<double,double> p){
      return (p.second - p.first)/totexp * n_photons;
      });

  // make sure we actually have the full number of photons
  while (std::accumulate(phots_per_bin.begin(), phots_per_bin.end(), 0U) < n_photons) {
    // if not, weigh the highest bin more
    *std::max_element(phots_per_bin.begin(), phots_per_bin.end()) += 1;
  }

  std::vector<double> times;

  // linearly space photons in each GTI bin
  for (unsigned int i_bin=0; i_bin<bins.size(); i_bin++) {
    if (phots_per_bin[i_bin]==0) {continue;}

    double dt = (bins[i_bin].second - bins[i_bin].first) / phots_per_bin[i_bin];

    for (unsigned int i_phot=0; i_phot<phots_per_bin[i_bin]; i_phot++) {
      times.push_back(i_phot * dt + bins[i_bin].first);
    }
  }

  return times;
}

ARFCorr::ARFCorr(const struct ARF *const arf, std::vector<sixte::XMLData>& xml_datas, const Parameters &pars)
 : gtis(pars.gti_file, pars.obstime),
   att(pars.obspointing, pars.obstime, gtis.totalExposure()/pars.n_photons),
   obsinfo(xml_datas[0], pars.obspointing, gtis),
   tel("", false, xml_datas[0], obsinfo)
{

  focal_length = xml_datas[0].child("telescope").child("focallength").attributeAsDouble("value");
  if (xml_datas[0].root().attributeAsString("telescop") == "THESEUS" &&
      xml_datas[0].root().attributeAsString("instrume") == "SXI") {
    is_theseus_ = true;
  }

  // load all the geometries
  geometries.reserve(xml_datas.size());
  for (auto & xml_data : xml_datas) {
    GeoPair p;
    p.abs = std::make_unique<Geometry>(xml_data);
    p.arr = createGeometry(xml_data);

    geometries.emplace_back(std::move(p));
  }

  size_t total_bins = arf->NumberEnergyBins;
  for (size_t ii=0; ii<total_bins; ii+=pars.arf_sampling_factor) {
    bin_idxs.push_back(ii);
  }

  // make sure we also have the final bin, for interpolation
  if (bin_idxs[-1] != total_bins-1) {
    bin_idxs.push_back(total_bins-1);
  }

  for (auto ii: bin_idxs) {
    corr_fac.push_back(1.);
    corr_energ.push_back(0.5 * (arf->LowEnergy[ii] + arf->HighEnergy[ii]));
  }
}

void ARFCorr::applyCorrection(const std::string& arfpath, const struct ARF *const arf,
    const Parameters &pars) {

    double exposure = gtis.totalExposure();

    // build interpolator
    std::vector< std::vector<double>::iterator > grid_iter_list;
    grid_iter_list.push_back(corr_energ.begin());

    std::array<size_t , 1> grid_sizes{};
    grid_sizes[0] = corr_energ.size();

    linterp::InterpMultilinear<1,double> interp(
        grid_iter_list.begin(), grid_sizes.begin(),
        corr_fac.data(), corr_fac.data()+grid_sizes[0]);

    // load original ARF file
    CCfits::FITS infile(arfpath);
    // copy it
    std::string outfile = (pars.clobber ? "!" : "") + pars.ARFCorr;
    CCfits::FITS corrfile(outfile, infile);
    corrfile.copy(infile.extension("SPECRESP"));

    // calculate new SPECRESP column
    std::vector<double> e_cen(arf->NumberEnergyBins);
    for (size_t ii=0; ii<e_cen.size(); ii++) {
      e_cen[ii] = 0.5 * (arf->LowEnergy[ii] + arf->HighEnergy[ii]);
    }
    std::vector<double> specresp_corr(arf->NumberEnergyBins);

    std::vector< std::vector<double>::iterator > interp_x_list;
    interp_x_list.push_back(e_cen.begin());

    interp.interp_vec((int)arf->NumberEnergyBins,
        interp_x_list.begin(), interp_x_list.end(),
        specresp_corr.begin());

    for (int ii=0; ii<arf->NumberEnergyBins; ii++) {
      specresp_corr[ii] *= arf->EffArea[ii];
    }

    // write
    corrfile.extension("SPECRESP").column("SPECRESP").write(specresp_corr, 1);
    corrfile.extension("SPECRESP").addKey("EXPOSURE", exposure, "Exposure time");

    // save input parameters
    int status = EXIT_SUCCESS;
    HDpar_stamp(corrfile.fitsPointer(), 1, &status);
    sixte::checkStatusThrow(status, "Failed doing par_stamp");
}

bool ARFCorr::testPhoton(SixtePhoton& phot, SAORegion* region, wcsprm& wcspr) {
      // Photon imaging
      auto imaged = tel.doImaging(att, phot);
      if (!imaged.has_value()) return false;

      std::optional<std::pair<double,double>> skypos;

      // loop over detectors to see if any were hit
      for (GeoPair& geo: geometries) {
        // check if we hit a pixel
        auto imgpos = geo.abs->transformFocalToDet(imaged->detector_position().value());

        auto hit_pix = geo.arr->getPixId(Point_2(imgpos.sixte_point().hx(), imgpos.sixte_point().hy()));
        if (!hit_pix.has_value()) continue; // check next detector

        if (is_theseus_) {
          skypos = PhotonProjection::projectRectTheseusEvent(
              *imaged->photon_metainfo().origin_, att.getTelescopeAxes(phot.time()));
        } else {
        // project the imaged pixel into the sky
        auto pixpos = geo.arr->getRandPosInPixel(hit_pix.value());

        skypos = sixte::detToSky(pixpos.first, pixpos.second,
            focal_length, phot.time(), att, *geo.abs);
        }

        // no need to test further chips
        break;
      }

      if (!skypos.has_value()) return false; // photon landed on no detector

      // ckeck if projected position is inside the region
      // 1. convert skypos to WCS coordinate
      int wcsstatus = 0;
      double world[2] = {
              skypos->first * 180. / M_PI,
              skypos->second * 180. / M_PI
      };

      double imgcrd[2], pixcrd[2];
      double phi, theta;

      wcss2p(&wcspr, 1, 2, world, &phi, &theta, imgcrd, pixcrd, &wcsstatus);
      if (0 != wcsstatus) {
        std::string message = "WCS coordinate conversion failed (RA=" + std::to_string(world[0])
                               + ", Dec=" + std::to_string(world[1])
                               + ", error code " + std::to_string(wcsstatus)
                               + ")";
          throw(sixte::SixteException(message));
      }

      // 2. check if the reprojected impact is inside the region
      return fits_in_region(pixcrd[0], pixcrd[1], region);
}

void ARFCorr::calcCorrection(
    const std::pair<double,double>& srcpos, 
    const NewArf& arf,
    unsigned int num_photons,
    SAORegion* region, wcsprm& wcspr, const Parameters &pars)
{
  sixte::Progressbar progress(pars.show_progress, (double)bin_idxs.size());

  for (size_t i_chan = 0; i_chan <bin_idxs.size(); i_chan++) {
    double e_lo = arf.c_arf()->LowEnergy[bin_idxs[i_chan]];
    double e_hi = arf.c_arf()->HighEnergy[bin_idxs[i_chan]];

    unsigned int passed_photons = 0;

    auto photon_times = calcPhotonTimes(num_photons, gtis);

    for (size_t ph_idx = 0; ph_idx < pars.n_photons; ph_idx++) {
      // Generate photon within specified region and energy range
      double time = photon_times[ph_idx];
      double en = e_lo + sixte::getUniformRandomNumber() * (e_hi - e_lo);
      sixte::SixtePhoton phot(time, en, srcpos, sixte::PhotonMetainfo(1, 1));

      if (testPhoton(phot, region, wcspr)) {
        passed_photons += 1;
      }
    }

    corr_fac[i_chan] = (double) passed_photons / pars.n_photons;
    progress.update((double)i_chan);
  }

  progress.finish();

}

void ARFCorr::calcCorrection(
    const std::string& catalog_file, 
    const NewArf& arf,
    unsigned int num_photons,
    SAORegion* region, wcsprm& wcspr, const Parameters &pars)
{
  int status = EXIT_SUCCESS;
  NewSourceCatalog cat(catalog_file, arf);
  checkStatusThrow(status, "Failed Loading SIMPUT catalog");

  sixte::Progressbar progress(pars.show_progress, (double)bin_idxs.size());

  auto photon_times = calcPhotonTimes(num_photons, gtis);

  std::mt19937 rng(pars.Seed);

  for (size_t i_chan = 0; i_chan <bin_idxs.size(); i_chan++) {
    double e_lo = arf.c_arf()->LowEnergy[bin_idxs[i_chan]];
    double e_hi = arf.c_arf()->HighEnergy[bin_idxs[i_chan]];

    double* relfluxes = getSimputCtlgRelFluxes(cat.simput(),
        e_lo, e_hi,
        pars.obstime.tstart, pars.obstime.mjdref,
        &status);

    checkStatusThrow(status, "Failed generating relative SIMPUT fluxes");

    std::vector<unsigned int> photons_per_source(cat.simput()->nentries, 0);
    for (size_t ii=0; ii<photons_per_source.size(); ii++) {
      photons_per_source[ii] = pars.n_photons * (unsigned int) relfluxes[ii];
      //TODO: properly make an unsigned int from relfluxes[ii]
    }
    free(relfluxes);

    // what if all relative fluxes are zero?
    if (std::accumulate(photons_per_source.begin(), photons_per_source.end(), 0) == 0) {
      std::stringstream ss;
      std::cout << "Warning: All sources have zero flux in channel " << bin_idxs[i_chan]
        << " (" << e_lo << " to " << e_hi << " keV).\n";
      std::cout << "         Assuming even photon distribution between sources" << std::endl;
      for (size_t ii=0; ii<photons_per_source.size(); ii++) {
        photons_per_source[ii] = pars.n_photons / photons_per_source.size();
      }
    }

    // make sure we actually have the full number of photons
    while (std::accumulate(photons_per_source.begin(), photons_per_source.end(), 0U) < num_photons) {
      // if not, weigh the brightest source more
      *std::max_element(photons_per_source.begin(), photons_per_source.end()) += 1;
    }

    unsigned int total_photons = std::accumulate(
        photons_per_source.begin(), photons_per_source.end(), 0);

    unsigned int passed_photons = 0;
    unsigned int ph_idx = 0;

    // randomize photon times if we have multiple sources
    if(photons_per_source[0] != total_photons) {
      std::shuffle(std::begin(photon_times), std::end(photon_times), rng);
    }

    for (long i_src = 0; i_src < cat.simput()->nentries; i_src++) {
      SimputSrc* src = getSimputSrc(cat.simput(), i_src+1, &status);
      checkStatusThrow(status, "Failed getting SIMPUT source");

      for (unsigned int i_ph = 0; i_ph < photons_per_source[i_src]; i_ph++) {
        // Generate photon within specified region and energy range
        double time = photon_times[ph_idx];
        ph_idx++;

        double en = e_lo + sixte::getUniformRandomNumber() * (e_hi - e_lo);

        double ra, dec;
        getSimputPhotonCoord(cat.simput(), src,
            pars.obstime.tstart, pars.obstime.mjdref,
            e_lo, e_hi,
            &ra, &dec,
            &status);
        checkStatusThrow(status, "Failed photon coordinates");

        std::pair<double,double> srcpos = std::make_pair(ra, dec);

        sixte::SixtePhoton phot(time, en, srcpos, sixte::PhotonMetainfo(1, 1));

        if (testPhoton(phot, region, wcspr)) {
          passed_photons += 1;
        }
      }
    }

    corr_fac[i_chan] = (double) passed_photons / total_photons;
    progress.update((double)i_chan);
  }

  progress.finish();

  checkStatusThrow(status, "Failed freeing source catalog");
}

}

std::pair<WCSdata, wcsprm> getWCSStructs(const sixte::Parameters& pars) {

    // set up WCS -- need two structs here
    // TODO alternatively load from image
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
  // it seems that wcsdata are supposed to be initialized this way?
    WCSdata wcsdat={ .exists=1 };
    wcsdat.xrefpix=pars.crpix1;
    wcsdat.yrefpix=pars.crpix2;
    wcsdat.xrefval=pars.RefRA;
    wcsdat.yrefval=pars.RefDec;
    wcsdat.xinc=pars.cdelt1;
    wcsdat.yinc=pars.cdelt2;
    wcsdat.rot = 0.;
    strcpy(wcsdat.type, ("-" + pars.Projection).c_str());

    struct wcsprm wcspr = {.flag=-1};
    if (0 != wcsini(1, 2, &wcspr)) {
      throw sixte::SixteException("Failed Initializing WCS");
    }
    wcspr.crpix[0] = pars.crpix1;
    wcspr.crpix[1] = pars.crpix2;
    wcspr.crval[0] = pars.RefRA;
    wcspr.crval[1] = pars.RefDec;
    wcspr.cdelt[0] = pars.cdelt1;
    wcspr.cdelt[1] = pars.cdelt2;
    strcpy(wcspr.cunit[0], "deg");
    strcpy(wcspr.cunit[1], "deg");
    strcpy(wcspr.ctype[0], "RA---");
    strcat(wcspr.ctype[0], pars.Projection.c_str());
    strcpy(wcspr.ctype[1], "DEC--");
    strcat(wcspr.ctype[1], pars.Projection.c_str());
#pragma GCC diagnostic pop

    return std::make_pair(wcsdat, wcspr);
}

int sixte_arfgen_main() {
  // Register HEAdas task
  set_toolname("sixte_arfgen");
  set_toolversion("1.0");

  try {
    sixte::Parameters pars;

    sixte::initSixteRng(pars.Seed);

    // set up the XMLs
    sixte::XMLDetectorFiles files = sixte::loadXMLFiles();

    if (files.num_chips.size() != 1) {
      throw(sixte::SixteException("sixte_arfgen only supports a single XML!"));
    }

    std::string xml_path = sixte::loadXMLPath(); 

    std::vector<sixte::XMLData> xml_datas;
    xml_datas.reserve(files.xml_documents.size());
    
    for (auto & xml_document : files.xml_documents) {
      xml_datas.emplace_back(xml_document, xml_path);
    }

    // load needed stuff from the XML
    std::vector<std::string> arffile;
    arffile.push_back(xml_datas[0].dirname() + xml_datas[0].child("telescope").child("arf").attributeAsString("filename"));
    sixte::NewArf cpp_arf(arffile);
    // TODO directly use the NewArf
    const ARF* arf = cpp_arf.c_arf();

    // get the two WCS data structs
    auto [wcsdat, wcspr] = getWCSStructs(pars);

    // Load the region file
    SAORegion *region;
    int status = EXIT_SUCCESS;
    fits_read_rgnfile(pars.regfilter.c_str(), &wcsdat, &region, &status);
    sixte::checkStatusThrow(status, "Failed reading region from file");

    // set up ARFCorr
    sixte::ARFCorr corr(arf, xml_datas, pars);

    healog(3) << "Calculating ARF correction..." << std::endl;

    if (pars.simputfile.empty()) {
      std::pair<double,double> srcpos = std::make_pair(
          pars.SourceRA * M_PI / 180.,
          pars.SourceDec * M_PI / 180.
          );

      corr.calcCorrection(
          srcpos, cpp_arf,
          pars.n_photons,
          region, wcspr, pars);
    } else {
      corr.calcCorrection(
          pars.simputfile, cpp_arf,
          pars.n_photons,
          region, wcspr, pars);
    }

    healog(3) << "Applying ARF correction..." << std::endl;

    // apply correction to the ARF
    std::string inarfpath = xml_datas[0].dirname() + xml_datas[0].child("telescope").child("arf").attributeAsString("filename");
    corr.applyCorrection(inarfpath, arf, pars);

    // cleanup
    fits_free_region(region);
    sixt_destroy_rng();

    healog(3) << "Finished successfully!" << std::endl;

    return EXIT_SUCCESS;

  } catch (const std::exception& e) {
    sixte::printError(e.what());
    sixt_destroy_rng();
    return EXIT_FAILURE;
  }
}



