/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2025 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>
#include <cmath>

#include "StochInterp.h"

using namespace sixte;


TEST_CASE("StochInterp test", "[StochInt]") {
  setenv("SIXTE_USE_PSEUDO_RNG","1",1);
  sixte::initSixteRng(1);
 
  SECTION("1D") {
    // setup
    std::vector<double> ingrid = {0,1};
    std::vector<double> gridvals = {0,10};

    StochInterp<1,double> interp({ingrid.size()});

    interp.set_grid(0,std::move(ingrid));

    for (size_t ii=0; ii<gridvals.size(); ii++) {
      double val = gridvals[ii];
      interp.set_interpolation_point({ii}, std::move(val));
    }

    // testing
    std::vector<std::array<double,1>> refpts{{0.7}, {0.2}, {0.5}, {-1}, {2}};
    std::vector<double> refvals = {7, 2, 5, 0, 10};

    for (size_t i_test=0; i_test < refpts.size(); i_test++) {
      size_t num_tries = 10000;
      double interp_val = 0;

      for (size_t ii=0; ii<num_tries; ii++) {
        auto idx = interp.interpolate(refpts[i_test]);
        double res = interp.retrieve(idx);
        interp_val += res;
      }

      interp_val /= num_tries;

      REQUIRE_THAT(interp_val, Catch::Matchers::WithinAbs(refvals[i_test], 1e-1));

    }

  }

  SECTION("2D") {
    // setup
    std::vector<double> ingrid1 = {0,1};
    std::vector<double> ingrid2 = {0,1};

    StochInterp<2,double> interp({ingrid1.size(),ingrid2.size()});

    interp.set_grid(0,std::move(ingrid1));
    interp.set_grid(1,std::move(ingrid2));

    // approximate a radial function
    interp.set_interpolation_point({0,0}, 0);
    interp.set_interpolation_point({0,1}, 1);
    interp.set_interpolation_point({1,0}, 1);
    interp.set_interpolation_point({1,1}, sqrt(2));

    // testing
    std::vector<std::array<double,2>> refpts{{0.5,0.5}, {0.8,0.8}, {0.5,0}, {0.7,0.3}, {0.3, 2}, {10,10}};
    std::vector<double> refvals = {0.8536, 1.225, 0.5, 0.877, 1.12426, sqrt(2)};

    for (size_t i_test=0; i_test < refpts.size(); i_test++) {
      size_t num_tries = 10000;
      double interp_val = 0;

      for (size_t ii=0; ii<num_tries; ii++) {
        auto idx = interp.interpolate(refpts[i_test]);
        double res = interp.retrieve(idx);
        interp_val += res;
      }

      interp_val /= num_tries;

      REQUIRE_THAT(interp_val, Catch::Matchers::WithinAbs(refvals[i_test], 1e-2));

    }

  }

}
