/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2022 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include "PhotonInteractionStrategy.h"
#include "NewSIXT.h"
#include "sixte_random.h"
#include <map>
#include <cmath>
#include <algorithm>

namespace sixte {

EnergyDepositions RMFBasedInteraction::doInteraction(const Geometry& absorber_geometry,
                                                     const SixtePhoton &photon) const {
  // TODO: detector_position should be named focal_position, shouldn't it?
  // TODO: General: Do transformation to detector coordinates here?
  auto photon_det_position = absorber_geometry.transformFocalToDet(*(photon.detector_position()));
  if (!absorber_geometry.contains(photon_det_position)) {
    return EnergyDepositions{};
  }

  // Determine the measured detector channel according to the RMF.
  auto channel = rmf_->sampleChannel(photon.energy());

  // Check if the photon is really measured (if the returned channel
  // is '-1', the photon is not detected).
  if (channel < rmf_->firstChannel()) {
    return EnergyDepositions{};
  }

  // Determine the corresponding detected energy.
  double energy = rmf_->sixteGetEBOUNDSEnergy(channel);

  return EnergyDepositions{ EnergyDeposition(energy, *(photon.detector_position()),
                                             photon.photon_metainfo(), photon.time()) };
}

/*
class ScatterPhys: public PhotonInteractionStrategy {
};

class Thermalization: public PhotonInteractionStrategy {
};
*/


RMFBasedInteraction::RMFBasedInteraction(std::shared_ptr<const NewRMF> rmf)
  : rmf_(std::move(rmf)),
    canonical_ebounds_(std::make_shared<Ebounds>(rmf_->exportEbounds())) {}

EnergyDepositions FullAbsorption::doInteraction(
    const Geometry& /*absorber_geometry*/,
    const SixtePhoton &photon) const {
  return EnergyDepositions{ EnergyDeposition(
      photon.energy(),
      *photon.detector_position(),
      photon.photon_metainfo(), photon.time()
    ) };
}

static void uniqueSorted(std::vector<double>& vec) {
  std::sort(vec.begin(), vec.end());
  vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
}

double AngleDependentRMFBasedInteraction::foldPhiFourfold(double phi_deg) {
  double phi90 = std::fmod(phi_deg, 90.0);
  if (phi90 < 0) phi90 += 90.0;
  return (phi90 <= 45.0) ? phi90 : (90.0 - phi90);
}

void AngleDependentRMFBasedInteraction::validateEboundsUniform(
  const std::vector<std::shared_ptr<const NewRMF>>& nodes) {

  const auto& ref = nodes.front();
  const auto nchan = ref->numberChannels();
  const auto first_channel = ref->firstChannel();

  for (const auto& node : nodes) {
    if (node->numberChannels() != nchan || node->firstChannel() != first_channel) {
      throw SixteException("AngleDependentRMF: channel layout differs across RMFs");
    }
    for (long ii = 0; ii < static_cast<long>(nchan); ++ii) {
      const long channel = first_channel + ii;
      auto [lo_ref, hi_ref] = ref->sixteGetEBOUNDSEnergyLoHi(channel);
      auto [lo_cur, hi_cur] = node->sixteGetEBOUNDSEnergyLoHi(channel);
      if (!nearEqual(lo_ref, lo_cur, 1e-5, 1e-6) || !nearEqual(hi_ref, hi_cur, 1e-5, 1e-6)) {
        throw SixteException("AngleDependentRMF: EBOUNDS mismatch across RMFs");
      }
    }
  }
}

AngleDependentRMFBasedInteraction::AngleDependentRMFBasedInteraction(const XMLData& xml_data,
                                                                     const std::shared_ptr<RmfRegistry>& rmf_registry) {
  const auto grid = xml_data.child("detector").child("rmf_angle_grid");

  struct Cell { double theta; double phi_folded; std::string path; };
  std::vector<Cell> cells;

  std::map<std::pair<double,double>, std::string> angle_to_path;

  for (const auto& rmf : grid.children("rmf")) {
    double theta = rmf.attributeAsDouble("theta");
    double phi_folded = foldPhiFourfold(rmf.attributeAsDouble("phi"));
    auto key = std::make_pair(theta, phi_folded);
    auto inserted = angle_to_path.emplace(key, xml_data.dirname() + rmf.attributeAsString("filename"));
    if (!inserted.second) {
      throw SixteException("AngleDependentRMF: duplicate entry for theta=" +
                           std::to_string(theta) + " phi=" + std::to_string(phi_folded) + " after folding");
    }
    cells.push_back({theta, phi_folded, inserted.first->second});
  }

  std::vector<double> theta_axis, phi_axis;
  for (const auto& cell : cells) { theta_axis.push_back(cell.theta); phi_axis.push_back(cell.phi_folded); }
  uniqueSorted(theta_axis);
  uniqueSorted(phi_axis);

  // Resize stochastic interpolator to the rectangular grid we are about to fill.
  interp_.reallocate({theta_axis.size(), phi_axis.size()});

  std::vector<std::shared_ptr<const NewRMF>> nodes;
  nodes.reserve(theta_axis.size() * phi_axis.size());
  for (double phi_folded : phi_axis) for (double theta : theta_axis) {
    auto it = angle_to_path.find({theta, phi_folded});
    if (it == angle_to_path.end()) {
      throw SixteException("AngleDependentRMF: rectangular grid required; missing theta=" +
                           std::to_string(theta) + " phi=" + std::to_string(phi_folded));
    }
    rmf_registry->load(it->second);
    nodes.emplace_back(rmf_registry->get(it->second));
  }

  validateEboundsUniform(nodes);
  canonical_ebounds_ = std::make_shared<Ebounds>(nodes.front()->exportEbounds());

  size_t n_theta = theta_axis.size();
  size_t n_phi = phi_axis.size();

  interp_.set_grid(0, std::move(theta_axis));
  interp_.set_grid(1, std::move(phi_axis));
  size_t idx = 0;
  for (size_t phi_index = 0; phi_index < n_phi; ++phi_index) {
    for (size_t theta_index = 0; theta_index < n_theta; ++theta_index) {
      interp_.set_interpolation_point({theta_index, phi_index}, std::move(nodes[idx++]));
    }
  }
}

EnergyDepositions AngleDependentRMFBasedInteraction::doInteraction(const Geometry& absorber_geometry,
                                                                   const SixtePhoton& photon) const {
  auto photon_det_position = absorber_geometry.transformFocalToDet(*(photon.detector_position()));
  if (!absorber_geometry.contains(photon_det_position)) {
    return EnergyDepositions{};
  }

  auto incident_angles = photon.photon_metainfo_ref().incidentAnglesDeg();
  if (!incident_angles) throw SixteException("AngleDependentRMF: incident angles missing");

  double theta = incident_angles->first;
  double phi = foldPhiFourfold(incident_angles->second);

  const auto idx = interp_.interpolate({theta, phi});
  const auto& rmf = interp_.retrieve(idx);
  const long channel = rmf->sampleChannel(photon.energy());
  if (channel < rmf->firstChannel()) return {};
  const double e_det = rmf->sixteGetEBOUNDSEnergy(channel);

  return EnergyDepositions{ EnergyDeposition(e_det, *(photon.detector_position()),
                                             photon.photon_metainfo(), photon.time()) };
}

}
