/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2022 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include "ArrayGeometry.h"
#include "Polygon.h"
#include "SixteException.h"
#include "sixte_random.h"
#include "Telescope_attitude.h"
#include "rndgen.h"
#include <algorithm>
#include <math.h>
#include <string>

namespace sixte {

std::pair<double,double> detToSky(double detx, double dety, double focal_length,
    double time, NewAttitude& ac, const Geometry& absorber_geometry) {

  SixtePoint detpos_3d(detx, dety, 0.);
  auto fpos = absorber_geometry.transformDetToFocal(detpos_3d);

  double fx = fpos.x();
  double fy = fpos.y();

  Telescope_attitude telescope_attitude = ac.getTelescopeAxes(time);

  // Determine the source position on the sky using the telescope
  // axis pointing vector and a vector from the point of the intersection
  // of the optical axis with the sky plane to the source position.
  SixteVector srcpos;
  srcpos = telescope_attitude.nz + (fx / focal_length * telescope_attitude.nx) + (fy / focal_length * telescope_attitude.ny);
  srcpos = normalizeVector(srcpos);
  
  // (RA and DEC are in the range [-pi:pi] and [-pi/2:pi/2] respectively).
  return calculateRaDec(srcpos);
}

std::vector<Rectangle2d> removeOverlaps(const std::vector<Rectangle2d> &pixels) {
  std::vector<Rectangle2d> out;

  // if pixels overlap, keep only the newest one
  for (size_t i_pix = 0; i_pix < pixels.size(); i_pix++) {
    // check if any pixels down the line overlap with this one
    auto thispix = pixels[i_pix];
    if (std::none_of(pixels.begin()+i_pix+1, pixels.end(),
        [&thispix](Rectangle2d pix) {return pix.doOverlap(thispix);}
        ))
    {
      out.push_back(thispix);
    }
  }

  return out;
}

RectangularArray::RectangularArray(XMLData& xml_data) {
  // Read parameters from XML
  auto detector = xml_data.child("detector");
  
  // Check if geometry node exists, if so use it as parent for dimensions, wcs, pixelborder
  // Otherwise use detector directly (for backward compatibility)
  auto parent_node = detector.optionalChild("geometry").value_or(detector);
  
  auto dimensions = parent_node.child("dimensions");
  xwidth_ = dimensions.attributeAsInt("xwidth");
  ywidth_ = dimensions.attributeAsInt("ywidth");

  auto wcs = parent_node.child("wcs");
  xrpix_ = wcs.attributeAsDouble("xrpix");
  yrpix_ = wcs.attributeAsDouble("yrpix");
  xdelt_ = wcs.attributeAsDouble("xdelt");
  ydelt_ = wcs.attributeAsDouble("ydelt");

  // Optional pixelborder element, but if present both x and y are required
  if (auto pixelborder = parent_node.optionalChild("pixelborder")) {
    xborder_ = pixelborder->attributeAsDouble("x");
    yborder_ = pixelborder->attributeAsDouble("y");
  }

  // Initialize prototype pixel.
  Point_2 bottom_left(- xdelt_ * (xrpix_ - 0.5),
                      - ydelt_ * (yrpix_ - 0.5));
  Point_2 top_right(- xdelt_ * (xrpix_ - 1.5),
                    - ydelt_ * (yrpix_ - 1.5));
  rectangle_ = Rectangle2d(bottom_left, top_right);

  total_surface_area_ = xwidth_ * xdelt_ * ywidth_ * ydelt_;

  auto xshift = xrpix_ - 0.5;
  auto min_x = (0. - xshift) * xdelt_;
  auto max_x = (xwidth_ - xshift) * xdelt_;

  auto yshift = yrpix_ - 0.5;
  auto min_y = (0. - yshift) * ydelt_;
  auto max_y = (ywidth_ - yshift) * ydelt_;

  Point_2 chip_bottom_left(min_x, min_y);
  Point_2 chip_top_right(max_x, max_y);

  bounding_box_ = {chip_bottom_left, chip_top_right};
}

std::optional<std::pair<size_t, size_t>> RectangularArray::getXY(const Point_2& position) const {
// get x and y
  double xd = position.x();
  double yd = position.y();

// Calculate the real valued pixel indices.
  double xb = xd / xdelt_ + (xrpix_ - 0.5);
  double yb = yd / ydelt_ + (yrpix_ - 0.5);

// Calculate the integer pixel indices.
  int xi = ((int) (xb + 1.)) - 1;
  int yi = ((int) (yb + 1.)) - 1;
//                |----|---->  avoid (int)(-0.5) = 0

// Check if this is a valid pixel.
  if ((xi >= xwidth_) || (xi < 0) || (yi >= ywidth_) || (yi < 0))
    return std::nullopt;

// Check if the impact is located on one of the pixel borders
// TODO check if this works without xrval
  if (xborder_ > 0. || yborder_ > 0.) {
    if (((xi - xrpix_ + 1.5) * xdelt_ - position.x() < xborder_) ||
        (position.x() + (xi - xrpix_ + 0.5) * xdelt_ < xborder_)) {
      return std::nullopt;
    }

    if (((yi - yrpix_ + 1.5) * ydelt_ - position.y() < yborder_) ||
        (position.y() + (yi - yrpix_ + 0.5) * ydelt_ < yborder_)) {
      return std::nullopt;
    }
  }

  return std::make_pair(xi, yi);
}

std::optional<T_PixId> RectangularArray::getPixId(const Point_2& position) const {
  std::optional<std::pair<size_t, size_t>> XY = getXY(position);
  if (XY) {
    return xy2PixId(XY->first, XY->second);
  }
 return std::nullopt;
}

std::vector<T_PixId> RectangularArray::getPixIds(const BoundingBox2d& box) const {
  std::vector<T_PixId> pix_ids;
  std::optional<T_PixId> pix_id;

  auto xy_top_left = getXY(box.top_left());
  auto xy_top_right = getXY(box.top_right());
  auto xy_bottom_left = getXY(box.bottom_left());
  auto xy_bottom_right = getXY(box.bottom_right());

  size_t x_min, x_max; // x is counted from left to right
  size_t y_min, y_max; // y is counted from bottom to top (!!!)

  if (xy_top_left) {
    x_min = xy_top_left->first;
    y_max = xy_top_left->second;
  } else if (xy_top_right) {
    x_min = 0;
    y_max = xy_top_right->second;
  } else if (xy_bottom_left) {
    x_min = xy_bottom_left->first;
    y_max = 0;
  } else if (xy_bottom_right) {
    x_min = 0;
    y_max = 0;
  } else {
    return pix_ids;
  }

  if (xy_bottom_right) {
    x_max = xy_bottom_right->first;
    y_min = xy_bottom_right->second;
  } else if (xy_bottom_left) {
    x_max = xwidth_ - 1;
    y_min = xy_bottom_left->second;
  } else if (xy_top_right) {
    x_max = xy_top_right->first;
    y_min = ywidth_ - 1;
  } else if (xy_top_left) {
    x_max = xwidth_ - 1;
    y_min = ywidth_ - 1;
  } else {
    return pix_ids;
  }

  for (size_t xi = x_min; xi <= x_max; xi++) {
    for (size_t yi = y_min; yi <= y_max; yi++) {
      pix_id = xy2PixId(xi, yi);
      if (!pix_id.has_value()) throw SixteException ("Invalid PixID in BoundingBox");
      if (std::find(std::begin(pix_ids), std::end(pix_ids), pix_id.value()) == std::end(pix_ids)) {
        pix_ids.push_back(pix_id.value());
      }
    }
  }

  return pix_ids;
}

T_PixId RectangularArray::xy2PixId(unsigned int xi, unsigned int yi) const {
  return yi * xwidth_ + xi;
}

T_PixId RectangularArray::getRandomPixId() const {
  int xi = (int) (getUniformRandomNumber() * xwidth_);
  int yi = (int) (getUniformRandomNumber() * ywidth_);

  return xy2PixId(xi, yi);

  // TODO (after verification): Would be more efficient, but gives different
  //  results due to RNG calls:
  //return (int) (getUniformRandomNumber() * static_cast<double>(numpix()));
}

std::pair<double,double> RectangularArray::getRandPosInPixel(T_PixId pix_id) const {

  auto [rawx, rawy] = PixId2xy(pix_id);

  // determine position within pixel
  double x_off = getUniformRandomNumber();
  double y_off = getUniformRandomNumber();

  // then on focal plane
  double xb = (rawx * 1. - xrpix_ + 0.5 + x_off) * xdelt_;
  double yb = (rawy * 1. - yrpix_ + 0.5 + y_off) * ydelt_;

  return std::make_pair(xb, yb);
}

double RectangularArray::totalSurfaceArea() const {
  return total_surface_area_;
}

std::pair<unsigned int, unsigned int> RectangularArray::PixId2xy(T_PixId pix_id) const {
  unsigned int xi = pix_id % xwidth_;
  unsigned int yi = (pix_id - xi) / xwidth_;
  return std::make_pair(xi, yi);
}

Rectangle2d RectangularArray::getPolygon(T_PixId pix_id) const {
  auto xy = PixId2xy(pix_id);
  return translate(rectangle_, Vector_2(xy.first * xdelt_, xy.second * ydelt_));
}

size_t RectangularArray::numpix() const {
  static size_t num_pix = xwidth_ * ywidth_;
  return num_pix;
}

int RectangularArray::getXWidth() const {
  return xwidth_;
}

int RectangularArray::getYWidth() const {
  return ywidth_;
}

void RectangularArray::doPhotonProjection(NewAttitude& ac,
                                          const Geometry& absorber_geometry,
                                          double focal_length,
                                          fitsfile *fptr,
                                          double tstart,
                                          double tstop) const {
  //ToDo: use CCFits
  int status = EXIT_SUCCESS;

  // need to do this copy business because the fits routines don't use
  // const char* - so I can't just use inline strings...
  char buffer[MAXMSG];

  strncpy(buffer, "EVENTS", 7);
  fits_movnam_hdu(fptr, BINARY_TBL, buffer, 1, &status);
  checkStatusThrow(status, "Failed to find EVENTS extension for photon projection");

  int c_rawx, c_rawy, c_time, c_ra, c_dec;

  strncpy(buffer, "RAWX", 5);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_rawx, &status);
  checkStatusThrow(status, "Failed to find RAWX column for photon projection");

