// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "luau.pb.h"

static const std::string kNames[] = {
    "_G",           "_VERSION",   "__add",      "__call",         "__concat",    "__div",        "__eq",      "__idiv",     "__index",
    "__iter",       "__le",       "__len",      "__lt",           "__mod",       "__mode",       "__mul",     "__namecall", "__newindex",
    "__pow",        "__sub",      "__type",     "__unm",          "abs",         "acos",         "arshift",   "asin",       "assert",
    "atan",         "atan2",      "band",       "bit32",          "bnot",        "boolean",      "bor",       "btest",      "buffer",
    "bxor",         "byte",       "ceil",       "char",           "charpattern", "clamp",        "clear",     "clock",      "clone",
    "close",        "codepoint",  "codes",      "collectgarbage", "concat",      "copy",         "coroutine", "cos",        "cosh",
    "countlz",      "countrz",    "create",     "date",           "debug",       "deg",          "difftime",  "error",      "exp",
    "extract",      "fill",       "find",       "floor",          "fmod",        "foreach",      "foreachi",  "format",     "freeze",
    "frexp",        "fromstring", "function",   "gcinfo",         "getfenv",     "getmetatable", "getn",      "gmatch",     "gsub",
    "huge",         "info",       "insert",     "ipairs",         "isfrozen",    "isyieldable",  "ldexp",     "len",        "loadstring",
    "log",          "log10",      "lower",      "lrotate",        "lshift",      "match",        "math",      "max",        "maxn",
    "min",          "modf",       "move",       "newproxy",       "next",        "nil",          "noise",     "number",     "offset",
    "os",           "pack",       "packsize",   "pairs",          "pcall",       "pi",           "pow",       "print",      "rad",
    "random",       "randomseed", "rawequal",   "rawget",         "rawlen",      "rawset",       "readf32",   "readf64",    "readi16",
    "readi32",      "readi8",     "readstring", "readu16",        "readu32",     "readu8",       "remove",    "rep",        "replace",
    "require",      "resume",     "reverse",    "round",          "rrotate",     "rshift",       "running",   "select",     "setfenv",
    "setmetatable", "sign",       "sin",        "sinh",           "sort",        "split",        "sqrt",      "status",     "string",
    "sub",          "table",      "tan",        "tanh",           "thread",      "time",         "tonumber",  "tostring",   "tostring",
    "traceback",    "type",       "typeof",     "unpack",         "upper",       "userdata",     "utf8",      "vector",     "wrap",
    "writef32",     "writef64",   "writei16",   "writei32",       "writei8",     "writestring",  "writeu16",  "writeu32",   "writeu8",
    "xpcall",       "yield",
};

static const std::string kTypes[] = {
    "any",
    "boolean",
    "buffer",
    "nil",
    "number",
    "string",
    "thread",
    "vector",
};

static const std::string kClasses[] = {
    "Vector3",
    "Instance",
    "Part",
};

struct ProtoToLuau
{
    struct Function
    {
        int loops = 0;
        bool vararg = false;
    };

    std::string source;
    std::vector<Function> functions;
    bool types = false;

    ProtoToLuau()
    {
        Function top = {};
        top.vararg = true;
        functions.push_back(top);
    }

    void ident(const luau::Name& name)
    {
        if (name.has_builtin())
        {
            size_t index = size_t(name.builtin()) % std::size(kNames);
            source += kNames[index];
        }
        else if (name.has_custom())
        {
            source += 'n';
            source += std::to_string(name.custom() & 0xff);
        }
        else
        {
            source += '_';
        }
    }

    void ident(const luau::Typename& name)
    {
        source += 't';
        source += std::to_string(name.index() & 0xff);
    }

    template<typename T>
    void genericidents(const T& node)
    {
        if (node.generics_size() || node.genericpacks_size())
        {
            source += '<';
            bool first = true;

            for (size_t i = 0; i < node.generics_size(); ++i)
            {
                if (!first)
                    source += ',';
                first = false;
                ident(node.generics(i));
            }

            for (size_t i = 0; i < node.genericpacks_size(); ++i)
            {
                if (!first)
                    source += ',';
                first = false;
                ident(node.genericpacks(i));
                source += "...";
            }

            source += '>';
        }
    }

