/*
    This file is part of SIXTE.

    SIXTE is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    any later version.

    SIXTE is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.

    For a copy of the GNU General Public License see
    <http://www.gnu.org/licenses/>.


    Copyright 2025 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
    Erlangen-Nuernberg
*/

#include "LobsterEyeOptic.h"
#include "raytracing/surface/SurfaceModel.h"

namespace {
constexpr unsigned int kMaskOptics = 1u;
constexpr unsigned int kMaskTerminal = 2u;

void set_geometry_mask(RTCScene scene, unsigned int geom_id, unsigned int mask) {
  if (geom_id == RTC_INVALID_GEOMETRY_ID) {
    return;
  }
  RTCGeometry geometry = rtcGetGeometry(scene, geom_id);
  if (!geometry) {
    return;
  }
  rtcSetGeometryMask(geometry, mask);
  rtcCommitGeometry(geometry);
}
}

LobsterEyeOptic::LobsterEyeOptic(const sixte::XMLData &xml_data)
  : xml_path_(xml_data.dirname()),
    raytracer_node_(xml_data.child("telescope").child("raytracer")) {
  device_ = EmbreeScene::initializeDevice();

  // focal length in mm
  focal_length_ = (xml_data.child("telescope").child("focallength").attributeAsDouble("value") * 1000);
  LobsterEyeOptic::create();
  full_scene_ = initializeFullScene(device_);
  optics_scene_ = initializeOpticsScene(device_);
}

std::optional<Ray> LobsterEyeOptic::ray_trace(Ray &ray) {
    if(embree_ray_trace(ray, 5)) {
        ray.set_fail_reason(RayFailReason::NONE);
        return ray;
    }
    return std::nullopt;
}

std::optional<Ray> LobsterEyeOptic::ray_trace_optics(Ray &ray) {
  if (embree_ray_trace_optics(ray, 5)) {
    ray.set_fail_reason(RayFailReason::NONE);
    return ray;
  }
  return std::nullopt;
}

std::optional<Ray> LobsterEyeOptic::ray_trace_terminal(Ray &ray) {
  if (embree_ray_trace_terminal(ray)) {
    ray.set_fail_reason(RayFailReason::NONE);
    return ray;
  }
  return std::nullopt;
}

std::optional<size_t> LobsterEyeOptic::FindSensorIndexByGeomID(size_t geom_id) {
  auto it = std::find_if(mesh_sensor_.begin(),
                         mesh_sensor_.end(),
                         [geom_id](const Sensor& sensor) {
                           return sensor.geomID() == geom_id;
                         });
  if (it == mesh_sensor_.end()) {
    return std::nullopt;
  }
  return static_cast<size_t>(std::distance(mesh_sensor_.begin(), it));
}

bool LobsterEyeOptic::embree_ray_trace_optics(Ray &ray, int depth) {
  while (depth > 0) {
    rtcIntersect1(optics_scene_, &ray.rayhit);

    Vec3fa normal = Vec3fa(ray.rayhit.hit.Ng_x,
                           ray.rayhit.hit.Ng_y,
                           ray.rayhit.hit.Ng_z);
    ray.set_normal(normalize(normal));

    if (ray.rayhit.hit.geomID == RTC_INVALID_GEOMETRY_ID) {
      ray.set_fail_reason(RayFailReason::MISSED_OPTIC);
      return false;
    }

    ray.raytracing_history.emplace_back((short) ray.rayhit.hit.geomID,
                                        ray.position(),
                                        ray.direction());

    if (ray.rayhit.hit.geomID == optical_geom_id_optics_) {
      ray.set_position(ray.position() + ray.rayhit.ray.tfar * ray.direction());

      if (!pore_.ray_trace(ray, depth)) {
        ray.set_fail_reason(RayFailReason::LOST_IN_PORE);
        return false;
      }

      return true;
    }

    depth--;
  }

  ray.set_fail_reason(RayFailReason::MAX_DEPTH);
  return false;
}

