// SPDX-FileCopyrightText: 2024 UnionTech Software Technology Co., Ltd.
//
// SPDX-License-Identifier: GPL-3.0-or-later

#include "command.h"

#include "modelrepo.h"
#include "backendloader.h"
#include "modelinfo.h"
#include "inferenceplugin.h"
#include "llmproxy.h"
#include "embeddingproxy.h"
#include "modelserver.h"
#include "runtimestate.h"
#include "modelrunner.h"
#include "util.h"

#include <QJsonDocument>
#include <QJsonObject>
#include <QJsonArray>
#include <QVariantHash>

#include <QDebug>

#include <iostream>

GLOBAL_USE_NAMESPACE

Command::Command(QObject *parent) : QObject(parent), QCommandLineParser()
{
    initOptions();
}

int Command::processCmd(const QCoreApplication &app)
{
    process(app);
    int ret = 0;
    if (isSet("list")) {
        ret = listHandler();
    } else if (isSet("embeddings")) {
        ret = embeddings();
    } else if (isSet("run")){
        ret = runServer();
    }else if (isSet("model")){
        ret = llmGenerate();
    } else if (isSet("v") || isSet("version")) {
        std::cout << qApp->applicationVersion().toStdString() << std::endl;
    } else if(isSet("stop")) {
        ret = stopServer();
    } else {
        showHelp(0);
    }

    return ret;
}

int Command::listHandler()
{
    bool info = isSet("info");
    auto target = value("list");
    QVariantHash vh;
    if (target == "model") {
        ModelRepo repo;
        auto models = repo.list();
        vh.insert("model", models);
    } else if (target == "backend"){
        BackendLoader loader;
        loader.readBackends();
        QStringList names;
        for (BackendMetaObjectPointer obj : loader.backends()) {
            names.append(obj->name());
        }
        vh.insert("backend", names);
    } else if (target == "server") {
        auto tmp = RuntimeState::listAll();
        if (info) {
            QVariantList infos;
            for (auto tvh :tmp)
                infos.append(tvh);
            vh.insert("serverinfo", infos);
        } else {
            QStringList names;
            for (auto v :tmp) {
                QString m = v.value("model").toString();
                if (!m.isEmpty())
                    names.append(m);
            }
            vh.insert("server", names);
        }
    } else {
        showHelp(1);
    }

    QJsonDocument doc(QJsonObject::fromVariantHash(vh));
    QString out = QString::fromUtf8(doc.toJson(QJsonDocument::Compact));

    std::cout << out.toStdString() << std::endl;
    return 0;
}

bool Command::llmStreamOutput(const std::string &text, void *llm)
{
    printf("%s", text.c_str());
    fflush(stdout);
    return true;
}

int Command::llmGenerate()
{
    QString modelName = value("model");
    QString prompt = value("prompt");

    if (modelName.isEmpty() || prompt.isEmpty()) {
        showHelp(1);
    }

    ModelRunner mr;
    int ret = loadModel(modelName, &mr);
    if (ret != 0)
        return ret;

    if (auto llm = dynamic_cast<LLMProxy *>(mr.modelProxy.data())) {
       auto tokens = llm->tokenize(prompt.toStdString());
       auto token = llm->generate(tokens, {}, llmStreamOutput, llm);
    } else {
        std::cerr << QString("%0 do not support to generate").arg(modelName).toStdString() << std::endl;
    }

    return 0;
}

int Command::embeddings()
{
    QString modelName = value("model");
    QString prompt = value("prompt");

    if (modelName.isEmpty() || prompt.isEmpty()) {
        showHelp(1);
    }

    ModelRunner mr;
    int ret = loadModel(modelName, &mr);
    if (ret != 0)
        return ret;
    if (auto emb = dynamic_cast<EmbeddingProxy *>(mr.modelProxy.data())) {
        std::list<std::vector<int32_t>> tokens = emb->tokenize({prompt.toStdString()});
        std::list<std::vector<float>> out = emb->embedding(tokens);

        // return json as openai embeddings api.
        QJsonObject root;
        QJsonArray arry;
        int i = 0;
        for (auto it = out.begin(); it != out.end(); ++it) {
            QJsonObject embObj;
            embObj.insert("object", "embedding");
            embObj.insert("index", i++);
            QJsonArray embValue;
            for (const float &v : *it)
                embValue.append(v);
            embObj.insert("embedding", embValue);
            arry.append(embObj);
        }

        root.insert("data", arry);
        root.insert("model", modelName);
        root.insert("object", "list");
        std::cout << QJsonDocument(root).toJson(QJsonDocument::Compact).toStdString() << std::endl;
    } else {
        std::cerr << QString("%0 do not support to embedding").arg(modelName).toStdString() << std::endl;
    }

    return 0;
}