  strncpy(buffer, "RAWY", 5);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_rawy, &status);
  checkStatusThrow(status, "Failed to find RAWY column for photon projection");

  strncpy(buffer, "TIME", 5);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_time, &status);
  checkStatusThrow(status, "Failed to find TIME column for photon projection");

  strncpy(buffer, "RA", 3);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_ra, &status);
  checkStatusThrow(status, "Failed to find RA column for photon projection");

  strncpy(buffer, "DEC", 4);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_dec, &status);
  checkStatusThrow(status, "Failed to find DEC column for photon projection");

  long numrows;
  fits_get_num_rows(fptr, &numrows, &status);
  checkStatusThrow(status, "Failed to determine number of rows in FITS file");

  unsigned int rawx, rawy;
  double time;
  
  for (long row = 0; row < numrows; row++) {

    fits_read_col(fptr, TUINT, c_rawx, row+1, 1, 1, NULL, &rawx, NULL, &status);
    checkStatusThrow(status, "Failed to read RAWX column");
    fits_read_col(fptr, TUINT, c_rawy, row+1, 1, 1, NULL, &rawy, NULL, &status);
    checkStatusThrow(status, "Failed to read RAWY column");
    fits_read_col(fptr, TDOUBLE, c_time, row+1, 1, 1, NULL, &time, NULL, &status);
    checkStatusThrow(status, "Failed to read TIME column");

    // Check whether we are still within the requested time interval.
    if (time < tstart) continue;
    if (time > tstop) break;
    
    // determine position within pixel
    // TODO could also use SUBX and SUBY columns here, if they exist
    auto [detx, dety] = getRandPosInPixel(xy2PixId(rawx, rawy));

    // get sky coordinates and write
    auto [ra, dec] = detToSky(detx, dety, focal_length, time, ac, absorber_geometry);

    ra *= 180/M_PI;
    dec *= 180/M_PI;

    fits_write_col(fptr, TDOUBLE, c_ra, row+1, 1, 1, &ra, &status);
    checkStatusThrow(status, "Failed to write RA column");
    fits_write_col(fptr, TDOUBLE, c_dec, row+1, 1, 1, &dec, &status);
    checkStatusThrow(status, "Failed to write DEC column");
  }
}

