// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Id.h" #include "Luau/LanguageHash.h" #include "Luau/Slice.h" #include "Luau/Variant.h" #include #include #include #include #define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom \ { \ static constexpr const char* tag = #name; \ using Atom::Atom; \ } #define LUAU_EQSAT_NODE_ARRAY(name, ops) \ struct name : public ::Luau::EqSat::NodeVector> \ { \ static constexpr const char* tag = #name; \ using NodeVector::NodeVector; \ } #define LUAU_EQSAT_NODE_VECTOR(name) \ struct name : public ::Luau::EqSat::NodeVector> \ { \ static constexpr const char* tag = #name; \ using NodeVector::NodeVector; \ } #define LUAU_EQSAT_FIELD(name) \ struct name : public ::Luau::EqSat::Field \ { \ } #define LUAU_EQSAT_NODE_FIELDS(name, ...) \ struct name : public ::Luau::EqSat::NodeFields \ { \ static constexpr const char* tag = #name; \ using NodeFields::NodeFields; \ } namespace Luau::EqSat { template struct Atom { Atom(const T& value) : _value(value) { } const T& value() const { return _value; } public: Slice operands() { return {}; } Slice operands() const { return {}; } bool operator==(const Atom& rhs) const { return _value == rhs._value; } bool operator!=(const Atom& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const Atom& value) const { return languageHash(value._value); } }; private: T _value; }; template struct NodeVector { template NodeVector(Args&&... args) : vector{std::forward(args)...} { } Id operator[](size_t i) const { return vector[i]; } public: Slice operands() { return Slice{vector.data(), vector.size()}; } Slice operands() const { return Slice{vector.data(), vector.size()}; } bool operator==(const NodeVector& rhs) const { return vector == rhs.vector; } bool operator!=(const NodeVector& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeVector& value) const { return languageHash(value.vector); } }; private: T vector; }; /// Empty base class just for static_asserts. struct FieldBase { FieldBase() = delete; FieldBase(FieldBase&&) = delete; FieldBase& operator=(FieldBase&&) = delete; FieldBase(const FieldBase&) = delete; FieldBase& operator=(const FieldBase&) = delete; }; template struct Field : FieldBase { }; template struct NodeFields { static_assert(std::conjunction...>::value); template static constexpr int getIndex() { constexpr int N = sizeof...(Fields); constexpr bool is[N] = {std::is_same_v, Fields>...}; for (int i = 0; i < N; ++i) if (is[i]) return i; return -1; } public: template NodeFields(Args&&... args) : array{std::forward(args)...} { } Slice operands() { return Slice{array}; } Slice operands() const { return Slice{array.data(), array.size()}; } template Id field() const { static_assert(std::disjunction_v, Fields>...>); return array[getIndex()]; } bool operator==(const NodeFields& rhs) const { return array == rhs.array; } bool operator!=(const NodeFields& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeFields& value) const { return languageHash(value.array); } }; private: std::array array; }; template struct Language final { template using WithinDomain = std::disjunction, Ts>...>; template Language(T&& t, std::enable_if_t::value>* = 0) noexcept : v(std::forward(t)) { } Language(const Language&) noexcept = default; Language& operator=(const Language&) noexcept = default; Language(Language&&) noexcept = default; Language& operator=(Language&&) noexcept = default; int index() const noexcept { return v.index(); } /// You should never call this function with the intention of mutating the `Id`. /// Reading is ok, but you should also never assume that these `Id`s are stable. Slice operands() noexcept { return visit([](auto&& v) -> Slice { return v.operands(); }, v); } Slice operands() const noexcept { return visit([](auto&& v) -> Slice { return v.operands(); }, v); } template T* get() noexcept { static_assert(WithinDomain::value); return v.template get_if(); } template const T* get() const noexcept { static_assert(WithinDomain::value); return v.template get_if(); } bool operator==(const Language& rhs) const noexcept { return v == rhs.v; } bool operator!=(const Language& rhs) const noexcept { return !(*this == rhs); } public: struct Hash { size_t operator()(const Language& language) const { size_t seed = std::hash{}(language.index()); hashCombine(seed, visit([](auto&& v) { return typename std::decay_t::Hash{}(v); }, language.v)); return seed; } }; private: Variant v; }; } // namespace Luau::EqSat