// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd.
//
// SPDX-License-Identifier: GPL-3.0-or-later

#include "llamacppplugin.h"
#include "llamacppmodelconfig.h"
#include "llamallmproxy.h"
#include "llamaembproxy.h"

#include "modelinfo.h"

#include <QDir>
#include <QDebug>

#include "llama.h"
#include "common.h"

GLOBAL_USE_NAMESPACE

#define COMPARE_ARCH(src, trgt) QString::compare(src, trgt, Qt::CaseInsensitive) == 0

static void hook_ggml_log_callback(enum ggml_log_level level, const char *text, void * user_data) {
    qDebug() << "level :" << level << text;
}

LlamacppPlugin::LlamacppPlugin(QObject *parent) : QObject(parent), InferencePlugin()
{

}

LlamacppPlugin::~LlamacppPlugin()
{
    if (inited)
        llama_backend_free();
}

bool LlamacppPlugin::initialize(const QVariantHash &params)
{
    if (inited)
        return true;

    // disable llama.log
    log_disable();

    // Redirect logs
    //llama_log_set(hook_ggml_log_callback, nullptr);

    llama_backend_init();

    //todo
    llama_numa_init(GGML_NUMA_STRATEGY_DISABLED);

    inited = true;
    return true;
}

ModelProxy *LlamacppPlugin::loadModel(const QString &name, const QString &imgDir, const QVariantHash &params)
{
    QDir dir(imgDir);
    LlamacppModelConfig cfg (dir.absoluteFilePath("config.json"));
    LlamaModelWrapper *llamaModel = createModelWrapper(name, cfg.architectures());
    auto mParams = cfg.params();
    for (auto it = params.begin(); it != params.end(); ++it) {
        if (!mParams.contains(it.key()))
            mParams.insert(it.key(), it.value());
    }

    if (llama_supports_gpu_offload() && !mParams.contains("--n-gpu-layers"))
        mParams.insert("--n-gpu-layers", "60"); // gpu

    if (!llamaModel->initialize(dir.absoluteFilePath(cfg.bin()), mParams)) {
        delete llamaModel;
        return nullptr;
    }

    return dynamic_cast<ModelProxy *>(llamaModel);
}

LlamaModelWrapper *LlamacppPlugin::createModelWrapper(const QString &name, const QStringList &archs)
{
    LlamaModelWrapper *ret = nullptr;
    for (const QString &arch : archs) {
        if (COMPARE_ARCH(arch, "LLM")) {
            ret = new LlamaLLMProxy(name.toStdString());
        } else if (COMPARE_ARCH(arch, "Embedding")) {
            ret = new LlamaEmbProxy(name.toStdString());
        }
    }

    return ret;
}