const BoundingBox2d& RectangularArray::boundingBox() const {
  return bounding_box_;
}

FreeGeometry::FreeGeometry(XMLData& xml_data) {
  // Read parameters from XML
  auto geometry = xml_data.child("detector").child("geometry");

  size_t npix = geometry.attributeAsInt("npix");
  double xoff = geometry.attributeAsDouble("xoff");
  double yoff = geometry.attributeAsDouble("yoff");

  std::vector<Rectangle2d> pixels_read;
  pixels_read.reserve(npix);

  // record extrema to construct a bounding box
  double min_x, min_y, max_x, max_y;
  bool extrema_set = false;

  for (auto pixnode: geometry.children("pixel")) {
    auto shapenode = pixnode.child("shape");
    double posx = shapenode.attributeAsDouble("posx");
    double posy = shapenode.attributeAsDouble("posy");

    double delx = shapenode.attributeAsDouble("delx");
    double dely = shapenode.attributeAsDouble("dely");

    double width = shapenode.attributeAsDouble("width");
    double height = shapenode.attributeAsDouble("height");

    double x_lo = posx*delx - width/2 + xoff;
    double x_hi = posx*delx + width/2 + xoff;
    double y_lo = posy*dely - height/2 + yoff;
    double y_hi = posy*dely + height/2 + yoff;

    Point_2 bottom_left(x_lo, y_lo);
    Point_2 top_right(x_hi, y_hi);

    pixels_read.emplace_back(bottom_left, top_right);

    if (!extrema_set) {
      min_x = x_lo;
      max_x = x_hi;
      min_y = y_lo;
      max_y = y_hi;
      extrema_set = true;
    } else {
      min_x = std::min(x_lo, min_x);
      max_x = std::max(x_hi, max_x);
      min_y = std::min(y_lo, min_y);
      max_y = std::max(y_hi, max_y);
    }
  }

  pixels_ = removeOverlaps(pixels_read);

  healog(0) << "Number of pixels after removing overlaps: "
    << pixels_.size() << "\n" << std::endl;

  if (pixels_.size() != npix) {
    throw SixteException(
        "Number of pixels specified in geometry tag ("
        + std::to_string(npix)
        + ") does not match number of pixels read from pixel tags ("
        + std::to_string(pixels_.size())
        + ")");
  }

  Point_2 bottom_left(min_x, min_y);
  Point_2 top_right(max_x, max_y);

  bounding_box = {bottom_left, top_right};


  for (const auto& pixel: pixels_) {
    total_surface_area_ += pixel.height()*pixel.width();
  }
}

