/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2023 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include "makespec.hpp"
#include "Parameters.h"
#include "sixt.h"
#include "sixte_random.h"
#include "SixteCCFits.h"

#define TOOLSUB makespec_main
#include "sixt_main.c"


ClinePars read_par() {

  ClinePars pars;

  pars.EvtFile = sixte::queryParameterString("EvtFile");
  pars.EventFilter = sixte::queryParameterString("EventFilter");
  pars.Spectrum = sixte::queryParameterString("Spectrum");
  pars.RSPPath = sixte::queryParameterString("RSPPath");
  pars.ARFfile = sixte::queryParameterString("ANCRfile");
  pars.RMFfile = sixte::queryParameterString("RESPfile");
  pars.GTIfile = sixte::queryParameterString("GTIfile");
  pars.regfile = sixte::queryParameterString("regfile");

  pars.Seed = sixte::queryParameterInt("Seed");
  pars.clobber = sixte::queryParameterBool("clobber");
  pars.usepha = sixte::queryParameterBool("usepha");

  return pars;
}

void getHeaderVals(HeaderVals &vals,
    CCfits::ExtHDU &hdu,
    const sixte::GTICollection &gtis) {

  // read a bunch of keywords
  try {
    hdu.readKey("TELESCOP", vals.telescop);
    hdu.readKey("INSTRUME", vals.instrume);
    hdu.readKey("FILTER", vals.filter);
    hdu.readKey("DATE-OBS", vals.dateobs);
    hdu.readKey("TIME-OBS", vals.timeobs);
    hdu.readKey("DATE-END", vals.dateend);
    hdu.readKey("TIME-END", vals.timeend);
    hdu.readKey("ANCRFILE", vals.ancrfile);
    hdu.readKey("RESPFILE", vals.respfile);
  } catch (CCfits::HDU::NoSuchKeyword &e) {
    std::cout << e.message() << std::endl;
    throw sixte::SixteException("Missing keyword in EvtFile");
  }

  vals.exposure = gtis.totalExposure();

  try {
    hdu.readKey("PHA2PI", vals.pha2pi);
  } catch (CCfits::ExtHDU::NoSuchKeyword) {
    vals.pha2pi="";
  }

  // to set later: detchans, gtifile, regfile
}

std::pair<int, SignalType> getSigCol(CCfits::FITS &evtfile, bool usepha) {
  int colidx = -1;
  SignalType sigtype = PHA;

  if (usepha) {
    try {
      colidx = evtfile.currentExtension().column("PHA").index();
      healog(5) << "Reading energies from column PHA\n";
    } catch (CCfits::Table::NoSuchColumn) {
      throw sixte::SixteException(
          "No PHA column present in event file, even though usepha was set to true");
    }
  } else {
    // we must either have a PI column or a SIGNAL column. Try PI first
    try {
      colidx = evtfile.currentExtension().column("PI").index();
      sigtype = PI;
      healog(5) << "Reading energies from column PI\n";
    } catch (CCfits::Table::NoSuchColumn) {
      colidx = -1;
      sigtype = SIGNAL;
    }

    if (sigtype == SIGNAL) {
      try {
        colidx = evtfile.currentExtension().column("SIGNAL").index();
        sigtype = SIGNAL;
        healog(5) << "Reading energies from column SIGNAL\n";
      } catch (CCfits::Table::NoSuchColumn){
        colidx = -1;
      }
    }
  }

  if (colidx == -1) {
    throw sixte::SixteException(
        "Unable to locate a column specifying event energy in EvtFile (must be PHA, SIGNAL, or PI)");
  }

  return std::make_pair(colidx, sigtype);
}

