#!/usr/bin/env python3

import argparse
import atexit
import os
import signal
import subprocess
import sys
from argparse import Namespace


def initialize_parser():
    parser = argparse.ArgumentParser(description="Run ramalama run core")
    parser.add_argument(
        "-c",
        "--ctx-size",
        dest="context",
        default=2048,
        help="size of the prompt context (0 = loaded from model)",
    )
    parser.add_argument("--jinja", action="store_true", help="enable jinja")
    parser.add_argument("--temp", default=0.8, help="temperature of the response from the AI model")
    parser.add_argument(
        "--ngl",
        dest="ngl",
        type=int,
        help="number of layers to offload to the gpu, if available",
    )
    parser.add_argument(
        "-t",
        "--threads",
        type=int,
        help=f"number of cpu threads to use",
    )
    parser.add_argument(
        "-v",
        help=f"verbose",
        action="store_true",
    )
    parser.add_argument("MODEL", type=str, help="Path to the model")  # positional argument
    parser.add_argument(
        "ARGS", nargs="*", help="overrides the default prompt, and the output is returned without entering the chatbot"
    )

    return parser


def initialize_args():
    from ramalama.model import get_available_port_if_any

    parser = initialize_parser()
    parsed_args = parser.parse_args()
    port = get_available_port_if_any(False)

    return parsed_args, port


def main(args):
    sys.path.append('./')
    from ramalama.cli import serve_cli
    from ramalama.common import exec_cmd

    parsed_args, port = initialize_args()

    pid = os.fork()
    if pid == 0:
        signal.pthread_sigmask(signal.SIG_BLOCK, {signal.SIGINT})
        # at this point we are already in a container so, --nocontainer always makes sense
        exec_cmd(["ramalama", "--nocontainer", "serve", "--threads", str(parsed_args.threads), "--temp", parsed_args.temp, "--port", str(port), parsed_args.MODEL], stdout2null=True, stderr2null=True)

        return 0

    client_args = build_args(args, "ramalama-client-core")
    client_args += ['-c', parsed_args.context, '--temp', parsed_args.temp, "--kill-server", str(pid), "http://127.0.0.1:" + str(port)] + parsed_args.ARGS
    exec_cmd(client_args)

    return 0

def build_args(args, new_command):
    new_args = args.copy()
    new_args[0] = new_args[0].replace("ramalama-run-core", new_command)
    new_args = [new_args[0]]

    return new_args

if __name__ == "__main__":
    main(sys.argv)
