/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2022 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#pragma once

#include <memory>
#include <utility>
#include "LineInfo.h"
#include "ObsInfo.h"
#include "Signal.h"
#include "SimulationParameters.h"
#include "XMLData.h"

namespace sixte {

class Pixel {
 public:
  virtual ~Pixel() = default;
  
  virtual void addSignal(const Signal& input_signal) = 0;
  virtual void addSignal(Signal&& input_signal) = 0;
  virtual void setSignal(const Signal& input_signal) = 0;
  virtual void setSignal(Signal&& input_signal) = 0;
  [[nodiscard]] virtual const std::optional<Signal>& getSignal() const = 0;
  virtual std::optional<Signal> releaseSignal() = 0;
  virtual void clearSignal() = 0;
  virtual void scaleSignal(double scaling_factor) = 0;
  [[nodiscard]] virtual bool anySignal() const = 0;
};

class SimplePixel: public Pixel {
 public:
  void addSignal(const Signal& s) override {
    addSignalImpl(s);
  }

  void addSignal(Signal&& s) override {
    addSignalImpl(std::move(s));
  }

  void setSignal(const Signal& input_signal) override {
    if (signal_) {
      signal_.reset();
    }
    signal_ = input_signal;
  }

  void setSignal(Signal&& input_signal) override {
    if (signal_) {
      signal_.reset();
    }
    signal_.emplace(std::move(input_signal));
  }

  void clearSignal() override {
    signal_.reset();
  }

  void scaleSignal(double scaling_factor) override {
    if (signal_) {
      signal_->scale(scaling_factor);
    }
  }

  [[nodiscard]] const std::optional<Signal>& getSignal() const override {
    return signal_;
  }

  [[nodiscard]] bool anySignal() const override {
    return signal_.has_value();
  }

  std::optional<Signal> releaseSignal() noexcept override {
    if (!signal_) return std::nullopt;
    std::optional<Signal> out_signal{std::move(*signal_)};
    signal_.reset();
    return out_signal;
  }

 private:
  template<class Sig>
  void addSignalImpl(Sig&& input_signal) {
    if (!signal_) {
      signal_.emplace(std::forward<Sig>(input_signal));
    } else {
      signal_->add(std::forward<Sig>(input_signal));
    }
  }

  std::optional<Signal> signal_;
};


class DEPFETPixel: public Pixel {
 public:
  void addSignal(const Signal& input_signal) override {
    addSignalImpl(input_signal);
  }

  void addSignal(Signal&& input_signal) override {
    addSignalImpl(std::move(input_signal));
  }

  void setSignal(const Signal& signal) override {
    signal_.reset();
    signal_ = signal;
  }

  void setSignal(Signal&& signal) override {
    signal_.reset();
    signal_.emplace(std::move(signal));
  }

  explicit DEPFETPixel(XMLData& xml_data, double frame_duration, T_PixId pix_id,
                       std::shared_ptr<LineInfo> line_info)
      : pix_id_(pix_id),
        frame_duration_{frame_duration},
        line_info_(std::move(line_info)) {
    auto depfet = xml_data.child("detector").child("depfet");
    t_integration_ = depfet.attributeAsDouble("integration");
    t_clear_ = depfet.attributeAsDouble("clear");
    if (depfet.hasAttribute("settling")) {
      t_settling_1_ = depfet.attributeAsDouble("settling");
      t_settling_2_ = t_settling_1_;
    } else {
      t_settling_1_ = depfet.attributeAsDouble("settling_1");
      t_settling_2_ = depfet.attributeAsDouble("settling_2");
    }

    t_readout_ = t_settling_1_ + t_settling_2_ + 2.*t_integration_ + t_clear_;
    t_not_readout_ = frame_duration_ - t_readout_;
  }

  void clearSignal() override {
    signal_.swap(ccarry_);
    ccarry_.reset();
  }

  void scaleSignal(double scaling_factor) override {
    signal_->scale(scaling_factor);
  }

  [[nodiscard]] const std::optional<Signal>& getSignal() const override {
    return signal_;
  }

  [[nodiscard]] bool anySignal() const override {
    return signal_.has_value() || ccarry_.has_value();
  }

  std::optional<Signal> releaseSignal() noexcept override {
    if (!signal_) return std::nullopt;
    std::optional out{std::move(*signal_)};
    clearSignal();
    return out;
  }

 private:
  enum class time_intervals {
    normal_exposure,
    first_settling,
    first_integration,
    clear_interval,
    second_settling,
    second_integration
  };

  [[nodiscard]] std::pair<time_intervals, double> getTimeInterval(double impact_time,
                                                                  double last_readout_time) const {
    // Determine time since the start of the readout cycle (specific impact time)
    double t_imp_spec = impact_time - last_readout_time;

    if (t_imp_spec <= t_not_readout_) {
      return std::make_pair(time_intervals::normal_exposure, t_imp_spec);
    } else {
      double ts1 = t_imp_spec - t_not_readout_;
      if (ts1 <= t_settling_1_) {
        return std::make_pair(time_intervals::first_settling, ts1);
      } else {
        double ti1 = ts1 - t_settling_1_;
        if (ti1 <= t_integration_) {
          return std::make_pair(time_intervals::first_integration, ti1);
        } else {
          double tc = ti1 - t_integration_;
          if (tc <= t_clear_) {
            return std::make_pair(time_intervals::clear_interval, tc);
          } else {
            double ts2 = tc - t_clear_;
            if (ts2 <= t_settling_2_) {
              return std::make_pair(time_intervals::second_settling, ts2);
            } else {
              double ti2 = ts2 - t_settling_2_;
              if (ti2 <= t_integration_) {
                return std::make_pair(time_intervals::second_integration, ti2);
              } else {
                throw SixteException("Invalid time interval");
              }
            }
          }
        }
      }
    }
  }