    void print(const luau::Expr& expr)
    {
        if (expr.has_group())
            print(expr.group());
        else if (expr.has_nil())
            print(expr.nil());
        else if (expr.has_bool_())
            print(expr.bool_());
        else if (expr.has_number())
            print(expr.number());
        else if (expr.has_string())
            print(expr.string());
        else if (expr.has_local())
            print(expr.local());
        else if (expr.has_global())
            print(expr.global());
        else if (expr.has_varargs())
            print(expr.varargs());
        else if (expr.has_call())
            print(expr.call());
        else if (expr.has_index_name())
            print(expr.index_name());
        else if (expr.has_index_expr())
            print(expr.index_expr());
        else if (expr.has_function())
            print(expr.function());
        else if (expr.has_table())
            print(expr.table());
        else if (expr.has_unary())
            print(expr.unary());
        else if (expr.has_binary())
            print(expr.binary());
        else if (expr.has_ifelse())
            print(expr.ifelse());
        else if (expr.has_interpstring())
            print(expr.interpstring());
        else
            source += "_";
    }

    void print(const luau::ExprPrefix& expr)
    {
        if (expr.has_group())
            print(expr.group());
        else if (expr.has_local())
            print(expr.local());
        else if (expr.has_global())
            print(expr.global());
        else if (expr.has_call())
            print(expr.call());
        else if (expr.has_index_name())
            print(expr.index_name());
        else if (expr.has_index_expr())
            print(expr.index_expr());
        else
            source += "_";
    }

    void print(const luau::ExprGroup& expr)
    {
        source += '(';
        print(expr.expr());
        source += ')';
    }

    void print(const luau::ExprConstantNil& expr)
    {
        source += "nil";
    }

    void print(const luau::ExprConstantBool& expr)
    {
        source += expr.val() ? "true" : "false";
    }

    void print(const luau::ExprConstantNumber& expr)
    {
        source += std::to_string(expr.val());
    }

    void print(const luau::ExprConstantString& expr)
    {
        source += '"';
        for (char ch : expr.val())
            if (isalpha(ch))
                source += ch;
        source += '"';
    }

    void print(const luau::Local& var)
    {
        source += 'l';
        source += std::to_string(var.name() & 0xff);
    }

    void print(const luau::ExprLocal& expr)
    {
        print(expr.var());
    }

    void print(const luau::ExprGlobal& expr)
    {
        ident(expr.name());
    }

    void print(const luau::ExprVarargs& expr)
    {
        if (functions.back().vararg)
            source += "...";
        else
            source += "_";
    }

    void print(const luau::ExprCall& expr)
    {
        if (expr.func().has_index_name())
            print(expr.func().index_name(), expr.self());
        else
            print(expr.func());
        source += '(';
        for (int i = 0; i < expr.args_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(expr.args(i));
        }
        source += ')';
    }

    void print(const luau::ExprIndexName& expr, bool self = false)
    {
        print(expr.expr());
        source += self ? ':' : '.';
        ident(expr.index());
    }

    void print(const luau::ExprIndexExpr& expr)
    {
        print(expr.expr());
        source += '[';
        print(expr.index());
        source += ']';
    }

    void function(const luau::ExprFunction& expr)
    {
        genericidents(expr);
        source += "(";
        for (int i = 0; i < expr.args_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(expr.args(i));

            if (types && i < expr.types_size())
            {
                source += ':';
                print(expr.types(i));
            }
        }
        if (expr.vararg())
        {
            if (expr.args_size())
                source += ',';
            source += "...";
        }
        source += ')';
        if (types && expr.rettypes_size())
        {
            source += ':';
            if (expr.rettypes_size() > 1)
                source += '(';
            for (size_t i = 0; i < expr.rettypes_size(); ++i)
            {
                if (i != 0)
                    source += ',';
                print(expr.rettypes(i));
            }
            if (expr.rettypes_size() > 1)
                source += ')';
        }
        source += '\n';

        Function func = {};
        func.vararg = expr.vararg();
        functions.push_back(func);

        print(expr.body());

        functions.pop_back();

        source += "end";
    }

    void print(const luau::ExprFunction& expr)
    {
        source += "function";
        function(expr);
    }

    void print(const luau::ExprTable& expr)
    {
        source += '{';
        for (int i = 0; i < expr.items_size(); ++i)
        {
            if (expr.items(i).has_key_name())
            {
                ident(expr.items(i).key_name());
                source += '=';
            }
            else if (expr.items(i).has_key_expr())
            {
                source += "[";
                print(expr.items(i).key_expr());
                source += "]=";
            }

            print(expr.items(i).value());
            source += ',';
        }
        source += '}';
    }

