// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Id.h" #include "Luau/LanguageHash.h" #include "Luau/Slice.h" #include "Luau/Variant.h" #include #include #include #include #include #include #define LUAU_EQSAT_UNIT(name) \ struct name : ::Luau::EqSat::Unit \ { \ static constexpr const char* tag = #name; \ using Unit::Unit; \ } #define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom \ { \ static constexpr const char* tag = #name; \ using Atom::Atom; \ } #define LUAU_EQSAT_NODE_ARRAY(name, ops) \ struct name : public ::Luau::EqSat::NodeVector> \ { \ static constexpr const char* tag = #name; \ using NodeVector::NodeVector; \ } #define LUAU_EQSAT_NODE_VECTOR(name) \ struct name : public ::Luau::EqSat::NodeVector> \ { \ static constexpr const char* tag = #name; \ using NodeVector::NodeVector; \ } #define LUAU_EQSAT_NODE_SET(name) \ struct name : public ::Luau::EqSat::NodeSet> \ { \ static constexpr const char* tag = #name; \ using NodeSet::NodeSet; \ } #define LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(name, t) \ struct name : public ::Luau::EqSat::NodeAtomAndVector> \ { \ static constexpr const char* tag = #name; \ using NodeAtomAndVector::NodeAtomAndVector; \ } namespace Luau::EqSat { template struct Unit { Slice mutableOperands() { return {}; } Slice operands() const { return {}; } bool operator==(const Unit& rhs) const { return true; } bool operator!=(const Unit& rhs) const { return false; } struct Hash { size_t operator()(const Unit& value) const { // chosen by fair dice roll. // guaranteed to be random. return 4; } }; }; template struct Atom { Atom(const T& value) : _value(value) { } const T& value() const { return _value; } public: Slice mutableOperands() { return {}; } Slice operands() const { return {}; } bool operator==(const Atom& rhs) const { return _value == rhs._value; } bool operator!=(const Atom& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const Atom& value) const { return languageHash(value._value); } }; private: T _value; }; template struct NodeAtomAndVector { template NodeAtomAndVector(const X& value, Args&&... args) : _value(value) , vector{std::forward(args)...} { } Id operator[](size_t i) const { return vector[i]; } public: const X& value() const { return _value; } Slice mutableOperands() { return Slice{vector.data(), vector.size()}; } Slice operands() const { return Slice{vector.data(), vector.size()}; } bool operator==(const NodeAtomAndVector& rhs) const { return _value == rhs._value && vector == rhs.vector; } bool operator!=(const NodeAtomAndVector& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeAtomAndVector& value) const { size_t result = languageHash(value._value); hashCombine(result, languageHash(value.vector)); return result; } }; private: X _value; T vector; }; template struct NodeVector { template NodeVector(Args&&... args) : vector{std::forward(args)...} { } Id operator[](size_t i) const { return vector[i]; } public: Slice mutableOperands() { return Slice{vector.data(), vector.size()}; } Slice operands() const { return Slice{vector.data(), vector.size()}; } bool operator==(const NodeVector& rhs) const { return vector == rhs.vector; } bool operator!=(const NodeVector& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeVector& value) const { return languageHash(value.vector); } }; private: T vector; }; template struct NodeSet { template NodeSet(Args&&... args) : vector{std::forward(args)...} { std::sort(begin(vector), end(vector)); auto it = std::unique(begin(vector), end(vector)); vector.erase(it, end(vector)); } Id operator[](size_t i) const { return vector[i]; } public: Slice mutableOperands() { return Slice{vector.data(), vector.size()}; } Slice operands() const { return Slice{vector.data(), vector.size()}; } bool operator==(const NodeSet& rhs) const { return vector == rhs.vector; } bool operator!=(const NodeSet& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeSet& value) const { return languageHash(value.vector); } }; protected: T vector; }; template struct Language final { using VariantTy = Luau::Variant; template using WithinDomain = std::disjunction, Ts>...>; template Language(T&& t, std::enable_if_t::value>* = 0) noexcept : v(std::forward(t)) { } int index() const noexcept { return v.index(); } /// This should only be used in canonicalization! /// Always prefer operands() Slice mutableOperands() noexcept { return visit( [](auto&& v) -> Slice { return v.mutableOperands(); }, v ); } Slice operands() const noexcept { return visit( [](auto&& v) -> Slice { return v.operands(); }, v ); } template T* get() noexcept { static_assert(WithinDomain::value); return v.template get_if(); } template const T* get() const noexcept { static_assert(WithinDomain::value); return v.template get_if(); } bool operator==(const Language& rhs) const noexcept { return v == rhs.v; } bool operator!=(const Language& rhs) const noexcept { return !(*this == rhs); } public: struct Hash { size_t operator()(const Language& language) const { size_t seed = std::hash{}(language.index()); hashCombine( seed, visit( [](auto&& v) { return typename std::decay_t::Hash{}(v); }, language.v ) ); return seed; } }; private: VariantTy v; }; } // namespace Luau::EqSat