// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Id.h" #include "Luau/LanguageHash.h" #include "Luau/Slice.h" #include "Luau/Variant.h" #include <array> #include <type_traits> #include <utility> #define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom<name, t> \ { \ static constexpr const char* tag = #name; \ using Atom::Atom; \ } #define LUAU_EQSAT_NODE_ARRAY(name, ops) \ struct name : public ::Luau::EqSat::NodeVector<name, std::array<::Luau::EqSat::Id, ops>> \ { \ static constexpr const char* tag = #name; \ using NodeVector::NodeVector; \ } #define LUAU_EQSAT_NODE_VECTOR(name) \ struct name : public ::Luau::EqSat::NodeVector<name, std::vector<::Luau::EqSat::Id>> \ { \ static constexpr const char* tag = #name; \ using NodeVector::NodeVector; \ } #define LUAU_EQSAT_FIELD(name) \ struct name : public ::Luau::EqSat::Field<name> \ { \ } #define LUAU_EQSAT_NODE_FIELDS(name, ...) \ struct name : public ::Luau::EqSat::NodeFields<name, __VA_ARGS__> \ { \ static constexpr const char* tag = #name; \ using NodeFields::NodeFields; \ } namespace Luau::EqSat { template<typename Phantom, typename T> struct Atom { Atom(const T& value) : _value(value) { } const T& value() const { return _value; } public: Slice<Id> operands() { return {}; } Slice<const Id> operands() const { return {}; } bool operator==(const Atom& rhs) const { return _value == rhs._value; } bool operator!=(const Atom& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const Atom& value) const { return languageHash(value._value); } }; private: T _value; }; template<typename Phantom, typename T> struct NodeVector { template<typename... Args> NodeVector(Args&&... args) : vector{std::forward<Args>(args)...} { } Id operator[](size_t i) const { return vector[i]; } public: Slice<Id> operands() { return Slice{vector.data(), vector.size()}; } Slice<const Id> operands() const { return Slice{vector.data(), vector.size()}; } bool operator==(const NodeVector& rhs) const { return vector == rhs.vector; } bool operator!=(const NodeVector& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeVector& value) const { return languageHash(value.vector); } }; private: T vector; }; /// Empty base class just for static_asserts. struct FieldBase { FieldBase() = delete; FieldBase(FieldBase&&) = delete; FieldBase& operator=(FieldBase&&) = delete; FieldBase(const FieldBase&) = delete; FieldBase& operator=(const FieldBase&) = delete; }; template<typename Phantom> struct Field : FieldBase { }; template<typename Phantom, typename... Fields> struct NodeFields { static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value); template<typename T> static constexpr int getIndex() { constexpr int N = sizeof...(Fields); constexpr bool is[N] = {std::is_same_v<std::decay_t<T>, Fields>...}; for (int i = 0; i < N; ++i) if (is[i]) return i; return -1; } public: template<typename... Args> NodeFields(Args&&... args) : array{std::forward<Args>(args)...} { } Slice<Id> operands() { return Slice{array}; } Slice<const Id> operands() const { return Slice{array.data(), array.size()}; } template<typename T> Id field() const { static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Fields>...>); return array[getIndex<T>()]; } bool operator==(const NodeFields& rhs) const { return array == rhs.array; } bool operator!=(const NodeFields& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeFields& value) const { return languageHash(value.array); } }; private: std::array<Id, sizeof...(Fields)> array; }; template<typename... Ts> struct Language final { template<typename T> using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>; template<typename T> Language(T&& t, std::enable_if_t<WithinDomain<T>::value>* = 0) noexcept : v(std::forward<T>(t)) { } int index() const noexcept { return v.index(); } /// You should never call this function with the intention of mutating the `Id`. /// Reading is ok, but you should also never assume that these `Id`s are stable. Slice<Id> operands() noexcept { return visit( [](auto&& v) -> Slice<Id> { return v.operands(); }, v ); } Slice<const Id> operands() const noexcept { return visit( [](auto&& v) -> Slice<const Id> { return v.operands(); }, v ); } template<typename T> T* get() noexcept { static_assert(WithinDomain<T>::value); return v.template get_if<T>(); } template<typename T> const T* get() const noexcept { static_assert(WithinDomain<T>::value); return v.template get_if<T>(); } bool operator==(const Language& rhs) const noexcept { return v == rhs.v; } bool operator!=(const Language& rhs) const noexcept { return !(*this == rhs); } public: struct Hash { size_t operator()(const Language& language) const { size_t seed = std::hash<int>{}(language.index()); hashCombine( seed, visit( [](auto&& v) { return typename std::decay_t<decltype(v)>::Hash{}(v); }, language.v ) ); return seed; } }; private: Variant<Ts...> v; }; } // namespace Luau::EqSat