/*
   This file is part of the RELXILL model code.

   RELXILL is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   RELXILL is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.
   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.

    Copyright 2022 Thomas Dauser, Remeis Observatory & ECAP
*/
#include "Relbase.h"
#include "Xillspec.h"
#include "IonGradient.h"


extern "C" {
#include "writeOutfiles.h"
}

#include <utility>

// new CACHE routines
cnode *cache_relbase = nullptr;


int save_1eV_pos = 0;


specCache *global_spec_cache = nullptr;


static specCache *new_specCache(int n_cache, int *status) {

  auto *spec = new specCache;

  spec->n_cache = n_cache;
  spec->nzones = 0;
  spec->n_ener = N_ENER_CONV;

  spec->conversion_factor_energyflux = nullptr;

  spec->fft_xill = new double**[n_cache];
  spec->fft_rel = new double**[n_cache];

  spec->fftw_xill = new fftw_complex*[n_cache];
  spec->fftw_rel = new fftw_complex*[n_cache];

  spec->fftw_backwards_input = new fftw_complex[spec->n_ener];
  spec->fftw_output = new double[spec->n_ener];

  spec->plan_c2r = fftw_plan_dft_c2r_1d(spec->n_ener, spec->fftw_backwards_input, spec->fftw_output,FFTW_ESTIMATE);

  spec->xill_spec = new xillSpec*[n_cache];

  int ii;
  int jj;
  int m = 2;
  for (ii = 0; ii < n_cache; ii++) {
    spec->fft_xill[ii] = new double*[m];
    spec->fft_rel[ii] = new double*[m];

    spec->fftw_xill[ii] = new fftw_complex[spec->n_ener];
    spec->fftw_rel[ii] = new fftw_complex[spec->n_ener];

   for (jj = 0; jj < m; jj++) {
      spec->fft_xill[ii][jj] = new double[spec->n_ener];
      spec->fft_rel[ii][jj] = new double[spec->n_ener];
    }
    spec->xill_spec[ii] = nullptr;
  }
  spec->out_spec = nullptr;

  return spec;
}

static void init_specCache(specCache **spec, const int n_zones, int *status) {
  if ((*spec) == nullptr) {
    (*spec) = new_specCache(n_zones, status);
  }
}


specCache *init_global_specCache(int *status) {
  init_specCache(&global_spec_cache, N_ZONES_MAX, status);
  CHECK_RELXILL_ERROR("failed initializing Relconv Spec Cache", status);
  return global_spec_cache;
}

static double* calculate_energyflux_conversion(const double* ener, int n_ener, int* status){

  auto* factor = new double[n_ener];
  CHECK_MALLOC_RET_STATUS(factor, status, nullptr)

  for(int ii=0; ii<n_ener; ii++){
    factor[ii] = 0.5*(ener[ii]+ener[ii+1]) / (ener[ii+1] - ener[ii]);
  }

  return factor;
}








/** @brief FFTW VERSION: convolve the (bin-integrated) spectra f1 and f2 (which need to have a certain binning)
 *  @details fout: gives the output
 *  f1 input (reflection) specrum
 *  f2 filter
 *  ener has length n+1 and is the energy array
 *  requirements: needs "specCache" to be set up
 * **/
