/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2022 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include "NewPSF.h"

#include "SixteCCFits.h"
#include "SixteVector.h"
#include "sixte_random.h"

#include "healog.h"

namespace sixte {

NewPSF::NewPSF(const std::string &filename, double focal_length, const std::string& vig_filename)
               : vignetting_(vig_filename) {
  
  if (filename.empty()) throw SixteException("PSF file not specified!");
  
  std::unique_ptr<CCfits::FITS> inFile;
  inFile = sixteOpenFITSFileRead(filename, "PSF file");
  
  size_t nhdus = inFile->extension().size();
  
  for (size_t hdu=0; hdu <= nhdus; hdu++) {
    if (hdu==0) {
      // skip primary HDU if it is empty
      if (inFile->pHDU().axes() != 0) getValuesForEnergyThetaPhi(inFile->pHDU());
    }
    else getValuesForEnergyThetaPhi(inFile->extension((int)hdu));
  }

  std::sort(energies_.begin(), energies_.end());
  std::sort(thetas_.begin(), thetas_.end());
  std::sort(phis_.begin(), phis_.end());
  
  printAvailablePSFEntries();

  interpolator_.reallocate({energies_.size(), thetas_.size(), phis_.size()});
  std::vector grid_copy = energies_;
  interpolator_.set_grid(0, std::move(grid_copy));
  grid_copy = thetas_;
  interpolator_.set_grid(1, std::move(grid_copy));
  grid_copy = phis_;
  interpolator_.set_grid(2, std::move(grid_copy));

  for (size_t hdu=0; hdu<=nhdus; hdu++) {
    if (hdu==0) {
      // skip primary HDU if it is empty
      if (inFile->pHDU().axes() != 0) addPSFTo3DArray(*inFile, inFile->pHDU(), focal_length);
    }
    else addPSFTo3DArray(*inFile, inFile->extension((int)hdu), focal_length);
  }
}

std::optional<Point_2> NewPSF::get_NewPSF_pos(const SixtePhoton& photon,
                                              const Telescope_attitude& telescope,
                                              double focal_length) {
  auto photon_direction=sixte::unitVector(*photon.ra(), *photon.dec());

  double theta = calculateOffAxisAngleTheta(telescope.nz, photon_direction);
  double phi = calculateAzimuthalAnglePhi(telescope.nx, telescope.ny, photon_direction);

  double random_number = getUniformRandomNumber();

  if (random_number > vignetting_.getVignettingFactor(photon.energy(), theta, phi)) {
    // The photon does not hit the detector at all (e.g. it is absorbed).
    return std::nullopt;
  }

  size_t energy_ind = interpolator_.interpolate_1D(0, photon.energy());
  size_t theta_ind = interpolator_.interpolate_1D(1, theta);
  size_t phi_ind = interpolator_.interpolate_1D(2, phi);

  const auto& psf_item = interpolator_.retrieve({energy_ind,theta_ind,phi_ind});

  std::optional<Point_2> position = samplePosFromPSFImage(psf_item, focal_length, theta, phi, phis_[phi_ind]);

  return position;
}

void NewPSF::saveNewPSFImage(const std::string& filename) {
  int nhdus=0;
  
  const string &inFileName(filename);
  CCfits::FITS inFile(inFileName, CCfits::Write);
  
  // Loop over the different PSFs in the storage.
  size_t energy_ind, theta_ind, phi_ind;
  for (energy_ind=0; energy_ind<energies_.size(); energy_ind++) {
    for (theta_ind=0; theta_ind<thetas_.size(); theta_ind++) {
      for (phi_ind=0; phi_ind<phis_.size(); phi_ind++) {
        
        const auto& psf_item = interpolator_.retrieve({energy_ind,theta_ind,phi_ind});
        // Determine size of PSF sub-rectangles (don't save entire PSF but only
        // the relevant region around the central peak, which has a probability
        // greater than 0).
        long width = psf_item.naxis1_;
        long height = psf_item.naxis2_;
        
        std::valarray<double> sub_psf;
        
        // Store the PSF in the 1D array to handle it to the FITS routine.
        int x0 = 0;
        int y0 = 0; // coordinates of lower left corner of sub-rectangle
        
        for (int x=x0; x<(x0+width); x++) {
          for (int y=y0; y<(y0+height); y++) {
            sub_psf[((x-x0)*width+y-y0)] = psf_item.data_array_[x][y];
          }
        }
        
        // Create an image in the FITS-file (primary HDU):
        std::vector<long> naxes={(psf_item.naxis1_), (psf_item.naxis2_)};
        inFile.addImage(filename, DOUBLE_IMG, naxes);
        
        nhdus++;
        auto &image_extension = inFile.extension(nhdus);
        
        writeHeaderKeywords(image_extension, psf_item, energies_[energy_ind], thetas_[theta_ind], phis_[phi_ind], nhdus);
        
        // Lower left corner -> FITS coordinates start at (1,1)
        std::vector<long> fpixel={x0+1, y0+1};
        
        // Upper right corner.
        std::vector<long> lpixel={psf_item.naxis1_, psf_item.naxis2_};

        image_extension.write(fpixel, lpixel, sub_psf);
          
        image_extension.writeChecksum();
      }
    }
  }
}

void NewPSF::getValuesForEnergyThetaPhi (CCfits::HDU& table) {
  table.readKey("ENERGY", energy_);
  table.readKey("THETA", theta_);
  table.readKey("PHI ", phi_);
  
  convertUnits(energy_, theta_, phi_);
  
  addDValue2Vector(energy_, energies_);
  addDValue2Vector(theta_, thetas_);
  addDValue2Vector(phi_, phis_);
}


void NewPSF::printAvailablePSFEntries () {
  healog(5) << "PSF - available energies:" << std::endl;
  for (double energ : energies_) {
    healog(5) <<  energ << " keV" << std::endl;
  }
  healog(5) << "PSF - available off-axis angles:" << std::endl;
  for (double thet : thetas_) {
    healog(5) << thet/M_PI*180.*60. << " arc min" << std::endl;
  }
  healog(5) << "PSF - available azimuthal angles:" << std::endl;
  for (double ph : phis_) {
    healog(5) << ph/M_PI*180. << " deg" << std::endl;
  }
}

void NewPSF::printAvailablePSFImages (double sum, size_t energy_ind, size_t theta_ind, size_t phi_ind) {
  healog(5) << "PSF: images "
            << sum/sum * 100. << "% of incident photons for "
            << energies_[energy_ind] << "keV, "
            << thetas_[theta_ind]/M_PI*180.*60. << " arc min, "
            << phis_[phi_ind]/M_PI*180. << " deg" << std::endl;
}

void NewPSF::addPSFTo3DArray (CCfits::FITS& file, CCfits::HDU& table, double focal_length) {
  table.readKey("ENERGY", energy_); // [eV]
  table.readKey("THETA", theta_); // [arcmin]
  table.readKey("PHI ", phi_); // [deg]
  
  convertUnits(energy_, theta_, phi_);
  
  size_t energy_ind = find_index(energy_, energies_);
  size_t theta_ind = find_index(theta_, thetas_);
  size_t phi_ind = find_index(phi_, phis_);
  
  int naxis1, naxis2;
  table.readKey("NAXIS1",naxis1);
  table.readKey("NAXIS2",naxis2);

  PSF_Item_ new_item;
  
  new_item.naxis1_ = (long)naxis1;
  new_item.naxis2_ = (long)naxis2;
  
  getWCSKeywords (table, focal_length, new_item);
  
  int extension_num = table.index();
  std::valarray<double> dat_arr = readPSFImage(file, extension_num);
  
  double sum = createCumulativeDistribution(new_item, dat_arr);

  interpolator_.set_interpolation_point({energy_ind, theta_ind, phi_ind}, std::move(new_item));
  
  printAvailablePSFImages (sum, energy_ind, theta_ind, phi_ind);
}


void addDValue2Vector(double value, std::vector<double>& vec) {
  bool value_in_list = false;
  if (std::find(vec.begin(), vec.end(), value)!=vec.end()) value_in_list = true;
  if (!value_in_list || vec.empty()) vec.push_back(value);
}

size_t find_index(double val, std::vector<double> val_vec) {
  size_t index;
  for (index=0; index<=val_vec.size(); index++) {
    if (fabs(val - val_vec[index]) <= fabs(val * 1.e-6)) return index;
  }
  throw SixteException("could not find appropriate PSF entry");
}

size_t which_psf(double val, std::vector<double> val_vec) {
  size_t index;
  for (index=0; index<val_vec.size()-1; index++) {
    if (val_vec[index+1] > val) break;
  }
  if (index < val_vec.size()-1) {
    double rnd = getUniformRandomNumber();
    if (rnd < ((val-val_vec[index]) / (val_vec[index+1]-val_vec[index]))) {
      index++;
    }
  }
  return index;
}

void convertUnits (double& energy_eV, double& theta_arcmin, double& phi_deg) {
  energy_eV     *= 1.e-3;             // [eV] -> [keV];
  theta_arcmin  *= M_PI / 180. / 60.; // [arc min] -> [rad]
  phi_deg       *= M_PI / 180.;       // [deg] -> [rad]
}

std::valarray<double> readPSFImage (CCfits::FITS& file, int extension_num) {
  std::valarray<double> dat_arr;
  
  if (extension_num == 0) file.pHDU().read(dat_arr);
  else file.extension(extension_num).read(dat_arr);
  
  return dat_arr;
}

double createCumulativeDistribution (PSF_Item_& current_psf_item, const std::valarray<double>& dat_arr) {
  double sum=0.;
  
  typedef std::vector<double> d1;
  typedef std::vector<d1> d2;
  current_psf_item.data_array_ = d2(current_psf_item.naxis1_, d1(current_psf_item.naxis2_));
  
  for (long ii=0; ii<current_psf_item.naxis1_; ii++) {
    for (long jj=0; jj<current_psf_item.naxis2_; jj++) {
      sum += dat_arr[jj*current_psf_item.naxis1_+ii];
      current_psf_item.data_array_[ii][jj] = sum;
    }
  }
  
  // Explicitly normalize the PSF.
  for (long ii=0; ii<current_psf_item.naxis1_; ii++) {
    for (long jj=0; jj<current_psf_item.naxis2_; jj++) {
      current_psf_item.data_array_[ii][jj]*=1./sum;
    }
  }
  
  return sum;
}

double calculateOffAxisAngleTheta(const SixteVector& telescope_nz, const SixteVector& photon_direction) {
  double cos_theta=scalarProduct(telescope_nz, photon_direction);
  
  // Avoid numerical problems with numbers slightly larger than 1.
  if ((cos_theta>1.0) && (cos_theta-1.0<1.e-10)) cos_theta = 1.0;
  
  if (cos_theta>1.0) throw SixteException("Off-axis angle theta is larger than 1!");
  
  return acos(cos_theta);
}

double calculateAzimuthalAnglePhi(const SixteVector& telescope_nx, const SixteVector& telescope_ny, const SixteVector& photon_direction){
  double phi = atan2(scalarProduct(telescope_ny, photon_direction),
                     scalarProduct(telescope_nx, photon_direction));
  // phi returned by atan2 is within [-PI,PI], but must be in [0,2*PI] as in PSF
  if (phi < 0.0) phi += 2.0*M_PI;
  return phi;
}

std::optional<Point_2> samplePosFromPSFImage (const PSF_Item_& psf_item, double focal_length, double theta, double phi_photon, double phi_psf) {
  // Get a random position from the best fitting PSF image.
  
  // Perform a binary search to determine a random position:
  // -> one binary search for each of the 2 coordinates x and y
  auto random_number = getUniformRandomNumber();
  
  if (psf_item.naxis1_ == 0 || psf_item.naxis2_ == 0) throw SixteException("psf_item->naxis is 0!");
  
  if (random_number > psf_item.data_array_[psf_item.naxis1_-1][psf_item.naxis2_-1]) {
    // The photon does not hit the detector at all (e.g. it is absorbed).
    sixt_warning("PSF contains vignetting contributions");
    return std::nullopt;
  }
  
  long x1, y1;
  // Perform a binary search to obtain the x-coordinate.
  long high = psf_item.naxis1_-1;
  long low = 0;
  long mid;
  long ymax = psf_item.naxis2_-1;
  while (high > low) {
    mid = (low+high)/2;
    if (psf_item.data_array_[mid][ymax] < random_number) {
      low = mid+1;
    } else {
      high = mid;
    }
  }
  x1 = low;
  
  // Search for the y coordinate:
  high=psf_item.naxis2_-1;
  low=0;
  while (high > low) {
    mid=(low+high)/2;
    if (psf_item.data_array_[x1][mid] < random_number) {
      low=mid+1;
    } else {
      high=mid;
    }
  }
  y1=low;
  // Now x1 and y1 have pixel positions [integer pixel].
  
  // Determine the distance ([m]) of the central reference position
  // from the optical axis according to the off-axis angle theta.
  double distance=focal_length * tan(theta);
  
  // rotate to the phi used for evaluating the psf
  double sinp, cosp;
  
  sincos(phi_psf, &sinp, &cosp);
  
  // TODO why is this saved in a Point_2 when it is extracted immediately afterwards?
  Point_2 position(cosp*distance, sinp*distance);
  
  // Add the relative position obtained from the PSF image (randomized pixel indices x1 and y1).
  double x2 = position.x()
              + (((double)x1 - psf_item.crpix1_ + 0.5 + getUniformRandomNumber()) * psf_item.cdelt1_)
              + psf_item.crval1_; // [m]
  
  double y2 = position.y()
              + (((double)y1 - psf_item.crpix2_ + 0.5 + getUniformRandomNumber()) * psf_item.cdelt2_)
              + psf_item.crval2_; // [m]
  
  // Rotate the position [m] according to the final azimuthal angle.
  sincos(phi_photon-phi_psf, &sinp, &cosp);
  
  Point_2 return_position((cosp*x2 - sinp*y2), (sinp*x2 + cosp*y2));
  
  return return_position;
}

void getWCSKeywords (CCfits::HDU& table, double focal_length, PSF_Item_& current_psf_item) {
  table.readKey("CDELT1", current_psf_item.cdelt1_);
  table.readKey("CDELT2", current_psf_item.cdelt2_);
  table.readKey("CRPIX1", current_psf_item.crpix1_);
  table.readKey("CRPIX2", current_psf_item.crpix2_);
  table.readKey("CRVAL1", current_psf_item.crval1_);
  table.readKey("CRVAL2", current_psf_item.crval2_);
  
  // Check whether units of PSF image are given in [m].
  std::string cunit1, cunit2;
  table.readKey("CUNIT1", cunit1);
  table.readKey("CUNIT2", cunit2);
  
  if ((cunit1=="arcsec") && (cunit2=="arcsec")) {
    // Convert from [arcsec] -> [m]
    double scaling=tan(1./3600.*M_PI/180.)*focal_length; // [m/arcsec]
    current_psf_item.cdelt1_*=scaling;
    current_psf_item.cdelt2_*=scaling;
    current_psf_item.crval1_*=scaling;
    current_psf_item.crval2_*=scaling;
  } else if ((cunit1=="deg") && (cunit2=="deg")) {
    // Convert from [deg] -> [m]
    double scaling=tan(M_PI/180.)*focal_length; // [m/deg]
    current_psf_item.cdelt1_*=scaling;
    current_psf_item.cdelt2_*=scaling;
    current_psf_item.crval1_*=scaling;
    current_psf_item.crval2_*=scaling;
  } else if (!(cunit1=="m") && !(cunit2=="m")) {
    // Neither [arcsec] nor [m]
    throw SixteException("PSF pixel width must be given either in [m] or in [arcsec]");
  }
}

void writeHeaderKeywords (CCfits::ExtHDU& image_extension, const PSF_Item_& psf_item, double energy, double theta, double phi, long nhdus) {
  // Write the header keywords for PSF FITS-files (CAL/GEN/92-027):
  image_extension.addKey("CTYPE1", "DETX", "detector coordinate system", false);
  image_extension.addKey("CTYPE2", "DETY", "detector coordinate system", false);
  image_extension.addKey("HDUCLASS", "OGIP", "Extension is OGIP defined", false);
  image_extension.addKey("HDUDOC", "CAL/GEN/92-020", "Document containing extension definition", false);
  image_extension.addKey("HDUVERS", "1.0.0", "giving the version of the format", false);
  image_extension.addKey("HDUCLAS1", "Image", "", false);
  image_extension.addKey("HDUCLAS2", "PSF", "", false);
  image_extension.addKey("HDUCLAS3", "PREDICTED", "", false);
  image_extension.addKey("HDUCLAS4", "NET", "", false);
  
  image_extension.addKey("CUNIT1", "m", "", false);
  image_extension.addKey("CUNIT2", "m", "", false);
  
  double dbuffer=(double)psf_item.naxis1_*0.5+0.5;
  image_extension.addKey("CRPIX1", &dbuffer, "X axis reference pixel", false);
  
  dbuffer=(double)psf_item.naxis2_*0.5+0.5;
  image_extension.addKey("CRPIX2", &dbuffer, "Y axis reference pixel", false);
  
  dbuffer=0.;
  image_extension.addKey("CRVAL1", &dbuffer, "coord of X ref pixel", false);
  image_extension.addKey("CRVAL2", &dbuffer, "coord of Y ref pixel", false);
  image_extension.addKey("CDELT1", &psf_item.cdelt1_, // [m]
                         "X axis increment", false);
  image_extension.addKey("CDELT2", &psf_item.cdelt2_, // [m]
                         "Y axis increment", false);
  
  dbuffer=0.0;
  image_extension.addKey("BACKGRND", &dbuffer, "background count rate per pixel", false);
  
  // Mission
  image_extension.addKey("TELESCOP", "", "Mission name", false);
  image_extension.addKey("INSTRUME", "", "Instrument", false);
  image_extension.addKey("FILTER", "NONE", "Filter", false);
  
  // Creator.
  image_extension.addKey("ORIGIN", "ECAP", "", false);
  
  // Write the ENERGY, THETA, and PHI for this particular PSF set.
  // This information is used to find the appropriate PSF for
  // an incident photon with particular energy and off-axis angle.
  dbuffer=energy*1000.;
  image_extension.addKey( "ENERGY", &dbuffer,
                          "photon energy for the PSF generation in [eV]", false);
  
  dbuffer=theta*180.*60./M_PI;
  image_extension.addKey("THETA", &dbuffer,
                         "off-axis angle in [arc min]", false);
  
  dbuffer=phi*180./M_PI;
  image_extension.addKey( "PHI", &dbuffer,
                          "azimuthal angle in [degree]", false);
  
  image_extension.addKey( "ENERG_LO", &energy, "[keV]", false);
  image_extension.addKey("ENERG_HI", &energy, "[keV]", false);
  
  dbuffer=-99.0;
  image_extension.addKey( "CHANMIN", &dbuffer, "", false);
  image_extension.addKey( "CHANMAX", &dbuffer, "", false);
  image_extension.addKey( "CHANTYPE", "PI", "", false);
  
  image_extension.addKey( "CCLS0001", "BCF", "", false);
  image_extension.addKey( "CDTP0001", "TASK", "", false);
  image_extension.addKey( "CCNM0001", "2D_PSF", "", false);
  
  std::string sbuffer = "ENERGY( " + std::to_string(energy) + ")keV";
  image_extension.addKey("CBD10001", sbuffer, "", false);
  
  sbuffer = "THETA( " + std::to_string(theta*180.*60./M_PI) + ")arcmin";
  image_extension.addKey("CBD20001", sbuffer, "", false);
  
  sbuffer = "PHI( " + std::to_string(phi*180./M_PI) + ")deg";
  image_extension.addKey("CBD30001", sbuffer, "", false);
  
  image_extension.addKey("CVSD0001", "2000-01-01", "", false);
  image_extension.addKey("CVST0001", "00:00:00", "", false);
  image_extension.addKey("CDES0001", "Theoretical images",
                         "", false);
  
  int status = 0;
  HDpar_stamp(image_extension.fitsPointer(), (int)nhdus, &status);
  if (status) throw SixteException("Adding runtime information to psf image failed!");
}

}
