/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2023 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#pragma once


#include "ArrayGeometry.h"
#include "Signal.h"
#include "SixteLinterp.h"
#include <algorithm>
#include <deque>

using linterp::InterpMultilinear;

namespace sixte {

/** Generic, virtual class describing crosstalk mechanisms */
class CrosstalkType {
  public:
    /** Compute the energy shift caused by crosstalk.
     *
     * @param i_vic        Victim Pixel ID
     * @param i_perp       Perpetrator Pixel ID
     * @param e_vic        Victim Signal Energy [keV]
     * @param e_perp       Perpetrator Signal Energy [keV]
     * @param perp_delay   Delay of perpetrator signal wrt. victim [s]
     * @param grade_id ID  of the victim grad3
     *
     * @return             crosstalk energy shift [keV]
     */
    [[nodiscard]] virtual double computeEnergyShift(
        T_PixId i_vic, T_PixId i_perp,
        double e_vic, double e_perp,
        double perp_delay, unsigned int grade_id) = 0;

    virtual ~CrosstalkType() = default;

    /** Check whether two pixels are coupled by this crosstalk type
     *
     * @param i_vic   Victim Pixel ID
     * @param i_perp  Perpetrator Pixel ID
     *
     * @return        True if pixels are coupled, False if not
     */
    [[nodiscard]] virtual bool isCoupled(T_PixId i_vic, T_PixId i_perp) = 0;

    /** Check whether an event of a given signal can cause crosstalk in
     *  this victim pixel
     *
     *  @param              i_vic Victim Pixel ID
     *  @param i_perp       Perpetrator Pixel ID
     *  @param perp_energy  Perpetrator Signal Energy [keV]
     *
     *  @return             True if it can trigger, false if not
     */
    [[nodiscard]] virtual bool canTrigger(
        T_PixId i_vic, T_PixId i_perp,
        double perp_energy) = 0;

    /** Get all the victim pixels of a particular pixel.
     *  The victims are calculated here, so only call this function once,
     *  which is done e.g. by the CrosstalkHandler
     *
     *  @param i_perp  Perpertrator Pixel ID
     *  
     *  @return        Vector of Victim IDs
     */
    [[nodiscard]] virtual std::vector<T_PixId> getVictims(T_PixId i_perp) = 0;

    /** Get the maximum perpetrator delay for which this type can cause
     * crosstalk
     *
     * @return  Maximum perpetrator delay
     */
    [[nodiscard]] virtual double max_xt_delay() = 0;
};

class CrosstalkProxy {
  public:
    /** Constructor
     *
     * @param i_perp   Perpetrator Pixel ID
     * @param perpsig  Perpetrator Signal [keV]
     */
    CrosstalkProxy(const T_PixId i_perp, const Signal perpsig)
      : i_perp_(i_perp), perpsig_(perpsig)
    {}

    [[nodiscard]] T_PixId i_perp() {
      return i_perp_;
    }

    [[nodiscard]] Signal perpsig() {
      return perpsig_;
    }

    /** Add additional Signal to a proxy. Mostly used for pileup.
     *  The Signals are added internally, so Photon Metainfo is also tracked.
     *
     *  @param to_add  Signal to add
     */
    void addSignal(const Signal& to_add) {
      perpsig_.add(to_add);
    }

  private:
    T_PixId i_perp_;
    Signal perpsig_; // contains the energy, time and photon meta info
    // unsigned int xt_type; //?
    // std::vector<unsigned int> xt_types; //?
};

/** Class used to parse options for which crosstalk to run */
class CrosstalkOptions {
  public:
    /** Constructor
     *  Parses the option string to determine which crosstalk to run
     *
     *  @param opt_str  Option string, usually from the command line
     */
    CrosstalkOptions(const std::string opt_str);

    bool therm{false};
    bool tdm_prop_any{false};
    bool tdm_prop1{false};
    bool tdm_prop2{false};
    bool tdm_prop3{false};
    bool tdm_der{false};
};

/** Class to determine perpetrator to victim coupling in crosstalk.
 *  This particular class represents the connection via a vector of maps.
 */
class VictimMap {
  public:
    /** Constructor.
     * Initializes a vector of maps with zero connections
     *
     * @param num_pix  Number of Pixel coupling maps to prepare
     */
    VictimMap(unsigned int num_pix);

    /** Add a new crosstalk coupling
     *
     * @param i_perp  Perpetrator Pixel ID
     * @param i_vic   Victim Pixel ID
     * @param weight  Coupling weight
     */
    void addCoupling(const T_PixId i_perp, const T_PixId i_vic, const double weight);

    /** Check whether two pixels are coupled
     *
     * @param i_vic   Victim Pixel ID
     * @param i_perp  Perpetrator Pixel ID
     *
     * @return        True if pixels are coupled, false if not
     */
    [[nodiscard]] bool isCoupled(T_PixId i_vic, T_PixId i_perp);