void fftw_conv_spectrum(double *ener, const double *fxill, const double *frel, double *fout, int n,
                       int re_rel, int re_xill, int izone, specCache *cache, int *status) {

  CHECK_STATUS_VOID(*status);

  // needs spec cache to be set up
  assert(cache != nullptr);

  if (cache->conversion_factor_energyflux == nullptr){
    cache->conversion_factor_energyflux = calculate_energyflux_conversion(ener, n, status);
  }


  /* need to find out where the 1keV for the filter is, which defines if energies are blue or redshifted*/
  if (save_1eV_pos == 0 ||
      (!((ener[save_1eV_pos] <= 1.0) &&
          (ener[save_1eV_pos + 1] > 1.0)))) {
    save_1eV_pos = binary_search(ener, n + 1, 1.0);
  }

  int ii;
  int irot;

  /**********************************************************************/
  /** cache either the m_cache_relat. or the xillver part, as only one of the
   * two changes most of the time (reduce time by 1/3 for convolution) **/
  /**********************************************************************/

  /** #1: for the xillver part **/
  if (re_xill != 0) {
    for (ii = 0; ii < n; ii++) {
      cache->fft_xill[izone][0][ii] = fxill[ii] * cache->conversion_factor_energyflux[ii] ;
    }

    fftw_plan plan_xill = fftw_plan_dft_r2c_1d(n, cache->fft_xill[izone][0], cache->fftw_xill[izone],FFTW_ESTIMATE);
    fftw_execute(plan_xill);
    fftw_destroy_plan(plan_xill);
  }

  /** #2: for the m_cache_relat. part **/
  if (re_rel != 0) {
    for (ii = 0; ii < n; ii++) {
      irot = (ii - save_1eV_pos + n) % n;
      cache->fft_rel[izone][0][irot] = frel[ii] * cache->conversion_factor_energyflux[ii];
    }

    fftw_plan plan_rel = fftw_plan_dft_r2c_1d(n, cache->fft_rel[izone][0], cache->fftw_rel[izone],FFTW_ESTIMATE);
    fftw_execute(plan_rel);
    fftw_destroy_plan(plan_rel);
  }

  // complex multiplication
  for (ii = 0; ii < n; ii++) {
    cache->fftw_backwards_input[ii][0] =
        cache->fftw_xill[izone][ii][0] * cache->fftw_rel[izone][ii][0] -
        cache->fftw_xill[izone][ii][1] * cache->fftw_rel[izone][ii][1];

    cache->fftw_backwards_input[ii][1] =
        cache->fftw_xill[izone][ii][0] * cache->fftw_rel[izone][ii][1] +
            cache->fftw_xill[izone][ii][1] * cache->fftw_rel[izone][ii][0];

  }

  fftw_execute(cache->plan_c2r);

  for (ii = 0; ii < n; ii++) {
    fout[ii] = cache->fftw_output[ii] /  cache->conversion_factor_energyflux[ii]; 
  }

}



/**
 * @Function: calcFFTNormFactor
 * @Synopsis: calculate the normalization of the FFT, which is defined to keep the normalization of the
 *           input spectrum and the m_cache_relat. smearing
 * Take the sum in the given energy band of interested, to avoid problems at the border of the FFT
 * convolution.
 */
double calcFFTNormFactor(const double *ener, const double *fxill, const double *frel, const double *fout, int n) {

  double sum_relline = 0.0;
  double sum_xillver = 0.0;
  double sum_conv = 0.0;
  for (int jj = 0; jj < n; jj++) {
    if (ener[jj] >= EMIN_XILLVER && ener[jj + 1] < EMAX_XILLVER) {
      sum_xillver += fxill[jj];
      sum_relline += frel[jj];
      sum_conv += fout[jj];
    }
  }

  return sum_relline * sum_xillver / sum_conv;
}

void normalizeFFTOutput(const double *ener, const double *fxill, const double *frel, double *fout, int n) {
  double norm_fac = calcFFTNormFactor(ener, fxill, frel, fout, n);

  for (int ii = 0; ii < n; ii++) {
    fout[ii] *= norm_fac;
  }

}
void convolveSpectrumFFTNormalized(double *ener, const double *fxill, const double *frel, double *fout, int n,
                                   int re_rel, int re_xill, int izone, specCache *spec_cache_ptr, int *status) {

  fftw_conv_spectrum(ener, fxill, frel, fout, n, re_rel, re_xill, izone, spec_cache_ptr, status);

  normalizeFFTOutput(ener, fxill, frel, fout, n);

}


void set_flux_outside_defined_range_to_zero(const double* ener, double* spec, int n_ener, double emin, double emax){
  int warned = 0;
  for (int ii=0; ii<n_ener; ii++){
    if (ener[ii+1]<emin || ener[ii]>emax){
      if (is_debug_run() && warned==0){
        printf(" *** warning: relconv applied outside the allowed energy range %.2f-%.0f\n",
               RELCONV_EMIN, RELCONV_EMAX);
        printf("     -> values outside are set to zero\n\n");
        warned=1;
      }
      spec[ii] = 0;
    }
  }
}


