// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once

#include "Luau/Id.h"
#include "Luau/LanguageHash.h"
#include "Luau/Slice.h"
#include "Luau/Variant.h"

#include <algorithm>
#include <array>
#include <type_traits>
#include <unordered_set>
#include <utility>
#include <vector>

#define LUAU_EQSAT_UNIT(name) \
    struct name : ::Luau::EqSat::Unit<name> \
    { \
        static constexpr const char* tag = #name; \
        using Unit::Unit; \
    }

#define LUAU_EQSAT_ATOM(name, t) \
    struct name : public ::Luau::EqSat::Atom<name, t> \
    { \
        static constexpr const char* tag = #name; \
        using Atom::Atom; \
    }

#define LUAU_EQSAT_NODE_ARRAY(name, ops) \
    struct name : public ::Luau::EqSat::NodeVector<name, std::array<::Luau::EqSat::Id, ops>> \
    { \
        static constexpr const char* tag = #name; \
        using NodeVector::NodeVector; \
    }

#define LUAU_EQSAT_NODE_VECTOR(name) \
    struct name : public ::Luau::EqSat::NodeVector<name, std::vector<::Luau::EqSat::Id>> \
    { \
        static constexpr const char* tag = #name; \
        using NodeVector::NodeVector; \
    }

#define LUAU_EQSAT_NODE_SET(name) \
    struct name : public ::Luau::EqSat::NodeSet<name, std::vector<::Luau::EqSat::Id>> \
    { \
        static constexpr const char* tag = #name; \
        using NodeSet::NodeSet; \
    }

#define LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(name, t) \
    struct name : public ::Luau::EqSat::NodeAtomAndVector<name, t, std::vector<::Luau::EqSat::Id>> \
    { \
        static constexpr const char* tag = #name; \
        using NodeAtomAndVector::NodeAtomAndVector; \
    }

namespace Luau::EqSat
{

template<typename Phantom>
struct Unit
{
    Slice<Id> mutableOperands()
    {
        return {};
    }

    Slice<const Id> operands() const
    {
        return {};
    }

    bool operator==(const Unit& rhs) const
    {
        return true;
    }

    bool operator!=(const Unit& rhs) const
    {
        return false;
    }

    struct Hash
    {
        size_t operator()(const Unit& value) const
        {
            // chosen by fair dice roll.
            // guaranteed to be random.
            return 4;
        }
    };
};

template<typename Phantom, typename T>
struct Atom
{
    Atom(const T& value)
        : _value(value)
    {
    }

    const T& value() const
    {
        return _value;
    }

public:
    Slice<Id> mutableOperands()
    {
        return {};
    }

    Slice<const Id> operands() const
    {
        return {};
    }

    bool operator==(const Atom& rhs) const
    {
        return _value == rhs._value;
    }

    bool operator!=(const Atom& rhs) const
    {
        return !(*this == rhs);
    }

    struct Hash
    {
        size_t operator()(const Atom& value) const
        {
            return languageHash(value._value);
        }
    };

private:
    T _value;
};

template<typename Phantom, typename X, typename T>
struct NodeAtomAndVector
{
    template<typename... Args>
    NodeAtomAndVector(const X& value, Args&&... args)
        : _value(value)
        , vector{std::forward<Args>(args)...}
    {
    }

    Id operator[](size_t i) const
    {
        return vector[i];
    }

public:
    const X& value() const
    {
        return _value;
    }

    Slice<Id> mutableOperands()
    {
        return Slice{vector.data(), vector.size()};
    }

    Slice<const Id> operands() const
    {
        return Slice{vector.data(), vector.size()};
    }

    bool operator==(const NodeAtomAndVector& rhs) const
    {
        return _value == rhs._value && vector == rhs.vector;
    }

    bool operator!=(const NodeAtomAndVector& rhs) const
    {
        return !(*this == rhs);
    }

    struct Hash
    {
        size_t operator()(const NodeAtomAndVector& value) const
        {
            size_t result = languageHash(value._value);
            hashCombine(result, languageHash(value.vector));
            return result;
        }
    };

private:
    X _value;
    T vector;
};

template<typename Phantom, typename T>
struct NodeVector
{
    template<typename... Args>
    NodeVector(Args&&... args)
        : vector{std::forward<Args>(args)...}
    {
    }