  [[nodiscard]] double linearClearSignal(double time, double energy) const {
    return (energy * time / t_clear_);
  }

  template <class Sig>
  void addSignalImpl(Sig&& input_signal) {
    auto interval = getTimeInterval(input_signal.creation_time(), line_info_->lastReadoutTime(pix_id_));

    double signal_to_add = 0.0;
    double carry_to_add = 0.0;
    bool anycarry = false;

    switch(interval.first) {
      case time_intervals::normal_exposure: {
        signal_to_add = input_signal.val();
        break;
      }
      case time_intervals::first_settling: {
        signal_to_add = input_signal.val();
        break;
      }
      case time_intervals::first_integration: {
        signal_to_add = input_signal.val() * (t_integration_ - interval.second)
            / t_integration_;
        break;
      }
      case time_intervals::clear_interval: {
        // Set PH_ID and SRC_ID in carry-arrays.
        // Should be set automatically in Signal. TODO: check
        anycarry = true;
        double rem = linearClearSignal(interval.second,
                                       input_signal.val());
        signal_to_add = -1 * rem;
        carry_to_add = rem;
        break;
      }
      case time_intervals::second_settling: {
        anycarry = true;
        signal_to_add = -1. * input_signal.val();
        carry_to_add = input_signal.val();
        break;
      }
      case time_intervals::second_integration: {
        anycarry = true;
        signal_to_add = -1. * input_signal.val()
            * (t_integration_ - interval.second) / t_integration_;
        carry_to_add = input_signal.val();
        break;
      }
      default: {
        throw SixteException("Time interval not valid");
      }
    }

    double creation_time = input_signal.creation_time();

    Signal new_signal(signal_to_add,
                      std::forward<Sig>(input_signal).photon_metainfo(),
                      creation_time);

    if (signal_) {
      signal_->add(std::move(new_signal));
    } else {
      signal_.emplace(std::move(new_signal));
    }

    if (anycarry) {
      Signal new_carry(carry_to_add, signal_->photon_metainfo(), creation_time);
      if (ccarry_) {
        ccarry_->add(std::move(new_carry));
      } else {
        ccarry_.emplace(std::move(new_carry));
      }
    }
  }

  T_PixId pix_id_;
  std::optional<Signal> signal_;
  std::optional<Signal> ccarry_;

  double t_integration_; ///< Integration time
  double t_clear_; ///< Clear time
  double t_settling_1_; ///< First settling time
  double t_settling_2_; ///< Second settling time

  double frame_duration_; ///< Duration of one frame
  double t_readout_; ///< Duration of the active readout
  double t_not_readout_; ///< Time length from the beginning of the
                         ///< cycle to the beginning of the read-out.

  std::shared_ptr<LineInfo> line_info_;
};

std::unique_ptr<Pixel> createPixel(XMLData& xml_data, std::optional<double> frame_duration,
                                   std::optional<T_PixId> pix_id,
                                   std::optional<std::shared_ptr<LineInfo>> line_info);

class DeadTimeHandler {
  public:
    DeadTimeHandler(double deadtime, bool is_paralyzable) :
    is_paralyzable_(is_paralyzable), deadtime_(deadtime) 
    {}

    void makeDead(double time) {
      dead_until_ = time + deadtime_;
    }

    bool canAdd(double time) {
      if (dead_until_.has_value() && time < dead_until_.value()) {
        if (is_paralyzable_) {
          makeDead(time);
        }
        return false;
      } else {
        makeDead(time);
        return true;
      }
    }

  private:
    bool is_paralyzable_;

    double deadtime_;

    std::optional<double> dead_until_;

};


class EvtPixel: public Pixel {
 public:

  explicit EvtPixel(XMLData& xml_data)
  {
    if (auto readout = xml_data.child("detector").optionalChild("readout")) {
      if (readout->hasAttribute("deadtime")) {
        std::string deadtype = readout->attributeAsString("deadtype");
        bool is_paralyzable = deadtype.compare("paralyzable") == 0;

        deadhand_ = DeadTimeHandler(
            readout->attributeAsDouble("deadtime"),
            is_paralyzable);
      }
    }
  }

  void addSignal(const Signal& s) override {
    addSignalImpl(s);
  }
  void addSignal(Signal&& s) override {
    addSignalImpl(std::move(s));
  }

  void setSignal(const Signal& input_signal) override {
    if (signal_) {
      signal_.reset();
    }
    signal_ = input_signal;
  }

  void setSignal(Signal&& input_signal) override {
    if (signal_) {
      signal_.reset();
    }
    signal_.emplace(std::move(input_signal));
  }

  void clearSignal() override {
    signal_.reset();
  }

  void scaleSignal(double scaling_factor) override {
    if (signal_) {
      signal_->scale(scaling_factor);
    }
  }

  [[nodiscard]] const std::optional<Signal>& getSignal() const override {
    return signal_;
  }

  [[nodiscard]] bool anySignal() const override {
    return signal_.has_value();
  }

  std::optional<Signal> releaseSignal() noexcept override {
    if (!signal_) return std::nullopt;
    std::optional out{std::move(*signal_)};
    clearSignal();
    return out;
  }

 private:
  template<class Sig>
  void addSignalImpl(Sig&& input_signal) {
    bool not_dead =
         !deadhand_.has_value()
      || deadhand_->canAdd(input_signal.creation_time());

    if (not_dead) {
      if (!signal_) {
        signal_.emplace(std::forward<Sig>(input_signal));
      } else {
        signal_->add(std::forward<Sig>(input_signal));
      }
    }
  }

  std::optional<Signal> signal_;

  std::optional<DeadTimeHandler> deadhand_;
};

}