/**
 * @brief basic relconv function: convolve any input spectrum with the relbase kernel
 * @description
 *   it is only defined the in energy range of 0.01-1000 keV (see RELCONV_EMIN, RELCONV_EMAX variables)
 *   and zero outside this range
 * @param double[n_ener_inp+1] ener_inp
 * @param double[n_ener_inp] spec_ener_inp
 *  **/
void relconv_kernel(double *ener_inp, double *spec_inp, int n_ener_inp, relParam *rel_param, int *status) {

  // get the (fixed!) energy grid for a RELLINE for a convolution
  // -> as we do a simple FFT, we can now take into account that we
  // need it to be number = 2^N */
  EnerGrid *ener_grid = get_relxill_conv_energy_grid();
  // const int n_ener = ener_grid->num_flux_bins;
  // const double *ener = ener_grid->ener;

  relline_spec_multizone *rel_profile = relbase(ener_grid->ener, ener_grid->nbins, rel_param, status);

  // simple convolution only makes sense for 1 zone !
  assert(rel_profile->n_zones == 1);

  auto rebin_flux = new double[ener_grid->nbins];
  _rebin_spectrum(ener_grid->ener, rebin_flux, ener_grid->nbins, ener_inp, spec_inp, n_ener_inp);

  specCache* spec_cache = init_global_specCache(status);
  CHECK_STATUS_VOID(*status);
  auto conv_out = new double[ener_grid->nbins];
  convolveSpectrumFFTNormalized(ener_grid->ener, rebin_flux, rel_profile->flux[0], conv_out, ener_grid->nbins,
                    1, 1, 0, spec_cache, status);
  CHECK_STATUS_VOID(*status);

  // rebin to the output grid
  _rebin_spectrum(ener_inp, spec_inp, n_ener_inp, ener_grid->ener, conv_out, ener_grid->nbins);

  set_flux_outside_defined_range_to_zero(ener_inp, spec_inp, n_ener_inp, RELCONV_EMIN, RELCONV_EMAX);

  delete[] rebin_flux;
  delete[] conv_out;

}




void add_primary_component(double *ener, int n_ener, double *flu, relParam *rel_param, xillParam *xill_input_param,
                           RelSysPar *sys_par, int *status) {

  xillTableParam *xill_table_param = get_xilltab_param(xill_input_param, status);
  double *pl_flux = calc_normalized_primary_spectrum(ener, n_ener, rel_param, xill_table_param, status);
  free(xill_table_param);
  CHECK_STATUS_VOID(*status);

  // For the non-relativistic model and if not a geometric corona, we simply multiply by the reflection fraction
  if (is_xill_model(xill_input_param->model_type) || !is_any_primary_source(rel_param->emis_type)) {
    for (int ii = 0; ii < n_ener; ii++) {
      flu[ii] *= fabs(xill_input_param->refl_frac);
    }

  } else { // we are in the LP geometry

    assert(rel_param != nullptr);

    lpReflFrac *struct_refl_frac = sys_par->emis->photon_fate_fractions;

    if (xill_input_param->interpret_reflfrac_as_boost) {
      // if set, it is given as boost, wrt predicted refl_frac
      xill_input_param->refl_frac *= struct_refl_frac->refl_frac;
    }

    double g_inf = energy_shift_source_obs(rel_param);
    double prim_fac = struct_refl_frac->f_inf_rest / 0.5 * pow(g_inf, xill_input_param->gam);
    //}

    if (rel_param->beta
        > 1e-4) { // flux boost of primary radiation taking into account here (therefore we need f_inf_rest above)
      prim_fac *= pow(doppler_factor_source_obs(rel_param), 2);
    }

    // if the user sets the refl_frac parameter manually, we need to calculate the ratio
    // to end up with the correct normalization
    double norm_fac_refl = (fabs(xill_input_param->refl_frac)) / struct_refl_frac->refl_frac;

    for (int ii = 0; ii < n_ener; ii++) {
      pl_flux[ii] *= prim_fac;
      flu[ii] *= norm_fac_refl;
    }

    /** 5 ** if desired, we ouput the reflection fraction and strength (as defined in Dauser+2016) **/
    if (shouldAuxInfoGetPrinted()) {
     // print_reflection_strength(ener, n_ener, flu, rel_param, xill_input_param, pl_flux, struct_refl_frac);
    }

  }

  // Finally, add the power law component if refl_frac >= 0
  if (xill_input_param->refl_frac >= 0) {
    for (int ii = 0; ii < n_ener; ii++) {
      flu[ii] += pl_flux[ii];
    }
  }

  delete[] pl_flux;

}


