#!/usr/bin/env python3

import argparse
import cmd
import itertools
import json
import os
import signal
import sys
import time
import urllib.error
import urllib.request


def construct_request_data(conversation_history):
    data = {
        "stream": True,
        "messages": conversation_history,
    }

    return data


def should_colorize():
    t = os.getenv("TERM")
    return t and t != "dumb" and sys.stdout.isatty()


def res(response, color):
    color_default = ""
    color_yellow = ""
    if (color == "auto" and should_colorize()) or color == "always":
        color_default = "\033[0m"
        color_yellow = "\033[33m"

    print("\r", end="")
    assistant_response = ""
    for line in response:
        line = line.decode("utf-8").strip()
        if line.startswith("data: {"):
            line = line[len("data: ") :]
            choice = json.loads(line)["choices"][0]["delta"]
            if "content" in choice:
                choice = choice["content"]
            else:
                continue

            print(f"{color_yellow}{choice}{color_default}", end="", flush=True)
            assistant_response += choice

    print("")
    return assistant_response


def req(conversation_history, url, parsed_args):
    data = construct_request_data(conversation_history)
    json_data = json.dumps(data).encode("utf-8")
    headers = {
        "Content-Type": "application/json",
    }

    # Create a request
    request = urllib.request.Request(url, data=json_data, headers=headers, method="POST")

    # Send request
    i = 0.01
    response = None
    for c in itertools.cycle(['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']):
        try:
            response = urllib.request.urlopen(request)
            break
        except Exception:
            if sys.stdout.isatty():
                print(f"\r{c}", end="", flush=True)

            if i > 32:
                break

            time.sleep(i)
            i *= 2

    if response:
        return res(response, parsed_args.color)

    from ramalama.common import perror
    perror(f"\rError: could not connect to: {url}")
    do_kills(parsed_args)

    return None

class RamaLamaShell(cmd.Cmd):
    def __init__(self, parsed_args):
        super().__init__()
        self.conversation_history = []
        self.parsed_args = parsed_args
        self.request_in_process = False
        if "LLAMA_PROMPT_PREFIX" in os.environ:
            self.prompt = os.environ["LLAMA_PROMPT_PREFIX"]
        else:
            self.prompt = parsed_args.prefix

        self.url = f"{parsed_args.host}/v1/chat/completions"

    def do_EOF(self, user_content):
        print("")
        return True

    def default(self, user_content):
        if user_content in ["/bye", "exit"]:
            return True

        self.conversation_history.append({"role": "user", "content": user_content})
        self.request_in_process = True
        response = req(self.conversation_history, self.url, self.parsed_args)
        if not response:
            return True

        self.conversation_history.append(
            {"role": "assistant", "content": response}
        )
        self.request_in_process = False

def do_kills(parsed_args):
    if parsed_args.pid2kill:
        os.kill(parsed_args.pid2kill, signal.SIGINT)
        os.kill(parsed_args.pid2kill, signal.SIGTERM)
        os.kill(parsed_args.pid2kill, signal.SIGKILL)

def parse_arguments(args):
    parser = argparse.ArgumentParser(description="Run ramalama client core")
    parser.add_argument(
        '--color',
        '--colour',
        default="auto",
        choices=['never', 'always', 'auto'],
        help='possible values are "never", "always" and "auto".',
    )
    parser.add_argument("--prefix", type=str, default="> ", help="prefix for the user prompt")
    parser.add_argument(
        "host", type=str, nargs="?", default="http://127.0.0.1:8080", help="the host to send requests to"
    )
    parser.add_argument(
        "-c",
        "--ctx-size",
        dest="context",
        default=2048,
        help="size of the prompt context (0 = loaded from model)",
    )
    parser.add_argument("--jinja", action="store_true", help="enable jinja")
    parser.add_argument("--kill-server", dest="pid2kill", type=int, help="server process to kill on termination of client")
    parser.add_argument("--temp", default=0.8, help="temperature of the response from the AI model")
    parser.add_argument(
        "ARGS", nargs="*", help="overrides the default prompt, and the output is returned without entering the chatbot"
    )

    return parser.parse_args(args)

def handle_args(parsed_args, ramalama_shell):
    if parsed_args.ARGS:
        ramalama_shell.default(" ".join(parsed_args.ARGS))
        do_kills(parsed_args)
        return True

    return False

def run_shell_loop(ramalama_shell):
    while True:
        ramalama_shell.request_in_process = False
        try:
            ramalama_shell.cmdloop()
        except KeyboardInterrupt:
            print("")
            if not ramalama_shell.request_in_process:
                print("Use Ctrl + d or /bye or exit to quit.")

            continue

        break

def main(args):
    sys.path.append('./')

    parsed_args = parse_arguments(args)
    ramalama_shell = RamaLamaShell(parsed_args)
    if handle_args(parsed_args, ramalama_shell):
        return 0

    run_shell_loop(ramalama_shell)
    do_kills(parsed_args)


if __name__ == '__main__':
    main(sys.argv[1:])
