#! /usr/bin/env python3

import subprocess
import os
import sys
sys.path.append('./data/scripts/python_wrapper/sixtesoft/')
import sixte
import numpy as np

sixte.check_pythonversion(3, 6)

from astropy.io import fits

evtfile = "evt_dither.fits"

print("*** running simulation to analyze ***")
proc = subprocess.run(
       ["sixtesim",
        "Simput=./data/simputs/gauss_lo.simput,./data/simputs/gauss_hi.simput",
        "XMLFile=./data/instruments/dummy/default_inst_sixtesim.xml",
        "RA=30.000",
        "Dec=30.000",
        "Attitude=att_swivel.fits",
        "EvtFile={}".format(evtfile),
        "Background=n",
        "Exposure=100",
        "clobber=y",
        "chatter=0", "progressbar=n"
        ], stdout=subprocess.PIPE, stderr=subprocess.PIPE
       )

# also need to run radec2xy
proc = subprocess.run(
        ["radec2xy",
         "EvtFile={}".format(evtfile),
         "RefRA=30", "RefDec=30", "chatter=0"
         ], stdout=subprocess.PIPE, stderr=subprocess.PIPE
        )

gtis = ['gti_lo.fits',
        'gti_hi.fits',
        'NONE', 'NONE']
regs = ['NONE', 'NONE',
        'lo.reg',
        'hi.reg']
midval = 5  # bin
above = [False, True, False, True]  # should all energies be > midval?
num_subtests = len(gtis)

for i_test in range(num_subtests):
    if regs[i_test] != 'NONE':
        testnam = regs[i_test]
    else:
        testnam = gtis[i_test]

    print(f"*** testing {testnam} ***")
    outfile = f"test_{i_test}.pha"
    proc = subprocess.run(["makespec",
                           "usepha=y", "chatter=0",
                           f"EvtFile={evtfile}",
                           f"Spectrum={outfile}",
                           "RSPPATH=./data/instruments/dummy",
                           "GTIfile={}".format(gtis[i_test]),
                           "regfile={}".format(regs[i_test])],
                          stdout=subprocess.PIPE,
                          stderr=subprocess.PIPE
                          )

    dat = fits.open(outfile)[1].data

    condition = False
    if (above[i_test]):
        condition = (0 != np.sum(dat["COUNTS"][midval:]))
        condition &= (0 == np.sum(dat["COUNTS"][:midval]))
    else:
        condition = (0 == np.sum(dat["COUNTS"][midval:]))
        condition &= (0 != np.sum(dat["COUNTS"][:midval]))

    if condition is False:
        print("*** failed! ***")
        exit(1)
    else:
        print("*** passed! ***")

    os.remove(outfile)

# clean output
os.remove(evtfile)