bool check_ARF_compatibility(std::string rsppath, std::string file_a, std::string file_b) {

  if (file_a.compare(file_b) == 0) {
    return true;
  }

  std::string path_a = rsppath + file_a;
  std::string path_b = rsppath + file_b;

  int status = EXIT_SUCCESS;

  struct ARF* arf_a = loadARF(const_cast<char *> (path_a.c_str()), &status);
  CHECK_STATUS_RET(status, false);
  struct ARF* arf_b = loadARF(const_cast<char *> (path_b.c_str()), &status);
  CHECK_STATUS_RET(status, false);

  bool is_valid = true;

  if (arf_a->NumberEnergyBins != arf_b->NumberEnergyBins) {
    is_valid = false;
    std::cout << "ARFS have differing number of energy bins" << std::endl;
  }
  if (arf_a->LowEnergy[0] != arf_b->LowEnergy[0]) {
    is_valid = false;
    std::cout << "ARFS have differing energy grid" << std::endl;
  }
  if (arf_a->HighEnergy[arf_a->NumberEnergyBins-2] != arf_b->HighEnergy[arf_b->NumberEnergyBins-2]) {
    is_valid = false;
    std::cout << "ARFS have differing energy grid" << std::endl;
  }

  freeARF(arf_a);
  freeARF(arf_b);

  if (!is_valid) {
    std::cout << "ARFs " << path_a << " and " << path_b << " are incompatible!" << std::endl;
  }

  return is_valid;
}

bool check_RMF_compatibility(std::string rsppath, std::string file_a, std::string file_b) {
  if (file_a.compare(file_b) == 0) {
    return true;
  }

  std::string path_a = rsppath + file_a;
  std::string path_b = rsppath + file_b;

  int status = EXIT_SUCCESS;

  struct RMF* rmf_a = loadRMF(const_cast<char *> (path_a.c_str()), &status);
  CHECK_STATUS_RET(status, false);
  struct RMF* rmf_b = loadRMF(const_cast<char *> (path_b.c_str()), &status);
  CHECK_STATUS_RET(status, false);

  bool is_valid = true;

  if (rmf_a->NumberChannels != rmf_b->NumberChannels) {
    is_valid = false;
    std::cout << "RMFs have differing number of energy bins" << std::endl;
  }
  if (rmf_a->ChannelLowEnergy[0] != rmf_b->ChannelLowEnergy[0]) {
    is_valid = false;
    std::cout << "RMFs have differing energy grid" << std::endl;
  }
  if (rmf_a->ChannelHighEnergy[rmf_a->NumberChannels-2] != rmf_b->ChannelHighEnergy[rmf_b->NumberChannels-2]) {
    is_valid = false;
    std::cout << "RMFs have differing energy grid" << std::endl;
  }

  if (rmf_a->FirstChannel != rmf_b->FirstChannel) {
    is_valid = false;
    std::cout << "RMFs have a differing FirstChannel" << std::endl;
  }


  freeRMF(rmf_a);
  freeRMF(rmf_b);

  if (!is_valid) {
    std::cout << "RMFs " << path_a << " and " << path_b << " are incompatible!" << std::endl;
  }

  return is_valid;
}

bool filterTime(const std::vector<std::pair<double,double>>& gtibins, double time) {
  for (auto &p: gtibins) {
    if (time >= p.first && time <= p.second) return true;
  }

  return false;
}