    void print(const luau::ExprUnary& expr)
    {
        if (expr.op() == luau::ExprUnary::Not)
            source += "not ";
        else if (expr.op() == luau::ExprUnary::Minus)
            source += "- ";
        else if (expr.op() == luau::ExprUnary::Len)
            source += "# ";

        print(expr.expr());
    }

    void print(const luau::ExprBinary& expr)
    {
        print(expr.left());

        if (expr.op() == luau::ExprBinary::Add)
            source += " + ";
        else if (expr.op() == luau::ExprBinary::Sub)
            source += " - ";
        else if (expr.op() == luau::ExprBinary::Mul)
            source += " * ";
        else if (expr.op() == luau::ExprBinary::Div)
            source += " / ";
        else if (expr.op() == luau::ExprBinary::FloorDiv)
            source += " // ";
        else if (expr.op() == luau::ExprBinary::Mod)
            source += " % ";
        else if (expr.op() == luau::ExprBinary::Pow)
            source += " ^ ";
        else if (expr.op() == luau::ExprBinary::Concat)
            source += " .. ";
        else if (expr.op() == luau::ExprBinary::CompareNe)
            source += " ~= ";
        else if (expr.op() == luau::ExprBinary::CompareEq)
            source += " == ";
        else if (expr.op() == luau::ExprBinary::CompareLt)
            source += " < ";
        else if (expr.op() == luau::ExprBinary::CompareLe)
            source += " <= ";
        else if (expr.op() == luau::ExprBinary::CompareGt)
            source += " > ";
        else if (expr.op() == luau::ExprBinary::CompareGe)
            source += " >= ";
        else if (expr.op() == luau::ExprBinary::And)
            source += " and ";
        else if (expr.op() == luau::ExprBinary::Or)
            source += " or ";

        print(expr.right());
    }

    void print(const luau::ExprIfElse& expr)
    {
        source += "if ";
        print(expr.cond());
        source += " then ";
        print(expr.then());

        if (expr.has_else_())
        {
            source += " else ";
            print(expr.else_());
        }
        else if (expr.has_elseif())
        {
            source += " else";
            print(expr.elseif());
        }
    }

    void print(const luau::ExprInterpString& expr)
    {
        source += "`";

        for (int i = 0; i < expr.parts_size(); ++i)
        {
            if (expr.parts(i).has_string())
            {
                // String literal is added surrounded with "", but that's ok
                print(expr.parts(i));
            }
            else
            {
                source += "{";
                print(expr.parts(i));
                source += "}";
            }
        }

        source += "`";
    }

    void print(const luau::LValue& expr)
    {
        if (expr.has_local())
            print(expr.local());
        else if (expr.has_global())
            print(expr.global());
        else if (expr.has_index_name())
            print(expr.index_name());
        else if (expr.has_index_expr())
            print(expr.index_expr());
        else
            source += "_";
    }

    void print(const luau::Stat& stat)
    {
        if (stat.has_block())
            print(stat.block());
        else if (stat.has_if_())
            print(stat.if_());
        else if (stat.has_while_())
            print(stat.while_());
        else if (stat.has_repeat())
            print(stat.repeat());
        else if (stat.has_break_())
            print(stat.break_());
        else if (stat.has_continue_())
            print(stat.continue_());
        else if (stat.has_return_())
            print(stat.return_());
        else if (stat.has_call())
            print(stat.call());
        else if (stat.has_local())
            print(stat.local());
        else if (stat.has_for_())
            print(stat.for_());
        else if (stat.has_for_in())
            print(stat.for_in());
        else if (stat.has_assign())
            print(stat.assign());
        else if (stat.has_compound_assign())
            print(stat.compound_assign());
        else if (stat.has_function())
            print(stat.function());
        else if (stat.has_local_function())
            print(stat.local_function());
        else if (stat.has_type_alias())
            print(stat.type_alias());
        else if (stat.has_require_into_local())
            print(stat.require_into_local());
        else
            source += "do end\n";
    }

    void print(const luau::StatBlock& stat)
    {
        for (int i = 0; i < stat.body_size(); ++i)
        {
            if (stat.body(i).has_block())
            {
                source += "do\n";
                print(stat.body(i));
                source += "end\n";
            }
            else
            {
                print(stat.body(i));

                // parser will reject code with break/continue/return being non-trailing statements in a block
                if (stat.body(i).has_break_() || stat.body(i).has_continue_() || stat.body(i).has_return_())
                    break;
            }
        }
    }