/**
 * @brief Check if any xillver parameters have changed between two parameter sets
 * @param cpar Current parameters
 * @param par New parameters
 * @return 1 if any parameter changed, 0 if all are the same
 */
int did_xill_param_change(const xillParam *cpar, const xillParam *par) {
  std::vector<std::pair<double, double>> param_comparison = {
    {par->afe, cpar->afe},
    {par->dens, cpar->dens},
    {par->ect, cpar->ect},
    {par->gam, cpar->gam},
    {par->lxi, cpar->lxi},
    {par->kTbb, cpar->kTbb},
    {par->frac_pl_bb, cpar->frac_pl_bb},
    {par->radzone, cpar->radzone},
    {par->z, cpar->z},
    {par->iongrad_index, cpar->iongrad_index},
    {par->distance, cpar->distance},
    {par->mass_msolar, cpar->mass_msolar},
    {par->luminosity_primary_source, cpar->luminosity_primary_source},
    {static_cast<double>(par->prim_type), static_cast<double>(cpar->prim_type)},
    {static_cast<double>(par->model_type), static_cast<double>(cpar->model_type)}
  };

  for (const auto& param : param_comparison)
  {
    if (are_values_different(param.first, param.second))
    {
      return 1;
    }
  }

  return 0;
}


/* check if values, which need a re-computation of the relline profile, have changed */
int redo_xillver_calc(const relParam *rel_param, const xillParam *xill_param,
                      const relParam *ca_rel_param, const xillParam *ca_xill_param) {

  int redo = 1;

  if ((ca_rel_param != nullptr) && (ca_xill_param != nullptr)) {

    redo = did_xill_param_change(ca_xill_param, xill_param);

    // xillver needs to be re-computed, Ecut changes for the following parameters **/
    if (are_values_different(rel_param->a, ca_rel_param->a) ||
        are_values_different(rel_param->height, ca_rel_param->height) ||
        are_values_different(rel_param->r_src, ca_rel_param->r_src) ||
        are_values_different(rel_param->theta, ca_rel_param->theta) ||
        are_values_different(rel_param->beta, ca_rel_param->beta)) {
      redo = 1;
    }

    // special case for the alpha model: the reflection fraction determines the incident flux and therfore
    // also the ionization, therefore we need to take care of this parameter
    if (is_alpha_model(xill_param->model_type)
        && are_values_different(xill_param->refl_frac, ca_xill_param->refl_frac)) {
      redo = 1;
    }

  }

  return redo;
}

int redo_relbase_calc(const relParam *rel_param, const relParam *ca_rel_param) {

  if (did_rel_param_change(ca_rel_param, rel_param)) {
    return 1;
  } else {
    return 0;
  }

}

void write_output_rel_param(relParam *pa) {
  printf(" - a = %e \n", pa->a);
  printf(" - height = %e\n", pa->height);
  printf(" - Rin = %e\n", pa->rin);
  printf(" - Rout = %e\n", pa->rout);
  printf(" - incl = %e\n", pa->incl);
  printf(" - beta = %e\n", pa->beta);
  printf(" - gamma = %e\n", pa->gamma);
}