T_PixId FreeGeometry::getRandomPixId() const {
  return (int) (getUniformRandomNumber() * static_cast<double>(numpix()));
}

std::pair<double,double> FreeGeometry::getRandPosInPixel(T_PixId pix_id) const {

    // determine position within pixel
    double x_off = getUniformRandomNumber();
    double y_off = getUniformRandomNumber();

    // then on focal plane
    auto bbox = pixels_[pix_id];
    double detx = bbox.bottom_left().x() + bbox.width() * x_off;
    double dety = bbox.bottom_left().y() + bbox.height() * y_off;

    return std::make_pair(detx, dety);
}

double FreeGeometry::totalSurfaceArea() const {
  return total_surface_area_;
}

std::optional<T_PixId> FreeGeometry::getPixId(const Point_2& position) const  {

  // first compare with the outer bounding box
  if (!bounding_box.containsPoint(position)) {
    return std::nullopt;
  }

  for (size_t ii=0; ii<numpix(); ii++) {
    if (pixels_[ii].containsPoint(position)) {
      return ii;
    }
  }

  return std::nullopt;
}


std::vector<T_PixId> FreeGeometry::getPixIds(const BoundingBox2d& box) const {

  std::vector<T_PixId> out {};

  for (size_t ii=0; ii<numpix(); ii++) {
    if (pixels_[ii].doOverlap(box)) {
      out.push_back(ii);
    }
  }
   return out;
}

Rectangle2d FreeGeometry::getPolygon(T_PixId pix_id) const {
  return pixels_[pix_id];
}