    /** Get the coupling weight of two pixels. Returns 0 if there's no coupling
     *
     * @param i_vic   Victim Pixel ID
     * @param i_perp  Perpetrator Pixel ID
     *
     * @return        Coupling weight
     */
    [[nodiscard]] double getCoupling(T_PixId i_vic, T_PixId i_perp);

    /** Get all the victim pixels of a particular pixel.
     *  The victims are calculated here, so only call this function once,
     *  which is done e.g. by the CrosstalkHandler
     *
     *  @param i_perp  Perpertrator Pixel ID
     *  
     *  @return        Vector of Victim IDs
     */
    [[nodiscard]] std::vector<T_PixId> getVictims(T_PixId i_perp);

    /** The maximum coupling between any pixels.
     *
     *  @return        Maximum coupling strength
     */
    [[nodiscard]] double getMaxCoupling() {return max_coupling_;}


  private:
    std::vector<std::map<T_PixId, double>> data_;
    double max_coupling_{0}; // absolute value of the maximum coupling strength
};

/** Class that stores all crosstalk information, and handles its calculation */
class CrossTalkHandler {

  public:
    /** Constructor
     * Parse all crosstalk information and initialize the desired crosstalk types
     *
     * @param xml_data       XML Data to parse
     * @param crosstalk_opt  Option string specifying which crosstalk to run
     * @param geometry       Geometry of the array, used to initialize e.g.
     *                       thermal crosstalk
     */
    CrossTalkHandler(XMLData& xml_data, std::string crosstalk_opt, const ArrayGeometry* const geometry);

    /** Create CrosstalkProxies for all victims of this perpetrator
     *
     * @param i_perp   Perpetrator Pixel ID
     * @param perpsig  Perpetrator Signal [keV]
     */
    void createProxies(T_PixId i_perp, const Signal& perpsig);

    /** Clean all crosstalk proxies that could no longer affect
     *  a signal entering at a given cutoff time
     *
     * @param cutoff_time   Signal time to use
     */
    void cleanProxiesGlobal(double cutoff_time);

    /** Combine the two most recent CrosstalkProxies of this perpetrator into
     *  one. This is usually done when photons pile up
     *
     *  @param i_perp  Perpetrator Pixel ID
     */
    void pileupProxies(T_PixId i_perp);

    /** Calculate the total crosstalk that a given signal receives
     *
     * @param i_pix       Pixel ID (i.e., the victim)
     * @param pix_signal  Pixel Signal (includes time information!)
     * @param grade_id    Grading of the output record
     */
    [[nodiscard]] std::pair<unsigned int, double> calcTotalCrosstalk(
        T_PixId i_pix, const Signal& pix_signal, unsigned int grade_id
        );

    /** Check whether the most recent proxies of this perpetrator can cause
     *  "fake" triggering.
     *
     *  @param i_perp  Perpetrator Pixel ID
     *
     *  @return        Vector of pixels where we trigger due to crosstalk
     */
    std::vector<T_PixId> checkProxyTriggers(T_PixId i_perp);

  private:
    [[nodiscard]] const std::vector<T_PixId>& getVictims(T_PixId i_perp);

    void cleanProxies(T_PixId i_vic, double readout_time);

    std::vector<std::unique_ptr<CrosstalkType>> xt_types_;
    std::vector<std::vector<T_PixId>> victim_lists_;

    double max_xt_delay_{0.}; /// maximum perpetrator delay where crosstalk
                             /// can still be caused

    std::vector<std::deque<CrosstalkProxy>> proxy_list_;

};

class TimeEnergyDepCrosstalk: public CrosstalkType {
  public:
    /**
     * Constructor
     * Populates interpolators for lookup, but leaves pixels uncoupled
     *
     * @param num_pix      Number of pixels
     * @param num_grades   Number of grades
     * @param weight_file  FITS extension containing the dE(dt, e_vic, e_perp)
     *                     for each grade
     */
    TimeEnergyDepCrosstalk(
        unsigned int num_pix,
        std::vector<unsigned int> grades_npost,
        double dt,
        std::string weight_file,
        double trigger_threshold);

    void addCoupling(const T_PixId i_perp, const T_PixId i_vic,
        const double weight);

    [[nodiscard]] double computeEnergyShift(
        T_PixId i_vic, T_PixId i_perp,
        double e_vic, double e_perp,
        double perp_delay, unsigned int grade_id) override;

    [[nodiscard]] bool isCoupled(T_PixId i_vic, T_PixId i_perp) override;

    [[nodiscard]] bool canTrigger(
        T_PixId i_vic, T_PixId i_perp,
        double perp_energy) override;

    [[nodiscard]] std::vector<T_PixId> getVictims(T_PixId i_perp) override;

    [[nodiscard]] double max_xt_delay() override {
      return abs(*std::min_element(dt_lo_.begin(), dt_lo_.end()));
    };

  private:
    /// for each pixel, contains a map of victims and associated
    /// crosstalk weights
    /// Could also use a matrix here, but crosstalk coupling should
    /// be sparse (so most entries in a matrix would be zero!)
    VictimMap coupling_info_;

