/*
This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2025 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#pragma once

#include <cmath>
#include <limits>
#include <vector>

#include "SixteException.h"

namespace sixte {
enum class SamplingMode { discrete, interpolate };

enum class InputType {
  discrete_points,  // (x, weight) pairs
  empirical_data,   // Raw sample data for empirical CDF
  bins              // (lo, hi, weight) uniform bins
};

struct TruncationBounds {
  double lower_bound = -std::numeric_limits<double>::infinity();
  double upper_bound = std::numeric_limits<double>::infinity();

  TruncationBounds() = default;
  TruncationBounds(double lower, double upper)
      : lower_bound(lower), upper_bound(upper) {}

  [[nodiscard]] bool isActive() const {
    return std::isfinite(lower_bound) || std::isfinite(upper_bound);
  }
};

class InverseTransformSampler {
 public:
  // Constructor from a file with (x, weight) pairs
  InverseTransformSampler(const std::string& filename, SamplingMode mode,
                          const TruncationBounds& bounds = TruncationBounds{});

  // Constructor from vectors with (x, weight) pairs
  InverseTransformSampler(const std::vector<double>& x_values,
                          const std::vector<double>& weights, SamplingMode mode,
                          const TruncationBounds& bounds = TruncationBounds{});

  // Constructor from empirical data sample
  InverseTransformSampler(const std::vector<double>& sample_data,
                          SamplingMode mode,
                          const TruncationBounds& bounds = TruncationBounds{});

  // Constructor from bin ranges with weights
  InverseTransformSampler(const std::vector<double>& lower_bins,
                          const std::vector<double>& upper_bins,
                          const std::vector<double>& weights, SamplingMode mode,
                          const TruncationBounds& bounds = TruncationBounds{});

  [[nodiscard]] double sample() const;

 private:
  struct BinInterval {
    double lower_bound;
    double upper_bound;
    double cumulative_probability;
  };

  // Data
  std::vector<double> discrete_x_values_;
  std::vector<double> discrete_cumulative_probabilities_;
  std::vector<BinInterval> bin_intervals_;
  std::vector<double> empirical_data_;

  // Configuration
  SamplingMode sampling_mode_;
  InputType input_type_;
  TruncationBounds truncation_bounds_;
  double truncated_cdf_lower_;
  double truncated_cdf_upper_;

  // Initialization
  void initializeFromDiscretePoints(const std::vector<double>& x_values,
                                    const std::vector<double>& weights);
  void initializeFromEmpiricalData(const std::vector<double>& sample_data);
  void initializeFromBins(const std::vector<double>& lower_bins,
                          const std::vector<double>& upper_bins,
                          const std::vector<double>& weights);
  void initializeFromFile(const std::string& filename);

  // Utility
  void validateInput() const;
  void computeTruncatedBounds();
  static void validateAndNormalizeWeights(std::vector<double>& weights);
  static void mergeDuplicatePoints(std::vector<double>& x_values,
                                   std::vector<double>& weights);
  static void validateBinsSorted(const std::vector<double>& lower_bins,
                                 const std::vector<double>& upper_bins);

  // CDF evaluation
  [[nodiscard]] double evaluateDiscretePointsCdf(double x) const;
  [[nodiscard]] double evaluateEmpiricalCdf(double x) const;
  [[nodiscard]] double evaluateBinsCdf(double x) const;

  // Sampling
  [[nodiscard]] double sampleFromDiscretePoints() const;
  [[nodiscard]] double sampleFromEmpiricalData() const;
  [[nodiscard]] double sampleFromBins() const;

  // Interpolation
  [[nodiscard]] static double linearInterpolate(double x, double x1, double y1,
                                                double x2, double y2);

  // File parsing
  static void parseDataFile(const std::string& filename,
                            std::vector<double>& x_values,
                            std::vector<double>& weights);
};

}  // namespace sixte