/*
   This file is part of SIXTE.

   SIXTE is free software: you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   any later version.

   SIXTE is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
   GNU General Public License for more details.

   For a copy of the GNU General Public License see
   <http://www.gnu.org/licenses/>.


   Copyright 2024 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
                  Erlangen-Nuernberg
*/

#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_floating_point.hpp>

#include "PatternAnalysis.h"

using namespace sixte;

void refdata_comparison(const std::string& ref_eventfile_name, const std::string& output_eventfile_name) {
  auto ref_eventfile = sixteOpenFITSFileRead(ref_eventfile_name, "eventfile");
  auto& ref_eventfile_table = ref_eventfile->extension(1);

  auto output_eventfile = sixteOpenFITSFileRead(output_eventfile_name, "eventfile");
  auto& output_eventfile_table = output_eventfile->extension(1);

  long nrows_ref = (long) ref_eventfile_table.rows();
  long nrows_output = (long) ref_eventfile_table.rows();
  REQUIRE(nrows_ref == nrows_output);

  //TODO: PH_ID, SRC_ID, SIGNALS, PHAS (vector columns) + keywords
  std::vector<double> time_ref, time_output;
  std::vector<long> frame_ref, frame_output;
  std::vector<long> pha_ref, pha_output;
  std::vector<double> signal_ref, signal_output;
  std::vector<int> rawx_ref, rawx_output;
  std::vector<int> rawy_ref, rawy_output;
  std::vector<int> type_ref, type_output;
  std::vector<long> npixels_ref, npixels_output;
  std::vector<int> pileup_ref, pileup_output;

  ref_eventfile_table.column("TIME").read(time_ref, 1, nrows_ref);
  output_eventfile_table.column("TIME").read(time_output, 1, nrows_ref);
  ref_eventfile_table.column("FRAME").read(frame_ref, 1, nrows_ref);
  output_eventfile_table.column("FRAME").read(frame_output, 1, nrows_ref);
  ref_eventfile_table.column("PHA").read(pha_ref, 1, nrows_ref);
  output_eventfile_table.column("PHA").read(pha_output, 1, nrows_ref);
  ref_eventfile_table.column("SIGNAL").read(signal_ref, 1, nrows_ref);
  output_eventfile_table.column("SIGNAL").read(signal_output, 1, nrows_ref);
  ref_eventfile_table.column("RAWX").read(rawx_ref, 1, nrows_ref);
  output_eventfile_table.column("RAWX").read(rawx_output, 1, nrows_ref);
  ref_eventfile_table.column("RAWY").read(rawy_ref, 1, nrows_ref);
  output_eventfile_table.column("RAWY").read(rawy_output, 1, nrows_ref);
  ref_eventfile_table.column("TYPE").read(type_ref, 1, nrows_ref);
  output_eventfile_table.column("TYPE").read(type_output, 1, nrows_ref);
  ref_eventfile_table.column("NPIXELS").read(npixels_ref, 1, nrows_ref);
  output_eventfile_table.column("NPIXELS").read(npixels_output, 1, nrows_ref);
  ref_eventfile_table.column("PILEUP").read(pileup_ref, 1, nrows_ref);
  output_eventfile_table.column("PILEUP").read(pileup_output, 1, nrows_ref);

  for (long ii = 0; ii < nrows_ref; ii++) {
    REQUIRE_THAT(time_output[ii], Catch::Matchers::WithinAbs(time_ref[ii], 1e-6));
    REQUIRE_THAT(frame_output[ii], Catch::Matchers::WithinAbs(frame_ref[ii], 1e-6));
    REQUIRE_THAT(pha_output[ii], Catch::Matchers::WithinAbs(pha_ref[ii], 1e-6));
    REQUIRE_THAT(signal_output[ii], Catch::Matchers::WithinAbs(signal_ref[ii], 1e-6));
    REQUIRE_THAT(rawx_output[ii], Catch::Matchers::WithinAbs(rawx_ref[ii], 1e-6));
    REQUIRE_THAT(rawy_output[ii], Catch::Matchers::WithinAbs(rawy_ref[ii], 1e-6));
    REQUIRE_THAT(type_output[ii], Catch::Matchers::WithinAbs(type_ref[ii], 1e-6));
    REQUIRE_THAT(npixels_output[ii], Catch::Matchers::WithinAbs(npixels_ref[ii], 1e-6));
    REQUIRE_THAT(pileup_output[ii], Catch::Matchers::WithinAbs(pileup_ref[ii], 1e-6));
  }
}

