#!/usr/bin/env python3

import argparse
import hashlib
import os
import sys
import uuid

import qdrant_client
from qdrant_client import models
import docling
from docling.chunking import HybridChunker
from docling.document_converter import DocumentConverter
from docling.datamodel.base_models import InputFormat

# Global Vars
EMBED_MODEL = os.getenv("EMBED_MODEL", "jinaai/jina-embeddings-v2-small-en")
SPARSE_MODEL = os.getenv("SPARSE_MODEL", "prithivida/Splade_PP_en_v1")
COLLECTION_NAME = "rag"

class Converter:
    """A Class designed to handle all document conversions using Docling"""

    def __init__(self, output, targets):
        self.doc_converter = DocumentConverter()
        self.targets = []
        for target in targets:
            self.add(target)
        self.output = output
        self.client = qdrant_client.QdrantClient(path=output)
        self.client.set_model(EMBED_MODEL)
        self.client.set_sparse_model(SPARSE_MODEL)
        # optimizations to reduce ram
        self.client.create_collection(
            collection_name=COLLECTION_NAME,
            vectors_config=self.client.get_fastembed_vector_params(on_disk=True),
            sparse_vectors_config= self.client.get_fastembed_sparse_vector_params(on_disk=True),
            quantization_config=models.ScalarQuantization(
                scalar=models.ScalarQuantizationConfig(
                    type=models.ScalarType.INT8,
                    always_ram=True,
                ),
            ),
        )

    def add(self, file_path):
        if os.path.isdir(file_path):
            self.walk(file_path)  # Walk directory and process all files
        else:
            self.targets.append(file_path)  # Process the single file

    def convert(self):
        result = self.doc_converter.convert_all(self.targets)
        documents, metadata, ids = [], [], []
        chunker = HybridChunker(tokenizer=EMBED_MODEL, overlap=100, merge_peers=True)
        for file in result:
            chunk_iter = chunker.chunk(dl_doc=file.document)
            for chunk in chunk_iter:
                # Extract the enriched text from the chunk
                doc_text = chunker.contextualize(chunk=chunk)
                # Append to respective lists
                documents.append(doc_text)
                # Generate unique ID for the chunk
                doc_id = self.generate_hash(doc_text)
                ids.append(doc_id)
        return self.client.add(COLLECTION_NAME, documents=documents, ids=ids, batch_size=1)

    def walk(self, path):
        for root, dirs, files in os.walk(path, topdown=True):
            if len(files) == 0:
                continue
            for f in files:
                file = os.path.join(root, f)
                if os.path.isfile(file):
                    self.targets.append(file)

    def generate_hash(self, document: str) -> str:
        """Generate a unique hash for a document."""
        sha256_hash = hashlib.sha256(document.encode('utf-8')).hexdigest()

        # Use the first 32 characters of the hash to create a UUID
        return str(uuid.UUID(sha256_hash[:32]))

def load():
    # Dummy code to preload models
    client = qdrant_client.QdrantClient(":memory:")
    client.set_model(EMBED_MODEL)
    client.set_sparse_model(SPARSE_MODEL)
    converter = DocumentConverter()
    converter.initialize_pipeline(InputFormat.PDF)

parser = argparse.ArgumentParser(
    prog="docling",
    description="process source files into RAG vector database",
)

parser.add_argument("target", nargs="?", help="Target database")
parser.add_argument("source", nargs="*", help="Source files")

def perror(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


def eprint(e, exit_code):
    perror("Error: " + str(e).strip("'\""))
    sys.exit(exit_code)

try:
    args = parser.parse_args()
    if args.target == "load":
        load()
    else:
        converter = Converter(args.target, args.source)
        converter.convert()
except docling.exceptions.ConversionError as e:
    eprint(e, 1)
except FileNotFoundError as e:
    eprint(e, 1)
except ValueError as e:
    eprint(e, 1)
except KeyboardInterrupt:
    pass
