/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2023 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include "NewRMF.h"

#include "SixteException.h"
#include "sixte_random.h"
#include <cstdlib>
#include <cstring>

namespace sixte {

NewRMF::NewRMF(const std::string& rmf_filename)
    : source_filename_(rmf_filename),
      rmf_(loadRMFWrapper(rmf_filename)) {
  cumulative_response_buffer_.resize(numberChannels());
  used_channels_.reserve(1000);
  energy_grid_high_.resize(numberEnergyBins());

  for (long ii = 0; ii < numberEnergyBins(); ++ii) {
    energy_grid_high_[ii] = rmf_->HighEnergy[ii];
  };

  has_matrix_ = true;
}

NewRMF::NewRMF()
    : rmf_(getRMFWrapper()) {}

NewRMF::NewRMF(const Ebounds& eb)
    : rmf_(getRMFWrapper()) {
  has_matrix_ = false;

  const size_t num_channels = eb.numberChannels();
  const long first_channel = eb.firstChannel();

  // Basic layout for an EBOUNDS-only RMF
  rmf_->NumberChannels = static_cast<long>(num_channels);
  rmf_->NumberEnergyBins = static_cast<long>(num_channels);  // keep axes consistent
  rmf_->FirstChannel = first_channel;
  rmf_->NumberTotalGroups = 0;
  rmf_->NumberTotalElements = 0;
  rmf_->isOrder = 0;

  // Matrix-related pointers stay null
  rmf_->NumberGroups = nullptr;
  rmf_->FirstGroup = nullptr;
  rmf_->FirstChannelGroup = nullptr;
  rmf_->NumberChannelGroups = nullptr;
  rmf_->FirstElement = nullptr;
  rmf_->OrderGroup = nullptr;
  rmf_->Matrix = nullptr;

  auto alloc_array = [num_channels](float*& ptr) {
    ptr = static_cast<float*>(std::malloc(num_channels * sizeof(float)));
    if (!ptr) throw SixteException("NewRMF(Ebounds): allocation failed");
  };

  alloc_array(rmf_->ChannelLowEnergy);
  alloc_array(rmf_->ChannelHighEnergy);
  alloc_array(rmf_->LowEnergy);
  alloc_array(rmf_->HighEnergy);

  for (size_t ii = 0; ii < num_channels; ++ii) {
    const double lo = eb.energyLo(ii);
    const double hi = eb.energyHi(ii);
    rmf_->ChannelLowEnergy[ii] = static_cast<float>(lo);
    rmf_->ChannelHighEnergy[ii] = static_cast<float>(hi);
    rmf_->LowEnergy[ii] = static_cast<float>(lo);
    rmf_->HighEnergy[ii] = static_cast<float>(hi);
  }

  rmf_->AreaScaling = 1.0f;
  rmf_->ResponseThreshold = 0.0f;

  std::memset(rmf_->ChannelType, 0, sizeof(rmf_->ChannelType));
  std::memset(rmf_->RMFVersion, 0, sizeof(rmf_->RMFVersion));
  std::memset(rmf_->EBDVersion, 0, sizeof(rmf_->EBDVersion));
  std::memset(rmf_->Telescope, 0, sizeof(rmf_->Telescope));
  std::memset(rmf_->Instrument, 0, sizeof(rmf_->Instrument));
  std::memset(rmf_->Detector, 0, sizeof(rmf_->Detector));
  std::memset(rmf_->Filter, 0, sizeof(rmf_->Filter));
  std::memset(rmf_->RMFType, 0, sizeof(rmf_->RMFType));
  std::memset(rmf_->RMFExtensionName, 0, sizeof(rmf_->RMFExtensionName));
  std::memset(rmf_->EBDExtensionName, 0, sizeof(rmf_->EBDExtensionName));
}

void NewRMF::sixteLoadEbounds(const std::string& filename) {
  int status = EXIT_SUCCESS;
  source_filename_ = filename;
  loadEbounds(rmf_.get(), const_cast<char *>(filename.c_str()), &status);
  checkStatusThrow(status, "Failed to load EBOUNDS");
}

long NewRMF::sixteGetEBOUNDSChannel(double energy) const {
  return getEBOUNDSChannel((float)energy, rawPtr());
}

std::pair<double, double> NewRMF::sixteGetEBOUNDSEnergyLoHi (long channel) const {
  int status = EXIT_SUCCESS;
  float low, high;
  getEBOUNDSEnergyLoHi((long) channel, rawPtr(), &low, &high, &status);
  checkStatusThrow(status, "Failed to get channel energies");
  return std::make_pair(low, high);
}

double NewRMF::sixteGetEBOUNDSEnergy(long channel) const {
  auto low_high = sixteGetEBOUNDSEnergyLoHi(channel);
  auto rand_num = getUniformRandomNumber();
  double rand_num_between_low_and_high = rand_num*low_high.first + (1.0-rand_num)*low_high.second;
  return rand_num_between_low_and_high;
}

long NewRMF::sampleChannel(double energy) const {
  if (!has_matrix_) {
    throw SixteException("NewRMF::sampleChannel called on EBOUNDS-only RMF");
  }

  // Check if energy is outside response range
  if (energy < rmf_->LowEnergy[0] || energy > rmf_->HighEnergy[numberEnergyBins() - 1]) {
    return -1;
  }

  // Find energy bin using binary search
  auto it = std::lower_bound(energy_grid_high_.begin(), energy_grid_high_.end(), energy);
  if (it == energy_grid_high_.end()) {
    throw SixteException("Binary search failed to find energy bin");
  }
  const long energy_bin = std::distance(energy_grid_high_.begin(), it);

  // Clear only previously used channels in cumulative_response_buffer_
  for (long channel : used_channels_) {
    cumulative_response_buffer_[channel] = 0.0;
  }
  used_channels_.clear();

  // Build cumulative response for this specific energy bin
  const size_t number_channels = numberChannels();
  const long first_channel = firstChannel();
  for (long ii = 0; ii < rmf_->NumberGroups[energy_bin]; ++ii) {
    const long igrp = ii + rmf_->FirstGroup[energy_bin];

    for (long jj = 0; jj < rmf_->NumberChannelGroups[igrp]; ++jj) {
      const long ichan = jj + rmf_->FirstChannelGroup[igrp] - first_channel;
      if (ichan >= 0 && ichan < static_cast<long>(number_channels)) {
        cumulative_response_buffer_[ichan] = rmf_->Matrix[jj + rmf_->FirstElement[igrp]];
        used_channels_.push_back(ichan);
      }
    }
  }

  if (used_channels_.empty()) {
    return -1;  // no channels contribute in this energy bin
  }

  if (!std::is_sorted(used_channels_.begin(), used_channels_.end())) {
    std::sort(used_channels_.begin(), used_channels_.end());
  }

  // Sanity check (overlapping channel groups within a row)
  auto dup = std::adjacent_find(used_channels_.begin(), used_channels_.end());
  if (dup != used_channels_.end()) {
    throw SixteException("RMF row has overlapping channel groups");
  }
  
  double cumulative_sum = 0.0;
  for (long channel : used_channels_) {
    cumulative_sum += cumulative_response_buffer_[channel];
    cumulative_response_buffer_[channel] = cumulative_sum;
  }

  if (cumulative_sum <= 0.0) {
    return -1;
  }

  double random_number = getUniformRandomNumber();

  // Check if random number exceeds total response
  if (random_number > cumulative_sum) {
    return -1;
  }

  // Find channel using binary search on used channels with cumulative response
  auto channel_it = std::lower_bound(used_channels_.begin(), used_channels_.end(),
                                     random_number,
                                     [this](long channel, double value) {
                                       return cumulative_response_buffer_[channel] < value;
                                     });
  // This should never happen given the prior check, but still guard against it
  if (channel_it == used_channels_.end()) {
    throw SixteException("Binary search failed to find channel");
  }
  const long channel_index = *channel_it;

  // Correct for the first channel number in use
  return channel_index + first_channel;
}

long NewRMF::firstChannel() const {
  return rmf_->FirstChannel;
}

size_t NewRMF::numberChannels() const {
  return rmf_->NumberChannels;
}

long NewRMF::numberEnergyBins() const {
  return rmf_->NumberEnergyBins;
}

Ebounds NewRMF::exportEbounds() const {
  const size_t num_channels = numberChannels();
  std::vector<double> lo(num_channels), hi(num_channels);

  for (size_t ii = 0; ii < num_channels; ++ii) {
    lo[ii] = rmf_->ChannelLowEnergy[ii];
    hi[ii] = rmf_->ChannelHighEnergy[ii];
  }

  return {std::move(lo), std::move(hi), firstChannel(), source_filename_};
}

std::shared_ptr<Ebounds> loadEboundsOnly(const std::string& abs_path) {
  NewRMF rmf;
  rmf.sixteLoadEbounds(abs_path);
  return std::make_shared<Ebounds>(rmf.exportEbounds());
}

const struct RMF* NewRMF::rawPtr() const {
  return rmf_.get();
}

void RMFDeleter::operator()(struct RMF* rmf) {
  freeRMF(rmf);
}

struct RMF* loadRMFWrapper(const std::string& rmf_filename) {
  int status = EXIT_SUCCESS;
  std::string rmf_filename_tmp(rmf_filename);

  struct RMF* rmf = loadRMF(rmf_filename_tmp.data(), &status);
  checkStatusThrow(status, "Failed to load RMF from " + rmf_filename);

  return rmf;
}

struct RMF* getRMFWrapper() {
  int status = EXIT_SUCCESS;

  struct RMF* rmf = getRMF(&status);
  checkStatusThrow(status, "Failed to initialize RMF");

  return rmf;
}

}