    void print(const luau::StatIf& stat)
    {
        source += "if ";
        print(stat.cond());
        source += " then\n";
        print(stat.then());

        if (stat.has_else_())
        {
            source += "else\n";
            print(stat.else_());
            source += "end\n";
        }
        else if (stat.has_elseif())
        {
            source += "else";
            print(stat.elseif());
        }
        else
        {
            source += "end\n";
        }
    }

    void print(const luau::StatWhile& stat)
    {
        source += "while ";
        print(stat.cond());
        source += " do\n";

        functions.back().loops++;
        print(stat.body());
        functions.back().loops--;

        source += "end\n";
    }

    void print(const luau::StatRepeat& stat)
    {
        source += "repeat\n";

        functions.back().loops++;
        print(stat.body());
        functions.back().loops--;

        source += "until ";
        print(stat.cond());
        source += "\n";
    }

    void print(const luau::StatBreak& stat)
    {
        if (functions.back().loops)
            source += "break\n";
        else
            source += "do end\n";
    }

    void print(const luau::StatContinue& stat)
    {
        if (functions.back().loops)
            source += "continue\n";
        else
            source += "do end\n";
    }

    void print(const luau::StatReturn& stat)
    {
        source += "return ";
        for (int i = 0; i < stat.list_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(stat.list(i));
        }
        source += "\n";
    }

    void print(const luau::StatCall& stat)
    {
        print(stat.expr());
        source += '\n';
    }

    void print(const luau::StatLocal& stat)
    {
        source += "local ";

        if (stat.vars_size() == 0)
            source += '_';

        for (int i = 0; i < stat.vars_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(stat.vars(i));

            if (types && i < stat.types_size())
            {
                source += ':';
                print(stat.types(i));
            }
        }

        if (stat.values_size() != 0)
            source += " = ";

        for (int i = 0; i < stat.values_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(stat.values(i));
        }
        source += '\n';
    }

    void print(const luau::StatFor& stat)
    {
        source += "for ";
        print(stat.var());
        source += '=';
        print(stat.from());
        source += ',';
        print(stat.to());
        if (stat.has_step())
        {
            source += ',';
            print(stat.step());
        }
        source += " do\n";

        functions.back().loops++;
        print(stat.body());
        functions.back().loops--;

        source += "end\n";
    }

    void print(const luau::StatForIn& stat)
    {
        source += "for ";

        if (stat.vars_size() == 0)
            source += '_';

        for (int i = 0; i < stat.vars_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(stat.vars(i));
        }

        source += " in ";

        if (stat.values_size() == 0)
            source += "...";

        for (int i = 0; i < stat.values_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(stat.values(i));
        }

        source += " do\n";

        functions.back().loops++;
        print(stat.body());
        functions.back().loops--;

        source += "end\n";
    }

    void print(const luau::StatAssign& stat)
    {
        if (stat.vars_size() == 0)
            source += '_';

        for (int i = 0; i < stat.vars_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(stat.vars(i));
        }

        source += " = ";

        if (stat.values_size() == 0)
            source += "nil";

        for (int i = 0; i < stat.values_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(stat.values(i));
        }
        source += '\n';
    }

    void print(const luau::StatCompoundAssign& stat)
    {
        print(stat.var());

        if (stat.op() == luau::StatCompoundAssign::Add)
            source += " += ";
        else if (stat.op() == luau::StatCompoundAssign::Sub)
            source += " -= ";
        else if (stat.op() == luau::StatCompoundAssign::Mul)
            source += " *= ";
        else if (stat.op() == luau::StatCompoundAssign::Div)
            source += " /= ";
        else if (stat.op() == luau::StatCompoundAssign::Mod)
            source += " %= ";
        else if (stat.op() == luau::StatCompoundAssign::Pow)
            source += " ^= ";
        else if (stat.op() == luau::StatCompoundAssign::Concat)
            source += " ..= ";

        print(stat.value());
        source += '\n';
    }

    void print(const luau::StatFunction& stat)
    {
        source += "function ";
        if (stat.var().has_index_name())
            print(stat.var().index_name(), stat.self());
        else if (stat.var().has_index_expr())
            source += '_'; // function foo[bar]() is invalid syntax
        else
            print(stat.var());
        function(stat.func());
        source += '\n';
    }

