#!/usr/bin/python
import os
import sys
import subprocess
import signal
import logging
import ctypes
from ctypes.util import find_library
from lutris.util.monitor import ProcessMonitor
from lutris.util.log import logger


PR_SET_CHILD_SUBREAPER = 36


def set_child_subreaper():
    """Sets the current process to a subreaper.

    A subreaper fulfills the role of init(1) for its descendant
    processes.  When a process becomes orphaned (i.e., its
    immediate parent terminates) then that process will be
    reparented to the nearest still living ancestor subreaper.
    Subsequently, calls to getppid() in the orphaned process will
    now return the PID of the subreaper process, and when the
    orphan terminates, it is the subreaper process that will
    receive a SIGCHLD signal and will be able to wait(2) on the
    process to discover its termination status.

    The setting of this bit is not inherited by children created
    by fork(2) and clone(2).  The setting is preserved across
    execve(2).

    Establishing a subreaper process is useful in session
    management frameworks where a hierarchical group of processes
    is managed by a subreaper process that needs to be informed
    when one of the processes—for example, a double-forked daemon—
    terminates (perhaps so that it can restart that process).
    Some init(1) frameworks (e.g., systemd(1)) employ a subreaper
    process for similar reasons.
    """
    result = ctypes.CDLL(find_library('c')).prctl(PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0, 0)
    if result == -1:
        print("PR_SET_CHILD_SUBREAPER failed, process watching may fail")


def log(line):
    """Generic log function that can be adjusted for any log output method
    (stdout, file, logging, t2s, Discord, ...)
    """
    try:
        sys.stdout.write(line + "\n")
        sys.stdout.flush()
    except BrokenPipeError:
        pass

    # File output example
    # with open(os.path.expanduser("~/lutris.log"), "a") as logfile:
    #     logfile.write(line)
    #     logfile.write("\n")


def main():
    """Runs a command independently from the Lutris client"""
    set_child_subreaper()
    _, include_proc_count, exclude_proc_count, *args = sys.argv

    # So I'm too lazy to implement real argument parsing... sorry.
    include_proc_count = int(include_proc_count)
    exclude_proc_count = int(exclude_proc_count)
    include_procs, args = args[:include_proc_count], args[include_proc_count:]
    exclude_procs, args = args[:exclude_proc_count], args[exclude_proc_count:]

    if "PYTHONPATH" in os.environ:
        del os.environ["PYTHONPATH"]
    monitor = ProcessMonitor(include_procs, exclude_procs)

    def hard_sig_handler(signum, _frame):
        log("Caught another signal, sending SIGKILL.")
        for _ in range(3):  # just in case we race a new process.
            monitor.refresh_process_status()
            children = monitor.children + monitor.ignored_children
            if not children:
                break
            for child in children:
                os.kill(child.pid, signal.SIGKILL)
        log("--killed processes--")

    def sig_handler(signum, _frame):
        log("Caught signal %s" % signum)
        signal.signal(signal.SIGTERM, hard_sig_handler)
        signal.signal(signal.SIGINT, hard_sig_handler)
        monitor.refresh_process_status()
        for child in monitor.children:
            log("passing along signal to PID %s" % child.pid)
            os.kill(child.pid, signum)
        log("--terminated processes--")

    old_sigterm_handler = signal.signal(signal.SIGTERM, sig_handler)
    old_sigint_handler = signal.signal(signal.SIGINT, sig_handler)

    log("Running %s" % " ".join(args))
    returncode = subprocess.run(args).returncode
    try:
        while True:
            log("Waiting on children")
            os.wait3(0)
            if not monitor.refresh_process_status():
                log("All children gone")
                break
    except ChildProcessError:
        # If the game itself has quit then
        # this process has no children
        pass
    log("Exit with returncode %s" % returncode)
    sys.exit(returncode)


if __name__ == "__main__":
    LAUNCH_PATH = os.path.dirname(os.path.realpath(__file__))
    if not LAUNCH_PATH.startswith("/usr"):
        logger.setLevel(logging.DEBUG)
        sys.dont_write_bytecode = True
        SOURCE_PATH = os.path.normpath(os.path.join(LAUNCH_PATH, '..'))
        sys.path.insert(0, SOURCE_PATH)

    main()
