// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd.
//
// SPDX-License-Identifier: GPL-3.0-or-later

#include "llamamodelwrapper.h"

#include "llama.h"
#include "common/common.h"

#include <QVariantHash>

GLOBAL_USE_NAMESPACE

LlamaModelWrapper::LlamaModelWrapper()
{

}

LlamaModelWrapper::~LlamaModelWrapper()
{
    if (gModel)
        llama_free_model(gModel);
    gModel = nullptr;

    if (gCtx)
        llama_free(gCtx);
    gCtx = nullptr;

    if (gParams)
        delete gParams;
    gParams = nullptr;
}

bool LlamaModelWrapper::initialize(const QString &bin, const QVariantHash &params)
{
    if (gModel)
        return false;
    gParams = new gpt_params;
    int argc = 1; // 从第二个开始
    char argv[128][128] = {0};
    char *ptr[128] = {0};
    for (auto it = params.begin(); it != params.end() && argc < 128; ++it) {
        auto str = it.key().toStdString();
        memcpy(argv[argc], str.c_str(), str.length());
        ptr[argc] = argv[argc];
        argc++;
        str = it.value().toString().toStdString();
        if (!str.empty()) {
            memcpy(argv[argc], str.c_str(), str.length());
            ptr[argc] = argv[argc];
            argc++;
        }
    }

    std::string sysInfo = llama_print_system_info();
    sysInfo += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | ";
    sysInfo += "CLAST = " + std::to_string(ggml_cpu_has_clblast()) + " | ";

    std::cerr << "system info: "<< sysInfo << std::endl;

    // todo covert params to argv
    if (!gpt_params_parse(argc, ptr, *gParams))
        return false;


    gParams->model = bin.toStdString();
    std::tie(gModel, gCtx) = llama_init_from_gpt_params(*gParams);
    return gModel;
}