    [[nodiscard]] double getCoupling(T_PixId i_vic, T_PixId i_perp);

    /// 3D interpolation objects (one per grade)
    std::vector<InterpMultilinear<3,double>> interpolators_;

    std::vector<double> dt_lo_;
    std::vector<double> dt_hi_;

    double max_evic_;
    double max_eperp_;

    /// if the perpetrator energy times the coupling weight exceeds this
    /// value, crosstalk can cause a fake trigger
    double trigger_threshold_{0.};

};

class TimeDepCrosstalk: public CrosstalkType {
  public:
    /**
     * Constructor
     * Populates interpolators for lookup, but leaves pixels uncoupled
     *
     * @param num_pix      Number of pixels
     * @param num_grades   Number of grades
     * @param weight_file  FITS extension containing the dE(dt, e_vic, e_perp)
     *                     for each grade
     */
    TimeDepCrosstalk(
        unsigned int num_pix,
        std::vector<unsigned int> grades_npost,
        std::string weight_file,
        double trigger_threshold);

    void addCoupling(const T_PixId i_perp, const T_PixId i_vic,
        const double weight);

    [[nodiscard]] double computeEnergyShift(
        T_PixId i_vic, T_PixId i_perp,
        double e_vic, double e_perp,
        double perp_delay, unsigned int grade_id) override;

    [[nodiscard]] bool isCoupled(T_PixId i_vic, T_PixId i_perp) override;

    [[nodiscard]] bool canTrigger(
        T_PixId i_vic, T_PixId i_perp,
        double perp_energy) override;

    [[nodiscard]] std::vector<T_PixId> getVictims(T_PixId i_perp) override;

    [[nodiscard]] double max_xt_delay() override {
      return abs(*std::min_element(dt_lo_.begin(), dt_lo_.end()));
    };

  private:

    [[nodiscard]] double getCoupling(T_PixId i_vic, T_PixId i_perp);

    /// 1D interpolation objects (one per grade)
    std::vector<InterpMultilinear<1,double>> interpolators_;

    /// for each pixel, contains a map of victims and associated
    /// crosstalk weights
    /// Could also use a matrix here, but crosstalk coupling should
    /// be sparse (so most entries in a matrix would be zero!)
    VictimMap coupling_info_;

    std::vector<double> dt_lo_;
    std::vector<double> dt_hi_;

    /// if the perpetrator energy times the coupling weight exceeds this
    /// value, crosstalk can cause a fake trigger
    double trigger_threshold_{0.};
};

/** Load Proportional TDM Crosstalk
 *  Affects Pixels in row N+1 (prop1), N+2 (prop3)
 *  and all pixels in the channel (prop2)
 *
 * @param xml_data      XML to get data from
 * @param xt_node       XML node containing crosstalk info
 * @param num_pix       Number of detector pixels
 * @param grades_npost  Post length of a particular grade
 * @param dt            Detector sampling time
 * @param xt_opt        CrosstalkOptions to apply 
 * @param ext_scaling   Additional scaling to multiply coupling by
 *
 * @return  Crosstalk initialized to the corresponding weights
 */
TimeEnergyDepCrosstalk loadPropCrosstalk(
    XMLData& xml_data,
    XMLNode& xt_node,
    unsigned int num_pix,
    std::vector<unsigned int> grades_npost,
    double dt,
    CrosstalkOptions xt_opt,
    double ext_scaling);

/** Load Proportional TDM Crosstalk
 * Affects pixels in row N-1 and N+1
 *
 * @param xml_data      XML to get data from
 * @param xt_node       XML node containing crosstalk info
 * @param num_pix       Number of detector pixels
 * @param grades_npost  Post length of a particular grade
 * @param dt            Detector sampling time
 * @param ext_scaling   Additional scaling to multiply coupling by
 *
 * @return  Crosstalk initialized to the corresponding weights
 */
TimeEnergyDepCrosstalk loadDerivCrosstalk(
    XMLData& xml_data,
    XMLNode& xt_node,
    unsigned int num_pix,
    std::vector<unsigned int> grades_npost,
    double dt,
    double ext_scaling);

/** Load Proportional TDM Crosstalk
 * Affects pixels in row N-1 and N+1
 *
 * @param xml_data      XML to get data from
 * @param xt_node       XML node containing crosstalk info
 * @param num_pix       Number of detector pixels
 * @param grades_npost  Post length of a particular grade
 * @param geometry      Array geometry (for distance determination)
 * @param ext_scaling   Additional scaling to multiply coupling by
 *
 * @return  Multiple crosstalk types, one per specified distance pairing
 */
std::vector<TimeDepCrosstalk> loadThermalCrosstalks(
    XMLData& xml_data,
    XMLNode& xt_node,
    unsigned int num_pix,
    std::vector<unsigned int> grades_npost,
    const ArrayGeometry* const geometry,
    double ext_scaling
    );

}
