// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Id.h" #include #include #include #include namespace Luau::EqSat { #define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom \ { \ static constexpr const char* tag = #name; \ } #define LUAU_EQSAT_FIELD(name) \ struct name : public ::Luau::EqSat::Field \ { \ } #define LUAU_EQSAT_UNARY_NODE(name, field) \ struct name : public ::Luau::EqSat::Node \ { \ static constexpr const char* tag = #name; \ using Base::Node; \ } #define LUAU_EQSAT_BINARY_NODE(name, field1, field2) \ struct name : public ::Luau::EqSat::Node \ { \ static constexpr const char* tag = #name; \ using Base::Node; \ } #define DERIVE_EQ(name, field) \ bool operator==(const name& rhs) const \ { \ return field == rhs.field; \ } \ bool operator!=(const name& rhs) const \ { \ return !(*this == rhs); \ } template struct Atom { T value; DERIVE_EQ(Atom, value); struct Hash { size_t operator()(const Atom& atom) const { return std::hash{}(atom.value); } }; }; /// Empty base class just for static_asserts. struct FieldBase { }; template struct Field : FieldBase { Id id; Field(Id id) : id(id) { } DERIVE_EQ(Field, id); struct Hash { size_t operator()(const Field& field) const { return std::hash{}(field.id); } }; }; template class Node { static_assert(std::conjunction...>::value); std::array array; template static constexpr int getIndex() { using TT = std::decay_t; constexpr int N = sizeof...(Fields); constexpr bool is[N] = {std::is_same_v...}; for (int i = 0; i < N; ++i) if (is[i]) return i; return -1; } public: using Base = Node; template Node(Args&&... args) : array{std::forward(args)...} { } template Id field() const { static_assert(std::disjunction_v, Fields>...>); return array[getIndex()]; } DERIVE_EQ(Node, array); struct Hash { size_t operator()(const Node& node) const { return 0; } }; }; #undef DERIVE_EQ // `Language` is very similar to `Luau::Variant` with enough differences warranting a different type altogether. // // Firstly, where `Luau::Variant` uses an `int` to decide which type the variant currently holds, we use // a `const char*` instead. We use the pointer address for tag checking, and the string buffer for stringification. // // Secondly, we need `Language` to have additional methods such as: // - `children()` to get child operands, // - `operator==` to decide equality, and // - `hash()` function. // // And finally, each `T` in `Ts` have additional requirements which `Luau::Variant` doesn't need. template class Language { const char* tag; char buffer[std::max({sizeof(Ts)...})]; private: template using WithinDomain = std::disjunction, Ts>...>; using FnCopy = void (*)(void*, const void*); using FnMove = void (*)(void*, void*); using FnDtor = void (*)(void*); using FnPred = bool (*)(const void*, const void*); using FnHash = size_t (*)(const void*); template static void fnCopy(void* dst, const void* src) { new (dst) T(*static_cast(src)); } template static void fnMove(void* dst, void* src) { new (dst) T(static_cast(*static_cast(src))); } template static void fnDtor(void* dst) { static_cast(dst)->~T(); } template static bool fnPred(const void* lhs, const void* rhs) { return *static_cast(lhs) == *static_cast(rhs); } template static size_t fnHash(const void* buffer) { return typename T::Hash{}(*static_cast(buffer)); } static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor...}; static constexpr FnPred tablePred[sizeof...(Ts)] = {&fnPred...}; static constexpr FnHash tableHash[sizeof...(Ts)] = {&fnHash...}; static constexpr int getIndexFromTag(const char* tag) { constexpr int N = sizeof...(Ts); constexpr const char* is[N] = {Ts::tag...}; for (int i = 0; i < N; ++i) if (is[i] == tag) return i; return -1; } public: template Language(T&& t, std::enable_if_t::value>* = 0) { tag = T::tag; new (&buffer) std::decay_t(std::forward(t)); } Language(const Language& other) { tag = other.tag; tableCopy[getIndexFromTag(tag)](&buffer, &other.buffer); } Language(Language&& other) { tag = other.tag; tableMove[getIndexFromTag(tag)](&buffer, &other.buffer); } ~Language() { tableDtor[getIndexFromTag(tag)](&buffer); } Language& operator=(const Language& other) { Language copy{other}; *this = static_cast(copy); return *this; } Language& operator=(Language&& other) { if (this != &other) { tableDtor[getIndexFromTag(tag)](&buffer); tag = other.tag; tableMove[getIndexFromTag(tag)](&buffer, &other.buffer); // nothrow } return *this; } template const T* get() const { static_assert(WithinDomain::value); return tag == T::tag ? reinterpret_cast(&buffer) : nullptr; } bool operator==(const Language& rhs) const { return tag == rhs.tag && tablePred[getIndexFromTag(tag)](&buffer, &rhs.buffer); } bool operator!=(const Language& rhs) const { return !(*this == rhs); } public: struct Hash { size_t operator()(const Language& language) const { size_t hash = std::hash{}(language.tag); hash ^= tableHash[getIndexFromTag(language.tag)](&language.buffer); return hash; } }; }; } // namespace Luau::EqSat