bool LobsterEyeOptic::embree_ray_trace_terminal(Ray &ray) {
  // Applies terminal geometry exactly once:
  //   - spider shadowing (incoming path)
  //   - spider obstruction + active sensor mosaic / gaps (outgoing path)
  // No reroll happens here.

  // --- Incoming path spider check (before the optic) ---
  if (!spider_.filename.empty() && !ray.raytracing_history.empty()) {
    Vec3fa in_dir = ray.raytracing_history.front().direction;
    Vec3fa in_org = ray.raytracing_history.front().origin;
    Ray incoming(in_org, in_dir, ray.energy);
    incoming.rayhit.ray.mask = kMaskTerminal;
    rtcIntersect1(full_scene_, &incoming.rayhit);
    if (incoming.rayhit.hit.geomID == spider_.geomID) {
      ray.set_fail_reason(RayFailReason::BLOCKED_SPIDER);
      return false;
    }
  }

  // --- Outgoing path: from optics exit to sensors ---
  ray.reset_rayhit();
  ray.rayhit.ray.mask = kMaskTerminal;

  rtcIntersect1(full_scene_, &ray.rayhit);

  Vec3fa normal = Vec3fa(ray.rayhit.hit.Ng_x,
                         ray.rayhit.hit.Ng_y,
                         ray.rayhit.hit.Ng_z);
  ray.set_normal(normalize(normal));

  if (ray.rayhit.hit.geomID == RTC_INVALID_GEOMETRY_ID) {
    ray.set_fail_reason(RayFailReason::MISSED_SENSOR);
    return false;
  }

  const size_t hitGeomID = ray.rayhit.hit.geomID;

  if (!spider_.filename.empty() && hitGeomID == spider_.geomID) {
    ray.set_fail_reason(RayFailReason::BLOCKED_SPIDER);
    return false;
  }

  auto sensor_index = FindSensorIndexByGeomID(hitGeomID);
  if (sensor_index.has_value()) {
    Sensor& sensor = mesh_sensor_[*sensor_index];
    ray.set_position(ray.position() + (ray.rayhit.ray.tfar * ray.direction()));
    ray.sensor_position = sensor.worldToSensor(ray);
    ray.hitID = hitGeomID;
    return true;
  }

  // Any other fail is also treated as a terminal loss.
  ray.set_fail_reason(RayFailReason::MISSED_SENSOR);
  return false;
}

bool LobsterEyeOptic::embree_ray_trace(Ray &ray, int depth) {
  int starting_depth = depth;

  while (depth > 0) {
    rtcIntersect1(full_scene_, &ray.rayhit);

    Vec3fa normal = Vec3fa(ray.rayhit.hit.Ng_x,
                           ray.rayhit.hit.Ng_y,
                           ray.rayhit.hit.Ng_z);
    ray.set_normal(normalize(normal));

    if (ray.rayhit.hit.geomID == RTC_INVALID_GEOMETRY_ID) {
      // If we never intersected anything, we missed the optic entirely.
      // If we already intersected something, we likely missed the detector.
      if (ray.raytracing_history.empty()) {
        ray.set_fail_reason(RayFailReason::MISSED_OPTIC);
      } else {
        ray.set_fail_reason(RayFailReason::MISSED_SENSOR);
      }
      return false;
    }

    ray.raytracing_history.emplace_back((short) ray.rayhit.hit.geomID,
                                        ray.position(),
                                        ray.direction());

    size_t searchID = ray.rayhit.hit.geomID;

    auto sensor_index = FindSensorIndexByGeomID(searchID);
    if (sensor_index.has_value()) {
      Sensor& sensor = mesh_sensor_[*sensor_index];
      if (depth == starting_depth) {
        ray.set_fail_reason(RayFailReason::DIRECT_SENSOR_HIT);
        return false;
      }

      ray.set_position(ray.position() + (ray.rayhit.ray.tfar * ray.direction()));
      ray.sensor_position = sensor.worldToSensor(ray);

      ray.hitID = searchID;
      ray.set_fail_reason(RayFailReason::NONE);
      return true;
    }

    if (ray.rayhit.hit.geomID == spider_.geomID) {
      ray.set_fail_reason(RayFailReason::BLOCKED_SPIDER);
      return false;
    }

    if (ray.rayhit.hit.geomID == optical_geom_id_full_) {
      ray.set_position(ray.position() + ray.rayhit.ray.tfar * ray.direction());

      if (!pore_.ray_trace(ray, depth)) {
        ray.set_fail_reason(RayFailReason::LOST_IN_PORE);
        return false;
      }
    }

    depth--;
  }

  ray.set_fail_reason(RayFailReason::MAX_DEPTH);
  return false;
}

void LobsterEyeOptic::createSpider() {
  std::string spider_flag = raytracer_node_.child("spider").attributeAsString("spider");
  Vec3fa spider_position = {};
  spider_position.x = (float) raytracer_node_.child("spider").attributeAsDouble("position_x");
  spider_position.y = (float) raytracer_node_.child("spider").attributeAsDouble("position_y");
  spider_position.z = (float) raytracer_node_.child("spider").attributeAsDouble("position_z");
  std::string spider_path = xml_path_ + raytracer_node_.child("spider").attributeAsString("path");

  if (spider_flag == "true")
    spider_ = Spider(spider_path, spider_position);

}

void LobsterEyeOptic::createSurface() {
  surface_model_ = SurfaceModel::from_xml(raytracer_node_.child("surface"));
}

void LobsterEyeOptic::createOptic() {
  Vec3fa optical_position = {};
  optical_position.x = (float) raytracer_node_.child("optical").attributeAsDouble("position_x");
  optical_position.y = (float) raytracer_node_.child("optical").attributeAsDouble("position_y");
  optical_position.z = (float) raytracer_node_.child("optical").attributeAsDouble("position_z");
  std::string optical_path = xml_path_ + raytracer_node_.child("optical").attributeAsString("path");
  opticalMesh_ = OpticalMesh(optical_path, optical_position);
}