int makespec_main() {
  // Register HEAdas task
  set_toolname("makespec");
  set_toolversion("1.0");

  // declare some NULL pointers to free later
  RMF* rmf = NULL;
  SAORegion* region = NULL;
  int status = EXIT_SUCCESS;

  try {
    ClinePars pars = read_par();

    // apply event filter if requested
    std::string evtfilename;
    if (!sixte::isNone(pars.EventFilter)) {
      std::stringstream ss;
      ss << pars.EvtFile
        << "[EVENTS]["
        << pars.EventFilter
        << "]";
      evtfilename = ss.str();
    } else {
      evtfilename = pars.EvtFile;
    }

    CCfits::FITS evtfile(evtfilename, CCfits::Read);
    evtfile.extension("EVENTS");

    // get the GTIs, either from the input event file or user specified
    sixte::GTICollection gtis(pars.EvtFile + "[STDGTI]");
    bool filter_gti = !sixte::isNone(pars.GTIfile);
    if (filter_gti) {
      sixte::GTICollection usrgti(pars.GTIfile);
      gtis.filterWith(usrgti);
    }

    // prepare region filtering
    if (!sixte::isNone(pars.regfile)) {
      region = sixte::getRegion(pars.regfile, evtfile.currentExtension());
    }

    // read values from the event file header
    HeaderVals headvals;
    getHeaderVals(headvals, evtfile.currentExtension(), gtis);

    // determine the column containing the measured energies
    // this is either PI, PHA or SIGNAL
    auto [sigcolidx, sigcoltype] = getSigCol(evtfile, pars.usepha);

    // get the RMF and ARF for spectral extraction
    std::string rsppath = (sixte::isNone(pars.RSPPath) || pars.RSPPath.empty()) ? "" : (pars.RSPPath+"/");

    // decide which RMF is used to build the spectrum
    std::string rmffile;
    if (sixte::isNone(pars.RMFfile)) {
      // use something from the header
      if (sigcoltype == PI) {
        try {
          evtfile.currentExtension().readKey("PIRMF", rmffile);
        } catch (CCfits::HDU::NoSuchKeyword) {
          throw sixte::SixteException("Could not find 'PIRMF' keyword in EvtFile, even though the PI column was requested.\nPlease either give a 'RESPfile' on the command line or set 'usepha' accordingly.");
        }
      } else {
        rmffile = headvals.respfile;
      }
    } else {
      rmffile = pars.RMFfile;
    }

    // then need to determine which ARF to use
    std::string arffile;
    if (sixte::isNone(pars.ARFfile)) {
      arffile = headvals.ancrfile;
    } else {
      arffile = pars.ARFfile;
    }

    // Check whether the simulated and makespec RMF/ARF have compatible binning
    if (!check_ARF_compatibility(rsppath, arffile, headvals.ancrfile)) {
      throw sixte::SixteException("Simulated ARF and given ARF are not compatible");
    }

    if (!check_RMF_compatibility(rsppath, rmffile, headvals.respfile)) {
      throw sixte::SixteException("Simulated RMF and given RMF are not compatible");
    }


    // then, the actual binning
    rmf = loadRMF(const_cast<char *> ((rsppath+rmffile).c_str()), &status);
    sixte::checkStatusThrow(status, "Failed opening RMF file");

    sixte::initSixteRng(pars.Seed);

    healog(3) << "calculate spectrum ..." << std::endl;

    std::vector<long> spec(rmf->NumberChannels, 0);

    std::vector<long> x;
    std::vector<long> y;
    std::vector<double> times;
    auto gtibins = gtis.getAllGtiBins();

    auto nrows = evtfile.currentExtension().rows();

    if (region != NULL) {
      evtfile.currentExtension().column("X").read(x, 1, nrows);
      evtfile.currentExtension().column("Y").read(y, 1, nrows);
    }

    if (filter_gti) {
      evtfile.currentExtension().column("TIME").read(times, 1, nrows);
    }

    std::vector<long> phas;

    if (sigcoltype == SIGNAL) {
      // need to convert signal in eV to PHA value
      std::vector<double> signals;
      evtfile.currentExtension().column(sigcolidx).read(signals, 1, nrows);
      std::transform(signals.begin(), signals.end(), std::back_inserter(phas),
          [rmf](double s){return getEBOUNDSChannel(s,rmf);});
    } else {
      // just read directly
      evtfile.currentExtension().column(sigcolidx).read(phas, 1, nrows);
    }

    for (size_t ii=0; ii<(size_t) nrows; ii++) {
      bool time_valid = filter_gti ? filterTime(gtibins, times[ii]) : true;
      bool region_valid = (region!=NULL) ? sixte::filterReg(region, x[ii],y[ii]) : true;

      if (region_valid && time_valid) {
        long idx=phas[ii]-rmf->FirstChannel;
        if (idx >= 0 && idx < rmf->NumberChannels) {
          spec[idx]++;
        }
      }
    }

    healog(3) << "store spectrum ..." << std::endl;

    std::string filename = pars.Spectrum;
    if (pars.clobber) filename = "!"+filename;
    CCfits::FITS outfile(filename, CCfits::Write);

    std::vector<std::string> ttype{"CHANNEL", "COUNTS"};
    std::vector<std::string> tform{"J", "J"};
    std::vector<std::string> tunit{"ADU", "counts"};

    auto table = outfile.addTable("SPECTRUM", spec.size(),ttype,tform,tunit);

    // write data
    std::vector<long> chan(spec.size(), 0);
    std::iota(chan.begin(), chan.end(), rmf->FirstChannel);
    table->column("CHANNEL").write(chan, 1);
    table->column("COUNTS").write(spec, 1);

    // write header
    table->addKey("ORIGIN", "ECAP", "Origin of FITS File");
    table->addKey("CREATOR", "makespec", "Program that created this FITS file");
    table->addKey("BACKFILE", "", "background file");
    table->addKey("CORRFILE", "", "correlation file");
    table->addKey("CORRSCAL", 0, "");
    table->addKey("SYS_ERR", 0, "");
    table->addKey("QUALITY", 0, "");
    table->addKey("GROUPING", 0, "");
    table->addKey("BACKSCAL", 1., "");
    table->addKey("AREASCAL", 1., "");
    table->addKey("HDUCLASS", "OGIP", "");
    table->addKey("HDUCLAS1", "SPECTRUM", "");
    table->addKey("HDUVERS", "1.2.1", "");
    table->addKey("HDUVERS1", "1.1.0", "");
    table->addKey("HDUCLAS2", "TOTAL", "");
    table->addKey("HDUCLAS3", "COUNT", "");
    table->addKey("CHANTYPE", "PI", "");
    table->addKey("DETCHANS", spec.size(), "");
    table->addKey("TELESCOP", headvals.telescop, "");
    table->addKey("INSTRUME", headvals.instrume, "");
    table->addKey("FILTER", headvals.filter, "");
    table->addKey("DATE-OBS", headvals.dateobs, "");
    table->addKey("TIME-OBS", headvals.timeobs, "");
    table->addKey("DATE-END", headvals.dateend, "");
    table->addKey("TIME-END", headvals.timeend, "");
    table->addKey("Exposure", headvals.exposure, "exposure time");
    table->addKey("POISSERR", true, "poissonian error");

    // now, some long strings
    // apparently, there are bugs in CCFITS - use our cfitsio
    fitsfile* fptr = table->fitsPointer();

    sixte::cfitsio::fitsUpdateKey(fptr, "LONGSTRN", "OGIP 1.0",
        "The OGIP long string convention may be used");

    // table->addKey("ANCRFILE", rsppath+arffile, "ancillary response file", true);
    sixte::cfitsio::fitsUpdateKeyLongstr(fptr, "ANCRFILE", rsppath+arffile/*, "ancillary response file"*/);
    // table->addKey("RESPFILE", rsppath+rmffile, "response file", true);
    sixte::cfitsio::fitsUpdateKeyLongstr(fptr, "RESPFILE", rsppath+rmffile/*, "response file"*/);
    if (!headvals.pha2pi.empty()) {
      //table->addKey("PHA2PI", headvals.pha2pi, "PHA2PI correction file", true);
      sixte::cfitsio::fitsUpdateKeyLongstr(fptr, "PHA2PI", headvals.pha2pi/*, "PHA2PI correction file"*/);
    }
    if (!sixte::isNone(pars.EventFilter)) {
      //table->addKey("FilterExpr", pars.EventFilter, "used filter Expression", true);
      sixte::cfitsio::fitsUpdateKeyLongstr(fptr, "FilterExpr", pars.EventFilter/*,"used filter Expression"*/);
    }

    if (region != NULL) {
      //table->addKey("REGFILE", pars.regfile, "Region filter file", true);
      sixte::cfitsio::fitsUpdateKeyLongstr(fptr, "REGFILE", pars.regfile/*, "Region filter file"*/);
    }



    // add GTI extension
    gtis.saveGtiExtension(outfile.fitsPointer(), headvals.telescop, headvals.instrume);

    // save input parameters
    HDpar_stamp(outfile.fitsPointer(), 1, &status);
    sixte::checkStatusThrow(status, "Failed doing par_stamp");

    // TODO copy fits region extension
  } catch (sixte::SixteException &e) {
    std::cout << e.what() << std::endl;
    status=EXIT_FAILURE;
  } catch (CCfits::FitsException &e) {
    std::cout << e.message() << std::endl;
    status=EXIT_FAILURE;
  }

  // cleanup
  if (rmf != NULL) {
    freeRMF(rmf);
  }
  if (region!=NULL) {
    fits_free_region(region);
  }
  sixt_destroy_rng();
  return status;
}