int Command::runServer()
{
    QString modelName = value("run");
    if (modelName.isEmpty())
        showHelp(1);

    ModelRunner mr;
    int ret = loadModel(modelName, &mr);
    if (ret != 0)
        return ret;

    ModelServer serve;
    {
        serve.setHost(isSet("host") ? value("host") : "");
        bool ok = false;
        int port = isSet("port") ? value("port").toInt(&ok) : -1;
        if (ok)
            serve.setPort(port);

        int pid = serve.instance(modelName);
        if (pid > 0) {
            std::cerr << QString("model: %0 is runned as process: %1").arg(modelName).arg(pid).toStdString() << std::endl;
            return 1;
        }
    }

    if (isSet("exit-idle")) {
        serve.setIdle(value("exit-idle").toInt());
    }
    if (!serve.run(&mr))
        return 1;

    // enter event loop;
    return qApp->exec();
}

int Command::stopServer()
{
    QString modelName = value("stop");
    if (modelName.isEmpty())
        showHelp(1);
    RuntimeState rs(modelName);
    int pid = rs.pid();
    if (pid > 0) {
        std::cerr << "stop server " << modelName.toStdString() << " pid " << pid << std::endl;
        system(QString("kill -3 %0").arg(pid).toStdString().c_str());
    }
    return 0;
}

int Command::loadModel(const QString &modelName, ModelRunner *runner)
{
    ModelRepo repo;
    ModelInfoPointer info = repo.modelInfo(modelName);
    if (info.isNull()) {
        std::cerr <<  QString("no such model:%0").arg(modelName).toStdString() << std::endl;
        return 1;
    }

    BackendLoader bkLoader;
    bkLoader.readBackends();

    QString format;
    BackendMetaObjectPointer bmo = bkLoader.perfect(info, &format);
    if (bmo.isNull()) {
        std::cerr << QString("no backend for model: %1").arg(modelName).toStdString() << std::endl;
        return 1;
    }

    auto backend = bkLoader.load(bmo);
    if (backend.isNull()) {
        return 1;
    }

    backend->initialize({});

    ModelProxy *mp = backend->loadModel(info->name(), info->imagePath(format), {});
    if (!mp) {
        std::cerr << QString("backend %0 unable to load model: %1 with format %2").arg(bmo->name())
                     .arg(modelName).arg(format).toStdString() << std::endl;
        return -1;
    }

    runner->modelInfo = info;
    runner->modelFormat = format;
    runner->chatTmpl = info->chatTemplate(format);
    runner->backendmo = bmo;
    runner->backendIns = backend;
    runner->modelProxy.reset(mp);

    return 0;
}

void Command::initOptions()
{
    QCommandLineOption listOption("list", "List {model|backend|server} [--info]", "option");
    QCommandLineOption showOption("show", "Show {model|backend|server}", "option");
    QCommandLineOption runOption("run", "Run a model", "model");
    QCommandLineOption hostOption("host", "Server host", "host");
    QCommandLineOption portOption("port", "Server port", "port");
    QCommandLineOption stopOption("stop", "Stop a running model", "model");
    QCommandLineOption modelOption("model", "Model name", "name");

    QCommandLineOption embeddingOption("embeddings", "Get embeddings for inputs");
    QCommandLineOption promptOption("prompt", "Input prompt for LLM", "json");
    QCommandLineOption infoOption("info", "Option for list");
    QCommandLineOption idleOption("exit-idle", "Exit the process when idle time is up. The minimum value is 10 seconds", "seconds");
    QCommandLineOption versionOption({"v","version"}, "Display version");


    addOption(listOption);
    addOption(infoOption);
    addOption(runOption);
    addOption(hostOption);
    addOption(portOption);
    addOption(stopOption);

    addOption(modelOption);
    addOption(promptOption);

    addOption(embeddingOption);
    addOption(idleOption);
    addOption(versionOption);

    addHelpOption();
}