TEST_CASE("Pattern Analysis Test", "[pattern_analysis]") {
  SECTION("wfi") {
    std::string ref_raweventfile_name = "./data/refdata/ref_wfi_raw_10mcrab.fits";
    std::string ref_eventfile_name = "./data/refdata/ref_wfi_evt_10mcrab.fits";
    std::string output_eventfile_name = "wfi_evt_10mcrab.fits";
    std::string xml_filename = "./data/instruments/athena-wfi/ld_wfi_ff_large.xml";

    pugi::xml_document xml_file;
    xml_file.load_file(xml_filename.c_str());
    XMLData xml_data(xml_file, "data/instruments/athena-wfi/");

    auto rmf_filename = xml_data.dirname() + xml_data.child("detector").child("rmf").attributeAsString("filename");

    ObsPointing obs_pointing("none", 0, 0, 0);
    ObsTime obs_time(55000., 0, 10);
    GTICollection gti_collection("", obs_time);
    ObsInfo obs_info(xml_data, obs_pointing, gti_collection);

    auto clobber{true};
    auto skip_invalids{true};
    auto rawymin{0};
    auto rawymax{511};

    SECTION("phpat") {
      NewRMF rmf(rmf_filename);
      auto ebounds = rmf.exportEbounds();
      NewEventfile ref_raweventfile(ref_raweventfile_name);
      NewEventfile output_eventfile(output_eventfile_name, clobber, xml_data, obs_info);
      PatternAnalysis pattern_analysis(skip_invalids, xml_data, rawymin, rawymax);

      pattern_analysis.phpat(ebounds, ref_raweventfile, output_eventfile);
    }

    SECTION("refdata comparison") {
      refdata_comparison(ref_eventfile_name, output_eventfile_name);
    }

    SECTION("clean-up") {
      if (std::remove(output_eventfile_name.c_str()) != 0) {
        throw SixteException("Failed to remove output eventfile");
      }
    }
  }

  SECTION("erosita") {
    std::string ref_raweventfile_name = "./data/refdata/ref_erosita_raw_100mcrab.fits";
    std::string ref_eventfile_name = "./data/refdata/ref_erosita_evt_100mcrab.fits";
    std::string output_eventfile_name = "erosita_evt_100mcrab.fits";
    std::string xml_filename = "./data/instruments/eRosita/erosita_1.xml";

    pugi::xml_document xml_file;
    xml_file.load_file(xml_filename.c_str());
    XMLData xml_data(xml_file, "data/instruments/eRosita/");

    auto rmf_filename = xml_data.dirname() + xml_data.child("detector").child("rmf").attributeAsString("filename");

    ObsPointing obs_pointing("none", 0, 0, 0);
    ObsTime obs_time(51543.875, 0, 20);
    GTICollection gti_collection("", obs_time);
    ObsInfo obs_info(xml_data, obs_pointing, gti_collection);

    auto clobber{true};
    auto skip_invalids{true};
    auto rawymin{0};
    auto rawymax{383};

    SECTION("phpat") {
      NewRMF rmf(rmf_filename);
      auto ebounds = rmf.exportEbounds();
      NewEventfile ref_raweventfile(ref_raweventfile_name);
      NewEventfile output_eventfile(output_eventfile_name, clobber, xml_data, obs_info);
      PatternAnalysis pattern_analysis(skip_invalids, xml_data, rawymin, rawymax);

      pattern_analysis.phpat(ebounds, ref_raweventfile, output_eventfile);
    }

    SECTION("refdata comparison") {
      refdata_comparison(ref_eventfile_name, output_eventfile_name);
    }

    SECTION("clean-up") {
      if (std::remove(output_eventfile_name.c_str()) != 0) {
        throw SixteException("Failed to remove output eventfile");
      }
    }
  }
}
