/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2025 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#pragma once

#include "array"
#include "iostream"
#include "sixte_random.h"
#include "vector"
#include <cassert>
#include <numeric>

namespace sixte {

/** Stochastic interpolation class.

  This class contains data on a multidimensional grid along which stochastic
  interpolation is performed.
  Concretely, for a single evaluation, stochastic interpolation does not return
   an interpolated value but only a given grid point which is near by the
  requested point.
  The stochastic part here is that the chosen nearby grid point is random,
  weighted towards grid points that are closer to the requested point.

  As an example, in the one dimensional case:
  Say you have grid points 0 and 1 and are interpolating at x=0.3. In this case,
  for 70% of cases, this class would return point 0, and 30% of cases point 1.

  Note that this class does not extrapolate - interpolation below the lowest
  grid point always return the lowest grid point, values above the highest point
  always return the highest.
*/
template <size_t DIMEN, typename Contained> class StochInterp {
public:
  /** Constructor.
   * Simply allocates memory for interpolation points - the grids need to be
   * defined via the set_grid method, and interpolation points defined via the
   * set_inerpolation_point method
   *
   * @param grid_sizes   Number of grid points for each dimension
   */
  StochInterp(std::array<size_t, DIMEN> grid_sizes);

  /** Constructor using a default value
   * Allocates memory for interpolation points and sets them to the default
   * value. This is useful for classes which can not be default-constructed
   *
   * @param grid_sizes   Number of grid points for each dimension
   */
  StochInterp(std::array<size_t, DIMEN> grid_sizes, Contained &default_val);

  /** Reallocate memory for internal storage
   *
   * Note that this resets the internal buffers and grids of the interpolator.
   * This can be used when the dimensions of the interpolation are not yet clear
   * at time of construction
   */
  void reallocate(std::array<size_t, DIMEN> grid_sizes);

  /** Define interpolation grid
   *
   * @param dimen  Index of the dimension to assign
   * @param grid   Grid points to use
   */
  void set_grid(size_t dimen, std::vector<double> &&grid);

  /** Assign a value to a grid point
   *
   * @param indices  Grid point indices
   * @param data     Data to assign
   */
  void set_interpolation_point(std::array<size_t, DIMEN> indices,
                               Contained &&data);

  /** Validate that an interpolation grid is fully defined
   *
   * @return   true, if a grid is defined
   */
  bool validate() const; // validate that e.g. all grid points are filled

  /** Perform stochastic interpolation to get an index in one dimension
   *
   * @param dimen       Dimension in which to interpolate
   * @param interp_val  Interpolation value
   *
   * @return            Interpolated index
   */
  size_t interpolate_1D(size_t dimen, double interp_val) const;

  /** Simultaneously perform interpolation for all dimensions
   *
   * @param interp_vals  Interpolation values, one per dimension
   *
   * @return            Interpolated indices
   */
  std::array<size_t, DIMEN>
  interpolate(const std::array<double, DIMEN> &interp_vals) const;
  // std::array<size_t, DIMEN>
  // interpolate(std::array<double, DIMEN> interp_vals) const;

  /** Retrieve contained data
   *
   * @param indices  Indices generated via e.g. get_indices
   *
   * @return         Stored value at the given indices
   */
  const Contained &retrieve(const std::array<size_t, DIMEN> &indices) const;
  // const Contained &retrieve(std::array<size_t, DIMEN> indices) const;

  /** Get the interpolation grid of a given dimension
   *
   * @param dimen  Grid dimension
   *
   * @return       Grid vector
   */
  const std::vector<double>& get_grid(size_t dimen) const {
    assert(dimen < DIMEN);
    return interp_grids_[dimen];
  }

  /** Get the numeric value of a given interpolation point
   *
   * @param dimen  Grid dimension
   * @param idx    Grid point index
   *
   * @return       Grid value at given dimension and index
   */
  double get_interp_value(size_t dimen, size_t idx) const {
    assert(dimen < DIMEN);
    return interp_grids_[dimen][idx];
  }

private:
  std::array<size_t, DIMEN> grid_sizes_;
  std::array<std::vector<double>, DIMEN> interp_grids_;
  std::vector<Contained> data_;

  size_t convert_indices(const std::array<size_t, DIMEN> &indices) const;
};

/* IMPLEMENTATION */

template <size_t DIMEN, typename Contained>
StochInterp<DIMEN, Contained>::StochInterp(std::array<size_t, DIMEN> grid_sizes)
    : grid_sizes_(grid_sizes) {

  size_t num_pts = std::accumulate(grid_sizes.begin(), grid_sizes.end(), 1,
                                   std::multiplies<size_t>());

  data_.resize(num_pts);

  // no need to do any assigments for interp_pts_, we fill that later
}

template <size_t DIMEN, typename Contained>
StochInterp<DIMEN, Contained>::StochInterp(std::array<size_t, DIMEN> grid_sizes,
                                           Contained &default_val)
    : grid_sizes_(grid_sizes) {

  size_t num_pts = std::accumulate(grid_sizes.begin(), grid_sizes.end(), 1,
                                   std::multiplies<size_t>());

  data_.reserve(num_pts);

  for (size_t ii = 0; ii < num_pts; ii++) {
    data_.push_back(default_val);
  }

  // no need to do any assigments for interp_pts_, we fill that via std::move
}

template <size_t DIMEN, typename Contained>
void StochInterp<DIMEN, Contained>::reallocate(std::array<size_t, DIMEN> grid_sizes) {
  grid_sizes_ = grid_sizes;

  // reset the interpolation grid
  for (size_t ii=0; ii < DIMEN; ii++) {
    interp_grids_[ii] = {};
  }

  size_t num_pts = std::accumulate(grid_sizes.begin(), grid_sizes.end(), 1,
                                   std::multiplies<size_t>());

  data_.resize(num_pts);
}

template <size_t DIMEN, typename Contained>
void StochInterp<DIMEN, Contained>::set_grid(size_t dimen,
                                             std::vector<double> &&grid) {

  assert(dimen < DIMEN);
  assert(grid.size() == grid_sizes_[dimen]);

  // TODO validate that the grid is monotonically increasing

  interp_grids_[dimen] = grid;
}

template <size_t DIMEN, typename Contained>
void StochInterp<DIMEN, Contained>::set_interpolation_point(
    std::array<size_t, DIMEN> indices, Contained &&data) {

  auto idx = convert_indices(indices);

  assert(idx < data_.size());

  data_[idx] = data;
}

template <size_t DIMEN, typename Contained>
bool StochInterp<DIMEN, Contained>::validate() const {

  // do we have a proper grid for every dimension?
  for (size_t ii = 0; ii < DIMEN; ii++) {
    if (interp_grids_[ii].size() == 0) {
      std::cout << "StochInterp: Grid in dimension " << ii << " is not defined!"
                << std::endl;
      return false;
    }
  }

  // TODO maybe check if any contained data is equal to the default value?
  // on the other hand, this is not necessarily an error...

  return true;
}

template <size_t DIMEN, typename Contained>
size_t StochInterp<DIMEN, Contained>::interpolate_1D(size_t dimen,
                                                     double interp_val) const {

  assert(dimen < DIMEN);

  auto &thisgrid = interp_grids_[dimen];

  // find the last grid point that is smaller than the interpolation value
  size_t index;
  for (index = 0; index < thisgrid.size() - 1; index++) {
    if (thisgrid[index + 1] > interp_val)
      break;
  }

  // perform the stochastic interpolation
  // core idea: if the interpolation value is between two grid points, you
  // should be more likely to pick the point you are closer to
  if (index < thisgrid.size() - 1) {
    double rnd = getUniformRandomNumber(); // 0 to 1
    if (rnd < ((interp_val - thisgrid[index]) /
               (thisgrid[index + 1] - thisgrid[index]))) {
      index++;
    }
  }
  return index;
}

template <size_t DIMEN, typename Contained>
std::array<size_t, DIMEN> StochInterp<DIMEN, Contained>::interpolate(
    const std::array<double, DIMEN> &interp_vals) const {

  std::array<size_t, DIMEN> out;

  for (size_t dim = 0; dim < DIMEN; dim++) {
    out[dim] = interpolate_1D(dim, interp_vals[dim]);
  }

  return out;
}

/*
template <size_t DIMEN, typename Contained>
std::array<size_t, DIMEN> StochInterp<DIMEN, Contained>::interpolate(
    std::array<double, DIMEN> interp_vals) const {

  return interpolate(&interp_vals);
}
*/

template <size_t DIMEN, typename Contained>
const Contained &StochInterp<DIMEN, Contained>::retrieve(
    const std::array<size_t, DIMEN> &indices) const {
  return data_.at(convert_indices(indices));
}

/*
template <size_t DIMEN, typename Contained>
const Contained &StochInterp<DIMEN, Contained>::retrieve(
    std::array<size_t, DIMEN> indices) const {
  return data_.at(convert_indices(indices));
}
*/

template <size_t DIMEN, typename Contained>
size_t StochInterp<DIMEN, Contained>::convert_indices(
    const std::array<size_t, DIMEN> &indices) const {
  size_t idx = 0;
  size_t multi = 1;

  for (size_t ii = 0; ii < DIMEN; ii++) {
    idx += indices[ii] * multi;
    multi *= grid_sizes_[ii];
  }

  return idx;
}
} // namespace sixte
