/*
    This file is part of SIXTE.

    SIXTE is free software: you can redistribute it and/or modify it
    under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    any later version.

    SIXTE is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    GNU General Public License for more details.

    For a copy of the GNU General Public License see
    <http://www.gnu.org/licenses/>.


    Copyright 2025 Remeis-Sternwarte, Friedrich-Alexander-Universitaet
    Erlangen-Nuernberg
*/


#include <iostream>
#include "Pore.h"

#include "NewSIXT.h"

Pore::Pore() {

}

Pore::Pore(double pwidth,
           double plength,
           Vec3fa protation,
           Vec3fa ptranslation,
           std::string material_path,
           std::string material,
           std::shared_ptr<SurfaceModel> surface_model) {
    width = pwidth;
    length = plength;
    rotation = protation;
    translation = ptranslation;
    coating = sixte::SurfaceElement(material_path, material);
    surface_model_ = std::move(surface_model);
    wall1 = Plane(0, 1, 0, 0, 0, 0);
    wall2 = Plane(1, 0, 0, -width, 0, 0);
    wall3 = Plane(0, 1, 0, -width, 0, 0);
    wall4 = Plane(1, 0, 0, 0, 0, 0);
    floor = Plane(0, 0, 1, 0, 0, 0);

}

void Pore::set_rotation(Vec3fa protation) {
    rotation = protation;
}

void Pore::set_translation(Vec3fa ptranslation) {
    translation = ptranslation;
}

void Pore::set_width(double pwidth) {
    width = pwidth;
}

void Pore::set_length(double plength) {
    length = plength;
}

int Pore::findInterection(Ray &ray) {
    double t = std::numeric_limits<double>::infinity();
    int wall_number = -1;
    Vec3fa dir = ray.direction();
    Vec3fa pos = ray.position();

    double t_temp = wall1.planeIntersect(ray);
    Vec3fa hit = pos + t_temp * dir;
    if (!(hit.z < 0 || hit.z > length || hit.x < 0 || hit.x > width || t_temp > t || t_temp < ray.rayhit.ray.tnear)) {
        t = t_temp;
        wall_number = 1;
    }

    t_temp = wall2.planeIntersect(ray);
    hit = pos + t_temp * dir;
    if (!(hit.z < 0 || hit.z > length || hit.y < 0 || hit.y > width || t_temp > t || t_temp < ray.rayhit.ray.tnear)) {
        t = t_temp;
        wall_number = 2;
    }

    t_temp = wall3.planeIntersect(ray);
    hit = pos + t_temp * dir;
    if (!(hit.z < 0 || hit.z > length || hit.x < 0 || hit.x > width || t_temp > t || t_temp < ray.rayhit.ray.tnear)) {
        t = t_temp;
        wall_number = 3;
    }

    t_temp = wall4.planeIntersect(ray);
    hit = pos + t_temp * dir;
    if (!(hit.z < 0 || hit.z > length || hit.y < 0 || hit.y > width || t_temp > t || t_temp < ray.rayhit.ray.tnear)) {
        t = t_temp;
        wall_number = 4;
    }

    t_temp = floor.planeIntersect(ray);
    hit = pos + t_temp * dir;
    if (!(hit.x < 0 || hit.x > width || hit.y < 0 || hit.y > width || t_temp > t || t_temp < ray.rayhit.ray.tnear)) {
        t = t_temp;
        wall_number = 5;
    }

    ray.rayhit.ray.tfar = (float) t;
    switch (wall_number) {
        case 1:
            ray.set_normal(Vec3fa(0, 1, 0));
            break;
        case 2:
            ray.set_normal(Vec3fa(-1, 0, 0));
            break;
        case 3:
            ray.set_normal(Vec3fa(0, -1, 0));
            break;
        case 4:
            ray.set_normal(Vec3fa(1, 0, 0));
            break;
        case 5:
            ray.set_normal(Vec3fa(0, 0, 1));
            break;

    }
    return wall_number;
}

double Pore::generateRandomDouble(double m, double n) {
    double uniform_number = sixte::getUniformRandomNumber();
    return m + (n-m) * uniform_number;
}

