// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Id.h" #include "Luau/Slice.h" #include #include #include #include #define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom \ { \ static constexpr const char* tag = #name; \ using Atom::Atom; \ } #define LUAU_EQSAT_NODE_ARRAY(name, ops) \ struct name : public ::Luau::EqSat::NodeVector> \ { \ static constexpr const char* tag = #name; \ using NodeVector::NodeVector; \ } #define LUAU_EQSAT_NODE_VECTOR(name) \ struct name : public ::Luau::EqSat::NodeVector> \ { \ static constexpr const char* tag = #name; \ using NodeVector::NodeVector; \ } #define LUAU_EQSAT_FIELD(name) \ struct name : public ::Luau::EqSat::Field \ { \ } #define LUAU_EQSAT_NODE_FIELDS(name, ...) \ struct name : public ::Luau::EqSat::NodeFields \ { \ static constexpr const char* tag = #name; \ using NodeFields::NodeFields; \ } namespace Luau::EqSat { template struct LanguageHash { size_t operator()(const T&) const { // See available specializations at the bottom of this file. static_assert(false, "missing languageHash specialization"); } }; template std::size_t languageHash(const T& lang) { return LanguageHash{}(lang); } template struct Atom { Atom(const T& value) : _value(value) { } const T& value() const { return _value; } public: Slice operands() { return {}; } Slice operands() const { return {}; } bool operator==(const Atom& rhs) const { return _value == rhs._value; } bool operator!=(const Atom& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const Atom& value) const { return languageHash(value._value); } }; private: T _value; }; template struct NodeVector { template NodeVector(Args&&... args) : vector{std::forward(args)...} { } const Id& operator[](size_t i) const { return vector[i]; } public: Slice operands() { return Slice{vector.data(), vector.size()}; } Slice operands() const { return Slice{vector.data(), vector.size()}; } bool operator==(const NodeVector& rhs) const { return vector == rhs.vector; } bool operator!=(const NodeVector& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeVector& value) const { return languageHash(value.vector); } }; private: T vector; }; /// Empty base class just for static_asserts. struct FieldBase { FieldBase() = delete; FieldBase(FieldBase&&) = delete; FieldBase& operator=(FieldBase&&) = delete; FieldBase(const FieldBase&) = delete; FieldBase& operator=(const FieldBase&) = delete; }; template struct Field : FieldBase { }; template class NodeFields { static_assert(std::conjunction...>::value); std::array array; template static constexpr int getIndex() { constexpr int N = sizeof...(Fields); constexpr bool is[N] = {std::is_same_v, Fields>...}; for (int i = 0; i < N; ++i) if (is[i]) return i; return -1; } public: template NodeFields(Args&&... args) : array{std::forward(args)...} { } Slice operands() { return Slice{array}; } Slice operands() const { return Slice{array}; } template const Id& field() const { static_assert(std::disjunction_v, Fields>...>); return array[getIndex()]; } bool operator==(const NodeFields& rhs) const { return array == rhs.array; } bool operator!=(const NodeFields& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeFields& value) const { return languageHash(value.array); } }; }; // `Language` is very similar to `Luau::Variant` with enough differences warranting a different type altogether. // // Firstly, where `Luau::Variant` uses an `int` to decide which type the variant currently holds, we use // a `const char*` instead. We use the pointer address for tag checking, and the string buffer for stringification. // // Secondly, we need `Language` to have additional methods such as: // - `children()` to get child operands, // - `operator==` to decide equality, and // - `hash()` function. // // And finally, each `T` in `Ts` have additional requirements which `Luau::Variant` doesn't need. template class Language final { const char* tag; char buffer[std::max({sizeof(Ts)...})]; private: template using WithinDomain = std::disjunction, Ts>...>; using FnCopy = void (*)(void*, const void*); using FnMove = void (*)(void*, void*); using FnDtor = void (*)(void*); using FnPred = bool (*)(const void*, const void*); using FnHash = size_t (*)(const void*); using FnOper = Slice (*)(void*); template static void fnCopy(void* dst, const void* src) noexcept { new (dst) T(*static_cast(src)); } template static void fnMove(void* dst, void* src) noexcept { new (dst) T(static_cast(*static_cast(src))); } template static void fnDtor(void* dst) noexcept { static_cast(dst)->~T(); } template static bool fnPred(const void* lhs, const void* rhs) noexcept { return *static_cast(lhs) == *static_cast(rhs); } template static size_t fnHash(const void* buffer) noexcept { return typename T::Hash{}(*static_cast(buffer)); } template static Slice fnOper(void* buffer) noexcept { return static_cast(buffer)->operands(); } static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor...}; static constexpr FnPred tablePred[sizeof...(Ts)] = {&fnPred...}; static constexpr FnHash tableHash[sizeof...(Ts)] = {&fnHash...}; static constexpr FnOper tableOper[sizeof...(Ts)] = {&fnOper...}; static constexpr int getIndexFromTag(const char* tag) noexcept { constexpr int N = sizeof...(Ts); constexpr const char* is[N] = {Ts::tag...}; for (int i = 0; i < N; ++i) if (is[i] == tag) return i; return -1; } public: template Language(T&& t, std::enable_if_t::value>* = 0) noexcept { tag = std::decay_t::tag; new (&buffer) std::decay_t(std::forward(t)); } Language(const Language& other) noexcept { tag = other.tag; tableCopy[getIndexFromTag(tag)](&buffer, &other.buffer); } Language(Language&& other) noexcept { tag = other.tag; tableMove[getIndexFromTag(tag)](&buffer, &other.buffer); } ~Language() noexcept { tableDtor[getIndexFromTag(tag)](&buffer); } Language& operator=(const Language& other) noexcept { Language copy{other}; *this = static_cast(copy); return *this; } Language& operator=(Language&& other) noexcept { if (this != &other) { tableDtor[getIndexFromTag(tag)](&buffer); tag = other.tag; tableMove[getIndexFromTag(tag)](&buffer, &other.buffer); // nothrow } return *this; } int index() const noexcept { return getIndexFromTag(tag); } /// You should never call this function with the intention of mutating the `Id`. /// Reading is ok, but you should also never assume that these `Id`s are stable. Slice operands() noexcept { return tableOper[getIndexFromTag(tag)](&buffer); } Slice operands() const noexcept { return const_cast(this)->operands(); } template const T* get() const noexcept { static_assert(WithinDomain::value); return tag == T::tag ? reinterpret_cast(&buffer) : nullptr; } bool operator==(const Language& rhs) const noexcept { return tag == rhs.tag && tablePred[getIndexFromTag(tag)](&buffer, &rhs.buffer); } bool operator!=(const Language& rhs) const noexcept { return !(*this == rhs); } public: struct Hash { size_t operator()(const Language& language) const { size_t seed = std::hash{}(language.tag); hashCombine(seed, tableHash[getIndexFromTag(language.tag)](&language.buffer)); return seed; } }; }; inline void hashCombine(size_t& seed, size_t hash) { // Golden Ratio constant used for better hash scattering // See https://softwareengineering.stackexchange.com/a/402543 seed ^= hash + 0x9e3779b9 + (seed << 6) + (seed >> 2); } template struct LanguageHash{}(std::declval()))>> { size_t operator()(const T& t) const { return std::hash{}(t); } }; template struct LanguageHash> { size_t operator()(const std::array& array) const { size_t seed = 0; for (const T& t : array) hashCombine(seed, languageHash(t)); return seed; } }; template struct LanguageHash> { size_t operator()(const std::vector& vector) const { size_t seed = 0; for (const T& t : vector) hashCombine(seed, languageHash(t)); return seed; } }; } // namespace Luau::EqSat