void LobsterEyeOptic::createMPOs() {
  double pore_width;
  double pore_length;
  pore_width = raytracer_node_.child("type").attributeAsDouble("pore_width");
  pore_length = raytracer_node_.child("type").attributeAsDouble("pore_length");

  std::string material_path = xml_path_ + raytracer_node_.child("surface").attributeAsString("material_path");
  std::string material = raytracer_node_.child("surface").attributeAsString("material");

  pore_ = Pore(pore_width, pore_length, Vec3fa(0,0,0), Vec3fa(0,0,0), material_path,
               material, surface_model_);
}

Vec3fa LobsterEyeOptic::readSensorPosition(const sixte::XMLNode& node) {
  float sensor_x = node.attributeAsFloat("sensor_x");
  float sensor_y = node.attributeAsFloat("sensor_y");
  float sensor_z = node.attributeAsFloat("sensor_z");
  return Vec3fa{sensor_x, sensor_y, sensor_z};
}

void LobsterEyeOptic::createSensor() {
  std::string sensor_mesh = raytracer_node_.child("sensor").attributeAsString("mesh");

  if (sensor_mesh == "true") {
    auto sensors = raytracer_node_.child("sensor").allChildren();
    if (sensors.empty()) {
      Vec3fa sensor_position = readSensorPosition(raytracer_node_.child("sensor"));
      std::string sensor_path = xml_path_ + raytracer_node_.child("sensor").attributeAsString("path");
      mesh_sensor_.emplace_back(sensor_path, sensor_position);
    } else {
      for (auto sensor_element : sensors) {
        std::string sensor_path = xml_path_ + sensor_element.attributeAsString("path");

        Vec3fa chip_position = readSensorPosition(sensor_element);
        mesh_sensor_.emplace_back(sensor_path, chip_position);
      }
    }
  } else {
    Vec3fa sensor_position = readSensorPosition(raytracer_node_.child("sensor"));
    double sensor_offset = raytracer_node_.child("sensor").attributeAsDouble("offset");
    sensor_ = Plane(0,0,1, sensor_offset, sensor_position.x, sensor_position.y);
  }
}

void LobsterEyeOptic::create() {
  createSpider();
  createSurface();
  createOptic();
  createMPOs();

  createSensor(); // TODO: put in different class
}

RTCScene LobsterEyeOptic::initializeFullScene(RTCDevice device) {
  RTCScene scene = rtcNewScene(device);
  rtcSetSceneFlags(scene, RTC_SCENE_FLAG_ROBUST);
  rtcSetSceneBuildQuality(scene, RTC_BUILD_QUALITY_HIGH);

  if (!mesh_sensor_.empty()) {
    for (auto &sensor  : mesh_sensor_) {
      std::vector<Vec3fa> points{};

      size_t geomID_sensor = EmbreeScene::addSTLMesh(sensor.filename(), sensor.position(), scene, device, points);
      sensor.setGeometricID(geomID_sensor);

//      auto geom_sensor = rtcGetGeometry(scene, sensor.geomID());

      if (sensor.geomID() % 2 == 0) {
        sensor.setOriginPoint(points[1]);
        sensor.setAxes(points[2], points[0]);
      } else {
        sensor.setOriginPoint(points[3]);
        sensor.setAxes(points[0], points[2]);
      }

      set_geometry_mask(scene, geomID_sensor, kMaskTerminal);
    }
  } else {
        RTCGeometry geometry = rtcNewGeometry(device, RTC_GEOMETRY_TYPE_USER);
        auto *para = &sensor_.planeParameters;

        rtcSetGeometryUserPrimitiveCount(geometry, 1);
        rtcSetGeometryUserData(geometry, para);
        para->geometry = geometry;

        rtcSetGeometryBoundsFunction(geometry, Plane::planeBoundsFunc, nullptr);
        rtcSetGeometryIntersectFunction(geometry, Plane::planeIntersectFunc);
        rtcSetGeometryOccludedFunction(geometry, Plane::planeOccludedFunc);
        rtcSetGeometryMask(geometry, kMaskTerminal);

        // Commit the geometry and attach it to the scene.
        rtcCommitGeometry(geometry);
        para->geomID = rtcAttachGeometry(scene, geometry);
        rtcReleaseGeometry(geometry);
    }
    
  if (!spider_.filename.empty())
      spider_.geomID = EmbreeScene::addSTLMesh(spider_.filename, spider_.position, scene, device);

  optical_geom_id_full_ = EmbreeScene::addSTLMesh(opticalMesh_.filename, opticalMesh_.position, scene, device);
  set_geometry_mask(scene, spider_.geomID, kMaskTerminal);
  set_geometry_mask(scene, optical_geom_id_full_, kMaskOptics);
  rtcCommitScene(scene);
  return scene;
}

RTCScene LobsterEyeOptic::initializeOpticsScene(RTCDevice device) {
  RTCScene scene = rtcNewScene(device);
  rtcSetSceneFlags(scene, RTC_SCENE_FLAG_ROBUST);
  rtcSetSceneBuildQuality(scene, RTC_BUILD_QUALITY_HIGH);

  // MPO entrance/plate surface mesh.
  optical_geom_id_optics_ = EmbreeScene::addSTLMesh(opticalMesh_.filename,
                                                    opticalMesh_.position,
                                                    scene,
                                                    device);

  rtcCommitScene(scene);
  return scene;
}