struct OrthonormalBasis {
  Vec3fa u; // local x axis
  Vec3fa v; // local y axis
  Vec3fa w; // local z axis (pore axis)

  static OrthonormalBasis from_normal_fixed_twist(const Vec3fa& n_in) {
    OrthonormalBasis b;
    b.w = normalize(n_in);  // pore axis

    // --- choose a global reference direction ---
    const Vec3fa globalX(1.0f, 0.0f, 0.0f);
    const Vec3fa globalY(0.0f, 1.0f, 0.0f);

    Vec3fa ref = globalX;
    // If w is almost parallel to global X, fall back to Y to avoid degeneracy
    if (std::abs(dot(b.w, ref)) > 0.999f)
      ref = globalY;

    // --- project the reference vector into the tangent plane of w ---
    // u is "global X as seen in the plane perpendicular to w"
    b.u = ref - dot(ref, b.w) * b.w;
    b.u = normalize(b.u);

    // v completes a right-handed orthonormal basis
    b.v = cross(b.w, b.u);

    return b;
  }

  // world -> pore (local)
  Vec3fa to_local(const Vec3fa& v_world) const {
    return Vec3fa(dot(v_world, u),
                  dot(v_world, v),
                  dot(v_world, w));
  }

  // pore (local) -> world
  Vec3fa to_world(const Vec3fa& v_local) const {
    return v_local.x * u +
           v_local.y * v +
           v_local.z * w;
  }
};

bool Pore::ray_trace(Ray &ray, int depth) {
  // World-space hit point on the spherical plate
  Vec3fa hit = ray.position();
  Vec3fa normal_exact = normalize(hit);  // pore axis direction

  // Build an ONB aligned with the pore axis
  OrthonormalBasis pore_frame =
      OrthonormalBasis::from_normal_fixed_twist(normal_exact);

  // Transform incoming ray direction into pore/local coordinates
  Vec3fa dir_local = pore_frame.to_local(ray.direction());
  ray.set_direction(normalize(dir_local));

  // We keep the old world-space entry position to restore later
  Vec3fa old_position = ray.position();

  // Start the ray inside the pore in LOCAL coordinates
  double x = generateRandomDouble(width, 0);
  double y = generateRandomDouble(width, 0);
  ray.set_position(Vec3fa(x, y, length));  // (x,y,z) in pore frame

  depth = 10;
  while (depth > 0) {
    int wall_number = findInterection(ray);

    if (wall_number == -1) {
      return false;
    }

    ray.raytracing_history.emplace_back(
        static_cast<short>(wall_number + 10),
        ray.position(),
        ray.direction()
    );

    // Successful transmission through the pore
    if (wall_number == 5) {
      // Compute local exit point (advance to the actual hit on z=0)
      Vec3fa local_exit = ray.position() + ray.rayhit.ray.tfar * ray.direction();

      // Map the local exit back to world space:
      //   local (0,0,length) = old_position (entry point on the plate)
      //   local (0,0,0)      = old_position - w*length  (exit-plane origin)
      Vec3fa world_origin = old_position - pore_frame.w * length;              // world point of local origin
      Vec3fa world_exit   = world_origin + pore_frame.to_world(local_exit);    // rotate+translate local exit
      ray.set_position(world_exit);

      // Transform outgoing direction back to world coordinates
      Vec3fa dir_world = pore_frame.to_world(ray.direction());
      ray.set_direction(normalize(dir_world));

      // Reset Embree ray state
      ray.reset_rayhit(20.0f);

      return true;
    }

    if (get_angle(-1 * ray.direction(), ray.normal()) < 1e-8)
      return false;

    if (surface_model_ != nullptr)
      if (!surface_model_->simulate_surface(ray))
        return false;

    if (!coating.doesReflect(ray.energy,
                             get_angle(-1 * ray.direction(), ray.normal()))) {
      return false;
    }

    reflect_ray(ray);

    --depth;
  }
  return false;
}

bool Pore::reflect_ray(Ray &ray) {
    ray.set_position(ray.position() + ray.rayhit.ray.tfar * ray.direction());
    ray.set_direction(reflect(ray.direction(), ray.normal()));
    ray.reset_rayhit();
    return true;
}