/** @brief relbase function calculating the basic relativistic line shape for a given parameter setup
 *  @details
 *    - assuming a 1keV line, by a grid given in keV!
 *    - it is cached
 * input: ener(n_ener), param
 * input: RelSysPar
 * optional input: xillver grid
 * output: photar(n_ener)  [photons/bin]
**/
relline_spec_multizone* relbase_profile(double *ener, int n_ener, relParam *param,
                                       RelSysPar *sysPar,
                                       xillTable *xill_tab,
                                       const VecD& radialZones,  // TODO: make this a vector as well
                                       int nzones,
                                       int *status) {

  assert(radialZones.size() == nzones +1 );

  inpar* inp = get_inputvals_struct(ener, n_ener, param, status);
  cache_info *ca_info = cli_check_cache(cache_relbase, inp, check_cache_relpar, status);
  relline_spec_multizone *spec = nullptr;

  // set a pointer to the spectrum
  if (is_relbase_cached((ca_info)) == 0) {

    // init the spectra where we store the flux
    param->num_zones = nzones;
    init_relline_spec_multizone(&spec, param, xill_tab, radialZones, &ener, n_ener, status);

    calc_relline_profile(spec, sysPar, status); // returned units are 'photons/bin'

    if (*status != EXIT_SUCCESS) {
      printf(" *** error: calculation of relline profile failed \n");
      write_output_rel_param(param);
      throw std::exception();
    }

    // normalize it and calculate the angular distribution (if necessary)
    renorm_relline_profile(spec, param, status);

    // last step: store parameters and cached relline_spec_multizone (this prepends a new node to the cache)
    add_relspec_to_cache(&cache_relbase, param, spec, status);
    if (is_debug_run() && *status == EXIT_SUCCESS) {
      printf(" DEBUG:  Adding new RELBASE eval to cache; the count is %i \n", cli_count_elements(cache_relbase));
    }
  } else {
    if (is_debug_run()) {
      printf(" DEBUG:  RELBASE-RelxillCache: re-using calculated values\n");
    }
    spec = ca_info->store->data->relbase_spec;
  }

  if (shouldOutfilesBeWritten()) {
    save_emis_profiles(sysPar);
    save_relline_profile(spec);
  }

  free(inp);
  free(ca_info);

  return spec;
}


/** get a radial grid on the accretion disk in order to calculate a relline for each zone **/
std::vector<double> get_rzone_grid(double rmin, double rmax, int nzones, double h, int *status) {

  std::vector<double> rgrid( nzones + 1);

  if (nzones == 1) {
    rgrid[0] = rmin;
    rgrid[1] = rmax;
  } else {

    double r_transition = rmin;
    int indr = 0;

    // if h > rmin we choose a log grid for r<h
    if (h > rmin) {

      r_transition = h;

      get_log_grid(rgrid.data(), rgrid.size(), rmin, rmax);
      indr = binary_search(rgrid.data(), rgrid.size(), r_transition);

      r_transition = rgrid[indr];

    }

    if (indr < nzones) {

      double rlo = r_transition;
      double rhi = rmax; // radius[nzones];
      // add 1/r for larger radii
      int ii;
      for (ii = indr; ii < nzones + 1; ii++) {
        rgrid[ii] = 1.0 * (ii - indr) / (nzones - indr) * (1.0 / rhi - 1.0 / rlo) + 1.0 / rlo;
        rgrid[ii] = fabs(1.0 / rgrid[ii]);
      }

    }

  }

  return rgrid;
}



/** @brief relbase wrapper function, calculating the m_cache_relat system params plus the relbase profile
 *  @details
 *    - for more details see relbase_profile function
 *    - uses only a single zone on the disk
 *    - not used for any relxill-case (xill_table=nullptr, as no angular dependency is taken into account)
 * input: ener(n_ener), param
 * optional input: xillver grid
 * output: photar(n_ener)  [photons/bin]
**/
relline_spec_multizone *relbase(double *ener, const int n_ener, relParam *param, int *status) {

  // initialize parameter values (has an internal cache)
  RelSysPar *sysPar = get_system_parameters(param, status);
  CHECK_STATUS_RET(*status, nullptr);
  assert(sysPar != nullptr);

  auto rgrid = get_standard_relxill_radial_grid(*param);
  relline_spec_multizone* rel_spec = relbase_profile(ener, n_ener, param, sysPar, nullptr, rgrid.radius,
                                                     param->num_zones, status);

  return rel_spec;
}



