// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Id.h" #include "Luau/LanguageHash.h" #include "Luau/Slice.h" #include "Luau/Variant.h" #include #include #include #include #include #include #define LUAU_EQSAT_UNIT(name) \ struct name : ::Luau::EqSat::Unit \ { \ static constexpr const char* tag = #name; \ using Unit::Unit; \ } #define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom \ { \ static constexpr const char* tag = #name; \ using Atom::Atom; \ } #define LUAU_EQSAT_NODE_ARRAY(name, ops) \ struct name : public ::Luau::EqSat::NodeVector> \ { \ static constexpr const char* tag = #name; \ using NodeVector::NodeVector; \ } #define LUAU_EQSAT_NODE_VECTOR(name) \ struct name : public ::Luau::EqSat::NodeVector> \ { \ static constexpr const char* tag = #name; \ using NodeVector::NodeVector; \ } #define LUAU_EQSAT_NODE_SET(name) \ struct name : public ::Luau::EqSat::NodeSet> \ { \ static constexpr const char* tag = #name; \ using NodeSet::NodeSet; \ } #define LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(name, t) \ struct name : public ::Luau::EqSat::NodeAtomAndVector> \ { \ static constexpr const char* tag = #name; \ using NodeAtomAndVector::NodeAtomAndVector; \ } namespace Luau::EqSat { template struct Unit { Slice mutableOperands() { return {}; } Slice operands() const { return {}; } bool operator==(const Unit& rhs) const { return true; } bool operator!=(const Unit& rhs) const { return false; } struct Hash { size_t operator()(const Unit& value) const { // chosen by fair dice roll. // guaranteed to be random. return 4; } }; }; template struct Atom { Atom(const T& value) : _value(value) { } const T& value() const { return _value; } public: Slice mutableOperands() { return {}; } Slice operands() const { return {}; } bool operator==(const Atom& rhs) const { return _value == rhs._value; } bool operator!=(const Atom& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const Atom& value) const { return languageHash(value._value); } }; private: T _value; }; template struct NodeAtomAndVector { template NodeAtomAndVector(const X& value, Args&&... args) : _value(value) , vector{std::forward(args)...} { } Id operator[](size_t i) const { return vector[i]; } public: const X& value() const { return _value; } Slice mutableOperands() { return Slice{vector.data(), vector.size()}; } Slice operands() const { return Slice{vector.data(), vector.size()}; } bool operator==(const NodeAtomAndVector& rhs) const { return _value == rhs._value && vector == rhs.vector; } bool operator!=(const NodeAtomAndVector& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeAtomAndVector& value) const { size_t result = languageHash(value._value); hashCombine(result, languageHash(value.vector)); return result; } }; private: X _value; T vector; }; template struct NodeVector { template NodeVector(Args&&... args) : vector{std::forward(args)...} { } Id operator[](size_t i) const { return vector[i]; } public: Slice mutableOperands() { return Slice{vector.data(), vector.size()}; } Slice operands() const { return Slice{vector.data(), vector.size()}; } bool operator==(const NodeVector& rhs) const { return vector == rhs.vector; } bool operator!=(const NodeVector& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeVector& value) const { return languageHash(value.vector); } }; private: T vector; }; template struct NodeSet { template friend void canonicalize(NodeSet& node, Find&& find); template NodeSet(Args&&... args) : vector{std::forward(args)...} { std::sort(begin(vector), end(vector)); auto it = std::unique(begin(vector), end(vector)); vector.erase(it, end(vector)); } Id operator[](size_t i) const { return vector[i]; } public: Slice mutableOperands() { return Slice{vector.data(), vector.size()}; } Slice operands() const { return Slice{vector.data(), vector.size()}; } bool operator==(const NodeSet& rhs) const { return vector == rhs.vector; } bool operator!=(const NodeSet& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const NodeSet& value) const { return languageHash(value.vector); } }; protected: T vector; }; template struct Language final { using VariantTy = Luau::Variant; template using WithinDomain = std::disjunction, Ts>...>; template friend void canonicalize(Language& enode, Find&& find); template Language(T&& t, std::enable_if_t::value>* = 0) noexcept : v(std::forward(t)) { } int index() const noexcept { return v.index(); } /// This should only be used in canonicalization! /// Always prefer operands() Slice mutableOperands() noexcept { return visit( [](auto&& v) -> Slice { return v.mutableOperands(); }, v ); } Slice operands() const noexcept { return visit( [](auto&& v) -> Slice { return v.operands(); }, v ); } template T* get() noexcept { static_assert(WithinDomain::value); return v.template get_if(); } template const T* get() const noexcept { static_assert(WithinDomain::value); return v.template get_if(); } bool operator==(const Language& rhs) const noexcept { return v == rhs.v; } bool operator!=(const Language& rhs) const noexcept { return !(*this == rhs); } public: struct Hash { size_t operator()(const Language& language) const { size_t seed = std::hash{}(language.index()); hashCombine( seed, visit( [](auto&& v) { return typename std::decay_t::Hash{}(v); }, language.v ) ); return seed; } }; private: VariantTy v; }; template void canonicalize(Node& node, Find&& find) { // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). for (Id& id : node.mutableOperands()) id = find(id); } // Canonicalizing the Ids in a NodeSet may result in the set decreasing in size. template void canonicalize(NodeSet& node, Find&& find) { for (Id& id : node.vector) id = find(id); std::sort(begin(node.vector), end(node.vector)); auto endIt = std::unique(begin(node.vector), end(node.vector)); node.vector.erase(endIt, end(node.vector)); } template void canonicalize(Language& enode, Find&& find) { visit( [&](auto&& v) { Luau::EqSat::canonicalize(v, find); }, enode.v ); } } // namespace Luau::EqSat