/*
This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2025 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include "InverseTransformSampler.h"

#include <algorithm>
#include <cmath>
#include <fstream>
#include <map>
#include <numeric>
#include <sstream>

#include "sixte_random.h"

namespace sixte {

InverseTransformSampler::InverseTransformSampler(const std::string& filename,
                                                 SamplingMode mode,
                                                 const TruncationBounds& bounds)
    : sampling_mode_(mode),
      input_type_(InputType::discrete_points),
      truncation_bounds_(bounds) {
  initializeFromFile(filename);
  validateInput();
  computeTruncatedBounds();
}

InverseTransformSampler::InverseTransformSampler(
    const std::vector<double>& x_values, const std::vector<double>& weights,
    SamplingMode mode, const TruncationBounds& bounds)
    : sampling_mode_(mode),
      input_type_(InputType::discrete_points),
      truncation_bounds_(bounds) {
  initializeFromDiscretePoints(x_values, weights);
  validateInput();
  computeTruncatedBounds();
}

InverseTransformSampler::InverseTransformSampler(
    const std::vector<double>& sample_data, SamplingMode mode,
    const TruncationBounds& bounds)
    : sampling_mode_(mode),
      input_type_(InputType::empirical_data),
      truncation_bounds_(bounds) {
  initializeFromEmpiricalData(sample_data);
  validateInput();
  computeTruncatedBounds();
}

InverseTransformSampler::InverseTransformSampler(
    const std::vector<double>& lower_bins,
    const std::vector<double>& upper_bins, const std::vector<double>& weights,
    SamplingMode mode, const TruncationBounds& bounds)
    : sampling_mode_(mode),
      input_type_(InputType::bins),
      truncation_bounds_(bounds) {
  initializeFromBins(lower_bins, upper_bins, weights);
  validateInput();
  computeTruncatedBounds();
}

double InverseTransformSampler::sample() const {
  switch (input_type_) {
    case InputType::discrete_points:
      return sampleFromDiscretePoints();
    case InputType::empirical_data:
      return sampleFromEmpiricalData();
    case InputType::bins:
      return sampleFromBins();
    default:
      throw SixteException("Unknown input type");
  }
}

void InverseTransformSampler::initializeFromDiscretePoints(
    const std::vector<double>& x_values, const std::vector<double>& weights) {
  if (x_values.size() != weights.size()) {
    throw SixteException("x_values and weights must have the same size");
  }

  if (x_values.empty()) {
    throw SixteException("x_values and weights cannot be empty");
  }

  std::vector<double> x_copy = x_values;
  std::vector<double> weights_copy = weights;

  mergeDuplicatePoints(x_copy, weights_copy);
  validateAndNormalizeWeights(weights_copy);

  std::vector<std::pair<double, double>> x_weight_pairs;
  x_weight_pairs.reserve(x_copy.size());
  for (size_t ii = 0; ii < x_copy.size(); ++ii) {
    x_weight_pairs.emplace_back(x_copy[ii], weights_copy[ii]);
  }

  // Sort pairs by x value
  std::sort(x_weight_pairs.begin(), x_weight_pairs.end());

  // Build vectors
  discrete_x_values_.clear();
  discrete_cumulative_probabilities_.clear();
  discrete_x_values_.reserve(x_weight_pairs.size());
  discrete_cumulative_probabilities_.reserve(x_weight_pairs.size());

  double cumulative_probability = 0.0;
  for (const auto& pair : x_weight_pairs) {
    cumulative_probability += pair.second;
    discrete_x_values_.push_back(pair.first);
    discrete_cumulative_probabilities_.push_back(cumulative_probability);
  }
}

void InverseTransformSampler::initializeFromEmpiricalData(
    const std::vector<double>& sample_data) {
  if (sample_data.empty()) {
    throw SixteException("Sample data empty");
  }

  // Check for invalid values
  for (size_t ii = 0; ii < sample_data.size(); ++ii) {
    if (!std::isfinite(sample_data[ii])) {
      throw SixteException("Sample data contains non-finite value (" +
                           std::to_string(ii) + ")");
    }
  }

  empirical_data_ = sample_data;
  std::sort(empirical_data_.begin(), empirical_data_.end());
}

void InverseTransformSampler::initializeFromBins(
    const std::vector<double>& lower_bins,
    const std::vector<double>& upper_bins, const std::vector<double>& weights) {
  if (lower_bins.size() != upper_bins.size() ||
      lower_bins.size() != weights.size()) {
    throw SixteException(
        "lower_bins, upper_bins, and weights must have the same size");
  }

  if (lower_bins.empty()) {
    throw SixteException("Bin vectors cannot be empty");
  }

  // Validate bins
  for (size_t ii = 0; ii < lower_bins.size(); ++ii) {
    if (!std::isfinite(lower_bins[ii]) || !std::isfinite(upper_bins[ii])) {
      throw SixteException("Bin bounds non-finite at index " +
                           std::to_string(ii));
    }
    if (lower_bins[ii] > upper_bins[ii]) {
      throw SixteException("Lower bin bound > upper bin bound at index " +
                           std::to_string(ii));
    }
  }

  validateBinsSorted(lower_bins, upper_bins);

  // Validate weights
  std::vector<double> weights_copy = weights;
  validateAndNormalizeWeights(weights_copy);

  // Build bin intervals with CDF
  bin_intervals_.clear();
  bin_intervals_.reserve(lower_bins.size());

  double cumulative_probability = 0.0;
  for (size_t ii = 0; ii < lower_bins.size(); ++ii) {
    cumulative_probability += weights_copy[ii];
    bin_intervals_.push_back(
        {lower_bins[ii], upper_bins[ii], cumulative_probability});
  }
}

void InverseTransformSampler::initializeFromFile(const std::string& filename) {
  std::vector<double> x_values, weights;
  parseDataFile(filename, x_values, weights);
  initializeFromDiscretePoints(x_values, weights);
}

void InverseTransformSampler::validateInput() const {
  if (truncation_bounds_.lower_bound > truncation_bounds_.upper_bound) {
    throw SixteException("Truncation lower bound must be <= upper bound");
  }

  if (input_type_ == InputType::discrete_points && discrete_x_values_.empty()) {
    throw SixteException("No valid discrete points after initialization");
  }

  if (input_type_ == InputType::empirical_data && empirical_data_.empty()) {
    throw SixteException("No empirical data after initialization");
  }

  if (input_type_ == InputType::bins && bin_intervals_.empty()) {
    throw SixteException("No bin intervals after initialization");
  }
}

// Compute truncated bounds for CDF
void InverseTransformSampler::computeTruncatedBounds() {
  if (!truncation_bounds_.isActive()) {
    truncated_cdf_lower_ = 0.0;
    truncated_cdf_upper_ = 1.0;
    return;
  }

  switch (input_type_) {
    case InputType::discrete_points:
      truncated_cdf_lower_ =
          evaluateDiscretePointsCdf(truncation_bounds_.lower_bound);
      truncated_cdf_upper_ =
          evaluateDiscretePointsCdf(truncation_bounds_.upper_bound);
      break;
    case InputType::empirical_data:
      truncated_cdf_lower_ =
          evaluateEmpiricalCdf(truncation_bounds_.lower_bound);
      truncated_cdf_upper_ =
          evaluateEmpiricalCdf(truncation_bounds_.upper_bound);
      break;
    case InputType::bins:
      truncated_cdf_lower_ = evaluateBinsCdf(truncation_bounds_.lower_bound);
      truncated_cdf_upper_ = evaluateBinsCdf(truncation_bounds_.upper_bound);
      break;
  }

  if (truncated_cdf_lower_ >= truncated_cdf_upper_) {
    throw SixteException("Truncation bounds [" +
                         std::to_string(truncation_bounds_.lower_bound) + ", " +
                         std::to_string(truncation_bounds_.upper_bound) +
                         "] are outside the distribution support");
  }
}

void InverseTransformSampler::validateAndNormalizeWeights(
    std::vector<double>& weights) {
  for (size_t ii = 0; ii < weights.size(); ++ii) {
    if (!std::isfinite(weights[ii]) || weights[ii] < 0.0) {
      throw SixteException("Weights must be finite and non-negative (weights[" +
                           std::to_string(ii) +
                           "] = " + std::to_string(weights[ii]));
    }
  }

  double sum = std::accumulate(weights.begin(), weights.end(), 0.0);
  if (sum <= 0.0) {
    throw SixteException("Sum of weights must be positive");
  }

  // Normalize
  for (double& weight : weights) {
    weight /= sum;
  }
}

void InverseTransformSampler::mergeDuplicatePoints(
    std::vector<double>& x_values, std::vector<double>& weights) {
  if (x_values.size() <= 1) return;

  std::map<double, double> merged_map;
  for (size_t ii = 0; ii < x_values.size(); ++ii) {
    merged_map[x_values[ii]] += weights[ii];
  }

  if (merged_map.size() < x_values.size()) {
    x_values.clear();
    weights.clear();
    x_values.reserve(merged_map.size());
    weights.reserve(merged_map.size());

    for (const auto& pair : merged_map) {
      x_values.push_back(pair.first);
      weights.push_back(pair.second);
    }
  }
}

void InverseTransformSampler::validateBinsSorted(
    const std::vector<double>& lower_bins,
    const std::vector<double>& upper_bins) {
  for (size_t ii = 1; ii < lower_bins.size(); ++ii) {
    if (lower_bins[ii] < upper_bins[ii - 1]) {
      throw SixteException("Bins must be sorted and non-overlapping (bin " +
                           std::to_string(ii) +
                           " starts before previous bin ends");
    }
  }
}

double InverseTransformSampler::evaluateDiscretePointsCdf(double x) const {
  if (discrete_x_values_.empty()) return 0.0;

  // Find the rightmost point with x_value <= x
  auto it =
      std::upper_bound(discrete_x_values_.begin(), discrete_x_values_.end(), x);

  if (it == discrete_x_values_.begin()) {
    return 0.0;  // x is smaller than all points
  }

  size_t index = std::distance(discrete_x_values_.begin(), it) - 1;

  if (sampling_mode_ == SamplingMode::interpolate &&
      index + 1 < discrete_x_values_.size()) {
    double x1 = discrete_x_values_[index];
    double cdf1 = discrete_cumulative_probabilities_[index];
    double x2 = discrete_x_values_[index + 1];
    double cdf2 = discrete_cumulative_probabilities_[index + 1];

    return linearInterpolate(x, x1, cdf1, x2, cdf2);
  }

  return discrete_cumulative_probabilities_[index];
}

double InverseTransformSampler::evaluateEmpiricalCdf(double x) const {
  if (empirical_data_.empty()) return 0.0;

  auto it = std::upper_bound(empirical_data_.begin(), empirical_data_.end(), x);
  size_t count = std::distance(empirical_data_.begin(), it);

  if (sampling_mode_ == SamplingMode::interpolate && count > 0 &&
      count < empirical_data_.size()) {
    // Linear interpolation for continuous empirical CDF
    double x1 = empirical_data_[count - 1];
    double cdf1 = static_cast<double>(count) /
                  static_cast<double>(empirical_data_.size());
    double x2 = empirical_data_[count];
    double cdf2 = static_cast<double>(count + 1) /
                  static_cast<double>(empirical_data_.size());

    return linearInterpolate(x, x1, cdf1, x2, cdf2);
  }

  return static_cast<double>(count) /
         static_cast<double>(empirical_data_.size());
}

double InverseTransformSampler::evaluateBinsCdf(double x) const {
  if (bin_intervals_.empty()) return 0.0;

  for (size_t ii = 0; ii < bin_intervals_.size(); ++ii) {
    if (x >= bin_intervals_[ii].lower_bound &&
        x <= bin_intervals_[ii].upper_bound) {
      double prev_cdf =
          (ii == 0) ? 0.0 : bin_intervals_[ii - 1].cumulative_probability;
      double curr_cdf = bin_intervals_[ii].cumulative_probability;
      double bin_width =
          bin_intervals_[ii].upper_bound - bin_intervals_[ii].lower_bound;

      if (bin_width == 0.0) {
        return curr_cdf;
      }

      // Linear interpolation within the bin
      double relative_pos = (x - bin_intervals_[ii].lower_bound) / bin_width;
      return prev_cdf + relative_pos * (curr_cdf - prev_cdf);
    } else if (x < bin_intervals_[ii].lower_bound) {
      // x is before this bin
      return (ii == 0) ? 0.0 : bin_intervals_[ii - 1].cumulative_probability;
    }
  }

  // x is after all bins
  return bin_intervals_.back().cumulative_probability;
}

double InverseTransformSampler::sampleFromDiscretePoints() const {
  if (discrete_x_values_.size() == 1) {
    return discrete_x_values_[0];
  }

  const int MAX_RETRIES = 500;

  for (int retry = 0; retry < MAX_RETRIES; ++retry) {
    double u = getUniformRandomNumber();

    if (truncation_bounds_.isActive()) {
      u = truncated_cdf_lower_ +
          u * (truncated_cdf_upper_ - truncated_cdf_lower_);
    }

    // Binary search over CDF values
    auto it = std::lower_bound(discrete_cumulative_probabilities_.begin(),
                               discrete_cumulative_probabilities_.end(), u);

    size_t index = std::distance(discrete_cumulative_probabilities_.begin(), it);
    if (index >= discrete_x_values_.size()) {
      return discrete_x_values_.back();
    }

    double result;

    if (sampling_mode_ == SamplingMode::discrete) {
      result = discrete_x_values_[index];
    } else {
      // interpolate mode
      if (index == 0) {
        // Extrapolate backwards using slope between first two points
        double x1 = discrete_x_values_[0];
        double cdf1 = discrete_cumulative_probabilities_[0];
        double x2 = discrete_x_values_[1];
        double cdf2 = discrete_cumulative_probabilities_[1];

        if (std::abs(cdf2 - cdf1) < std::numeric_limits<double>::epsilon()) {
          // Zero slope
          result = x1;
        } else {
          // Find point (x0, 0) where line from (x1,cdf1) to (x2,cdf2)
          // intersects the x-axis (CDF = 0)
          double x0 = x1 - cdf1 * (x2 - x1) / (cdf2 - cdf1);
          // Interpolate between (x0, 0) and (x1, cdf1)
          result = linearInterpolate(u, 0.0, x0, cdf1, x1);
        }
      } else {
        // Normal interpolation between adjacent points
        double x1 = discrete_x_values_[index - 1];
        double cdf1 = discrete_cumulative_probabilities_[index - 1];
        double x2 = discrete_x_values_[index];
        double cdf2 = discrete_cumulative_probabilities_[index];

        result = linearInterpolate(u, cdf1, x1, cdf2, x2);
      }
    }

    // Validate result is within x_values range
    bool within_x_range = (result >= discrete_x_values_.front() &&
                          result <= discrete_x_values_.back());

    // Validate result is within truncation bounds if active
    bool within_truncation = !truncation_bounds_.isActive() ||
                            (result >= truncation_bounds_.lower_bound &&
                             result <= truncation_bounds_.upper_bound);

    if (within_x_range && within_truncation) {
      return result;
    }

    // Out of bounds - retry
  }

  throw SixteException("Failed to sample value within bounds");
}

double InverseTransformSampler::sampleFromEmpiricalData() const {
  double u = getUniformRandomNumber();

  if (truncation_bounds_.isActive()) {
    u = truncated_cdf_lower_ +
        u * (truncated_cdf_upper_ - truncated_cdf_lower_);
  }

  // Convert u to index in empirical data
  double index_real = u * static_cast<double>(empirical_data_.size());
  auto index = static_cast<size_t>(std::floor(index_real));

  // Bounds checking
  if (index >= empirical_data_.size()) {
    return empirical_data_.back();
  }

  if (sampling_mode_ == SamplingMode::discrete ||
      index + 1 >= empirical_data_.size()) {
    return empirical_data_[index];
  }

  // Continuous interpolation
  double fraction = index_real - static_cast<double>(index);
  double result =
      empirical_data_[index] +
      fraction * (empirical_data_[index + 1] - empirical_data_[index]);

  // Clip to bounds for numerical safety
  if (truncation_bounds_.isActive()) {
    result = std::max(truncation_bounds_.lower_bound,
                      std::min(truncation_bounds_.upper_bound, result));
  }

  return result;
}

double InverseTransformSampler::sampleFromBins() const {
  double u = getUniformRandomNumber();

  if (truncation_bounds_.isActive()) {
    u = truncated_cdf_lower_ +
        u * (truncated_cdf_upper_ - truncated_cdf_lower_);
  }

  // Find the bin using binary search
  auto it = std::lower_bound(bin_intervals_.begin(), bin_intervals_.end(), u,
                             [](const BinInterval& interval, double value) {
                               return interval.cumulative_probability < value;
                             });

  if (it == bin_intervals_.end()) {
    const auto& last_bin = bin_intervals_.back();
    double bin_sample = getUniformRandomNumber();
    return last_bin.lower_bound +
           bin_sample * (last_bin.upper_bound - last_bin.lower_bound);
  }

  // Sample uniformly within the selected bin
  double bin_sample = getUniformRandomNumber();
  double result =
      it->lower_bound + bin_sample * (it->upper_bound - it->lower_bound);

  // Clip to bounds for numerical safety
  if (truncation_bounds_.isActive()) {
    result = std::max(truncation_bounds_.lower_bound,
                      std::min(truncation_bounds_.upper_bound, result));
  }

  return result;
}

double InverseTransformSampler::linearInterpolate(double x, double x1,
                                                  double y1, double x2,
                                                  double y2) {
  double denominator = x2 - x1;
  if (std::abs(denominator) < std::numeric_limits<double>::epsilon()) {
    return y1;  // Return y1 if points are too close
  }
  return y1 + (x - x1) * (y2 - y1) / denominator;
}

void InverseTransformSampler::parseDataFile(const std::string& filename,
                                            std::vector<double>& x_values,
                                            std::vector<double>& weights) {
  std::ifstream file(filename);
  if (!file.is_open()) {
    throw SixteException("Cannot open file: " + filename);
  }

  x_values.clear();
  weights.clear();

  std::string line;
  size_t line_number = 0;

  while (std::getline(file, line)) {
    ++line_number;

    // Skip empty lines and comments
    if (line.empty() || line[0] == '#') {
      continue;
    }

    std::istringstream iss(line);
    double x_val, weight;

    if (!(iss >> x_val >> weight)) {
      throw SixteException("Invalid data format in file '" + filename +
                           "' at line " + std::to_string(line_number) + ": " +
                           line);
    }

    // Validate values
    if (!std::isfinite(x_val) || !std::isfinite(weight)) {
      throw SixteException("Non-finite values in file '" + filename +
                           "' at line " + std::to_string(line_number));
    }

    x_values.push_back(x_val);
    weights.push_back(weight);
  }

  if (x_values.empty()) {
    throw SixteException("No valid data found in file: " + filename);
  }
}

}  // namespace sixte