void free_rel_cosne(RelCosne *spec) {
  if (spec != nullptr) {
    //	free(spec->ener);  we do not need this, as only a pointer for ener is assigned
    free(spec->cosne);
    if (spec->dist != nullptr) {
      int ii;
      for (ii = 0; ii < spec->n_zones; ii++) {
        free(spec->dist[ii]);
      }
    }
    free(spec->dist);
    free(spec);
  }
}

void free_rel_spec(relline_spec_multizone *spec) {
  if (spec != nullptr) {
    free(spec->ener);
    delete[] spec->rgrid;
    if (spec->flux != nullptr) {
      int ii;
      for (ii = 0; ii < spec->n_zones; ii++) {
        if (spec->flux[ii] != nullptr) {
          free(spec->flux[ii]);
        }
      }
    }
    free(spec->flux);
    if (spec->rel_cosne != nullptr) {
      free_rel_cosne(spec->rel_cosne);
    }
    delete spec;
  }
}

void free_cached_tables() {
  free_relprofile_cache();

  free_cached_relTable();
  free_cached_lpTable();
  free_all_cached_xillTables();
  // this function is not used, but just in case - here is the routine to free ext table(s)
  //free_extendedSourceTable

  // TODO: implement cache in a general way
 // free(cached_rel_param);
 // free(cached_xill_param);

  free_specCache(global_spec_cache);

  // TODO, implement free of this global energy grid
  //free(global_energy_grid_relxill);
  // free(global_ener_xill);

}

void free_fft_cache(double ***sp, int n1, int n2) {

  int ii;
  int jj;
  if (sp != nullptr) {
    for (ii = 0; ii < n1; ii++) {
      if (sp[ii] != nullptr) {
        for (jj = 0; jj < n2; jj++) {
          free(sp[ii][jj]);
        }
      }
      free(sp[ii]);
    }
    free(sp);
  }

}

spectrum *new_spectrum(int n_ener, const double *ener, int *status) {

  auto *spec = new spectrum;
  spec->n_ener = n_ener;
  spec->ener = new double[n_ener];
  spec->flux = new double[n_ener];

  int ii;
  for (ii = 0; ii < n_ener; ii++) {
    spec->ener[ii] = ener[ii];
    spec->flux[ii] = 0.0;
  }

  return spec;
}

void free_spectrum(spectrum *spec) {
  if (spec != nullptr) {
    delete[]spec->ener;
    delete[] spec->flux;
    delete spec;
  }
}

void free_fftw_complex_cache(fftw_complex** val, int n){
  for(int ii=0; ii<n; ii++){
    fftw_free(val[ii]);
  }
}

void free_specCache(specCache* spec_cache) {

  int ii;
  int m = 2;
  if (spec_cache != nullptr) {
    if (spec_cache->xill_spec != nullptr) {
      for (ii = 0; ii < spec_cache->n_cache; ii++) {
        if (spec_cache->xill_spec[ii] != nullptr) {
          free_xill_spec(spec_cache->xill_spec[ii]);
        }
      }
      free(spec_cache->xill_spec);
    }

    if (spec_cache->fft_xill != nullptr) {
      free_fft_cache(spec_cache->fft_xill, spec_cache->n_cache, m);
    }

    if (spec_cache->fftw_rel != nullptr) {
      free_fft_cache(spec_cache->fft_rel, spec_cache->n_cache, m);
    }

    free_fftw_complex_cache(spec_cache->fftw_rel, spec_cache->n_cache);
    free_fftw_complex_cache(spec_cache->fftw_xill, spec_cache->n_cache);
    fftw_free(spec_cache->fftw_backwards_input);
    fftw_destroy_plan(spec_cache->plan_c2r);
    free(spec_cache->fftw_output);

    if (spec_cache->conversion_factor_energyflux != nullptr){
      free(spec_cache->conversion_factor_energyflux);
    }

    free_spectrum(spec_cache->out_spec);

  }

  free(spec_cache);

}

/** free the CLI cache **/

void free_cache() {
  free_cache_syspar();
  cli_delete_list(&cache_relbase);
}