    Id operator[](size_t i) const
    {
        return vector[i];
    }

public:
    Slice<Id> mutableOperands()
    {
        return Slice{vector.data(), vector.size()};
    }

    Slice<const Id> operands() const
    {
        return Slice{vector.data(), vector.size()};
    }

    bool operator==(const NodeVector& rhs) const
    {
        return vector == rhs.vector;
    }

    bool operator!=(const NodeVector& rhs) const
    {
        return !(*this == rhs);
    }

    struct Hash
    {
        size_t operator()(const NodeVector& value) const
        {
            return languageHash(value.vector);
        }
    };

private:
    T vector;
};

template<typename Phantom, typename T>
struct NodeSet
{
    template<typename P_, typename T_, typename Find>
    friend void canonicalize(NodeSet<P_, T_>& node, Find&& find);

    template<typename... Args>
    NodeSet(Args&&... args)
        : vector{std::forward<Args>(args)...}
    {
        std::sort(begin(vector), end(vector));
        auto it = std::unique(begin(vector), end(vector));
        vector.erase(it, end(vector));
    }

    Id operator[](size_t i) const
    {
        return vector[i];
    }

public:
    Slice<Id> mutableOperands()
    {
        return Slice{vector.data(), vector.size()};
    }

    Slice<const Id> operands() const
    {
        return Slice{vector.data(), vector.size()};
    }

    bool operator==(const NodeSet& rhs) const
    {
        return vector == rhs.vector;
    }

    bool operator!=(const NodeSet& rhs) const
    {
        return !(*this == rhs);
    }

    struct Hash
    {
        size_t operator()(const NodeSet& value) const
        {
            return languageHash(value.vector);
        }
    };

protected:
    T vector;
};

template<typename... Ts>
struct Language final
{
    using VariantTy = Luau::Variant<Ts...>;

    template<typename T>
    using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>;

    template<typename Find, typename... Vs>
    friend void canonicalize(Language<Vs...>& enode, Find&& find);

    template<typename T>
    Language(T&& t, std::enable_if_t<WithinDomain<T>::value>* = 0) noexcept
        : v(std::forward<T>(t))
    {
    }

    int index() const noexcept
    {
        return v.index();
    }

    /// This should only be used in canonicalization!
    /// Always prefer operands()
    Slice<Id> mutableOperands() noexcept
    {
        return visit(
            [](auto&& v) -> Slice<Id>
            {
                return v.mutableOperands();
            },
            v
        );
    }

    Slice<const Id> operands() const noexcept
    {
        return visit(
            [](auto&& v) -> Slice<const Id>
            {
                return v.operands();
            },
            v
        );
    }

    template<typename T>
    T* get() noexcept
    {
        static_assert(WithinDomain<T>::value);
        return v.template get_if<T>();
    }

    template<typename T>
    const T* get() const noexcept
    {
        static_assert(WithinDomain<T>::value);
        return v.template get_if<T>();
    }

    bool operator==(const Language& rhs) const noexcept
    {
        return v == rhs.v;
    }

    bool operator!=(const Language& rhs) const noexcept
    {
        return !(*this == rhs);
    }

public:
    struct Hash
    {
        size_t operator()(const Language& language) const
        {
            size_t seed = std::hash<int>{}(language.index());
            hashCombine(
                seed,
                visit(
                    [](auto&& v)
                    {
                        return typename std::decay_t<decltype(v)>::Hash{}(v);
                    },
                    language.v
                )
            );
            return seed;
        }
    };

private:
    VariantTy v;
};

template<typename Node, typename Find>
void canonicalize(Node& node, Find&& find)
{
    // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
    // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
    for (Id& id : node.mutableOperands())
        id = find(id);
}

// Canonicalizing the Ids in a NodeSet may result in the set decreasing in size.
template<typename Phantom, typename T, typename Find>
void canonicalize(NodeSet<Phantom, T>& node, Find&& find)
{
    for (Id& id : node.vector)
        id = find(id);

    std::sort(begin(node.vector), end(node.vector));
    auto endIt = std::unique(begin(node.vector), end(node.vector));
    node.vector.erase(endIt, end(node.vector));
}

template<typename Find, typename... Vs>
void canonicalize(Language<Vs...>& enode, Find&& find)
{
    visit(
        [&](auto&& v)
        {
            Luau::EqSat::canonicalize(v, find);
        },
        enode.v
    );
}

} // namespace Luau::EqSat