size_t FreeGeometry::numpix() const {
  return pixels_.size();
}

void FreeGeometry::doPhotonProjection(NewAttitude& ac, const Geometry& absorber_geometry,
    double focal_length, fitsfile *fptr, double tstart, double tstop) const {

  //ToDo: use CCFits
  int status = EXIT_SUCCESS;

  // need to do this copy business because the fits routines don't use
  // const char* - so I can't just use inline strings...
  char buffer[MAXMSG];

  strncpy(buffer, "EVENTS", 7);
  fits_movnam_hdu(fptr, BINARY_TBL, buffer, 1, &status);
  checkStatusThrow(status, "Failed to find EVENTS extension for photon projection");

  int c_pixid, c_time, c_ra, c_dec, c_detx, c_dety;

  strncpy(buffer, "PIXID", 6);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_pixid, &status);
  checkStatusThrow(status, "Failed to find PIXID column for photon projection");

  strncpy(buffer, "TIME", 5);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_time, &status);
  checkStatusThrow(status, "Failed to find TIME column for photon projection");

  strncpy(buffer, "RA", 3);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_ra, &status);
  checkStatusThrow(status, "Failed to find RA column for photon projection");

  strncpy(buffer, "DEC", 4);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_dec, &status);
  checkStatusThrow(status, "Failed to find DEC column for photon projection");

  strncpy(buffer, "DETX", 5);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_detx, &status);
  checkStatusThrow(status, "Failed to find DETX column for photon projection");

  strncpy(buffer, "DETY", 5);
  fits_get_colnum(fptr, CASEINSEN, buffer, &c_dety, &status);
  checkStatusThrow(status, "Failed to find DETY column for photon projection");

  long numrows;
  fits_get_num_rows(fptr, &numrows, &status);
  checkStatusThrow(status, "Failed to determine number of rows in FITS file");

  T_PixId pix_id;
  double time;

  for (long row = 0; row < numrows; row++) {

    fits_read_col(fptr, TINT, c_pixid, row+1, 1, 1, NULL, &pix_id, NULL, &status);
    checkStatusThrow(status, "Failed to read PIXID column");
    fits_read_col(fptr, TDOUBLE, c_time, row+1, 1, 1, NULL, &time, NULL, &status);
    checkStatusThrow(status, "Failed to read TIME column");

    // Check whether we are still within the requested time interval.
    if (time < tstart) continue;
    if (time > tstop) break;

    // determine position within pixel
    auto [detx, dety] = getRandPosInPixel(pix_id-1); // PIXIDs are 1-based in FITS file!

    // get sky coordinates and write
    auto [ra, dec] = detToSky(detx, dety, focal_length, time, ac, absorber_geometry);
    ra *= 180/M_PI;
    dec *= 180/M_PI;

    fits_write_col(fptr, TDOUBLE, c_ra, row+1, 1, 1, &ra, &status);
    checkStatusThrow(status, "Failed to write RA column");
    fits_write_col(fptr, TDOUBLE, c_dec, row+1, 1, 1, &dec, &status);
    checkStatusThrow(status, "Failed to write DEC column");
    fits_write_col(fptr, TDOUBLE, c_detx, row+1, 1, 1, &detx, &status);
    checkStatusThrow(status, "Failed to write DETX column");
    fits_write_col(fptr, TDOUBLE, c_dety, row+1, 1, 1, &dety, &status);
    checkStatusThrow(status, "Failed to write DETY column");
  }
}

const BoundingBox2d& FreeGeometry::boundingBox() const {
  return bounding_box;
}

std::unique_ptr<ArrayGeometry> createGeometry(XMLData& xml_data) {
  // TODO: Use enum+switch
  auto detector = xml_data.child("detector");
  if (auto geometry = detector.optionalChild("geometry")) {
    auto type = geometry->attributeAsString("type");
    
    if (type == "rectarray") return std::make_unique<RectangularArray>(xml_data);
    if (type == "free") return std::make_unique<FreeGeometry>(xml_data);
    
    throw SixteException("Unknown geometry type '" + type + "'. Supported types: rectarray, free");
  }
  
  printWarning("Geometry keyword not set in XML, will be set to default (rectangular array)");
  return std::make_unique<RectangularArray>(xml_data);
}

}