    void print(const luau::StatLocalFunction& stat)
    {
        source += "local function ";
        print(stat.var());
        function(stat.func());
        source += '\n';
    }

    void print(const luau::StatTypeAlias& stat)
    {
        if (stat.export_())
            source += "export ";

        source += "type ";
        ident(stat.name());
        genericidents(stat);
        source += " = ";
        print(stat.type());
        source += '\n';
    }

    void print(const luau::StatRequireIntoLocalHelper& stat)
    {
        source += "local ";
        print(stat.var());
        source += " = require(module" + std::to_string(stat.modulenum() % 2) + ")\n";
    }

    void print(const luau::Type& type)
    {
        if (type.has_primitive())
            print(type.primitive());
        else if (type.has_literal())
            print(type.literal());
        else if (type.has_table())
            print(type.table());
        else if (type.has_function())
            print(type.function());
        else if (type.has_typeof())
            print(type.typeof());
        else if (type.has_union_())
            print(type.union_());
        else if (type.has_intersection())
            print(type.intersection());
        else if (type.has_class_())
            print(type.class_());
        else if (type.has_ref())
            print(type.ref());
        else if (type.has_boolean())
            print(type.boolean());
        else if (type.has_string())
            print(type.string());
        else
            source += "any";
    }

    void print(const luau::TypePrimitive& type)
    {
        size_t index = size_t(type.kind()) % std::size(kTypes);
        source += kTypes[index];
    }

    void print(const luau::TypeLiteral& type)
    {
        ident(type.name());

        if (type.generics_size() || type.genericpacks_size())
        {
            source += '<';
            bool first = true;

            for (size_t i = 0; i < type.generics_size(); ++i)
            {
                if (!first)
                    source += ',';
                first = false;
                print(type.generics(i));
            }

            for (size_t i = 0; i < type.genericpacks_size(); ++i)
            {
                if (!first)
                    source += ',';
                first = false;
                ident(type.genericpacks(i));
                source += "...";
            }

            source += '>';
        }
    }

    void print(const luau::TypeTable& type)
    {
        source += '{';
        for (size_t i = 0; i < type.items_size(); ++i)
        {
            ident(type.items(i).key());
            source += ':';
            print(type.items(i).type());
            source += ',';
        }
        if (type.has_indexer())
        {
            source += '[';
            print(type.indexer().key());
            source += "]:";
            print(type.indexer().value());
        }
        source += '}';
    }

    void print(const luau::TypeFunction& type)
    {
        genericidents(type);
        source += '(';
        for (size_t i = 0; i < type.args_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(type.args(i));
        }
        source += ")->";
        if (type.rets_size() != 1)
            source += '(';
        for (size_t i = 0; i < type.rets_size(); ++i)
        {
            if (i != 0)
                source += ',';
            print(type.rets(i));
        }
        if (type.rets_size() != 1)
            source += ')';
    }

    void print(const luau::TypeTypeof& type)
    {
        source += "typeof(";
        print(type.expr());
        source += ')';
    }

    void print(const luau::TypeUnion& type)
    {
        source += '(';
        print(type.left());
        source += ")|(";
        print(type.right());
        source += ')';
    }

    void print(const luau::TypeIntersection& type)
    {
        source += '(';
        print(type.left());
        source += ")&(";
        print(type.right());
        source += ')';
    }

    void print(const luau::TypeClass& type)
    {
        size_t index = size_t(type.kind()) % std::size(kClasses);
        source += kClasses[index];
    }

    void print(const luau::TypeRef& type)
    {
        print(type.prefix());
        source += '.';
        ident(type.index());
    }

    void print(const luau::TypeBoolean& type)
    {
        source += type.val() ? "true" : "false";
    }

    void print(const luau::TypeString& type)
    {
        source += '"';
        for (char ch : type.val())
            if (isgraph(ch))
                source += ch;
        source += '"';
    }
};

std::vector<std::string> protoprint(const luau::ModuleSet& stat, bool types)
{
    std::vector<std::string> result;

    if (stat.has_module())
    {
        ProtoToLuau printer;
        printer.types = types;
        printer.print(stat.module());
        result.push_back(printer.source);
    }

    ProtoToLuau printer;
    printer.types = types;
    printer.print(stat.program());
    result.push_back(printer.source);

    return result;
}