From bf628396344d480ff81c1b1dfff4037f65b1d886 Mon Sep 17 00:00:00 2001 From: Alexander McCord Date: Sun, 2 Jun 2024 17:30:15 -0700 Subject: [PATCH] Rework the API entirely. --- EqSat/include/Luau/Language.h | 111 +++++++++++++++++++++++------ tests/EqSat.language.test.cpp | 13 ++-- tests/EqSat.propositional.test.cpp | 10 +-- 3 files changed, 99 insertions(+), 35 deletions(-) diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h index d83138a2..e32d76cc 100644 --- a/EqSat/include/Luau/Language.h +++ b/EqSat/include/Luau/Language.h @@ -28,10 +28,32 @@ std::size_t languageHash(const T& lang) return LanguageHash{}(lang); } +// We have four different kinds of declarations: +// +// Atom, the root data type that holds the value in question. +// NodeArray, a fixed sized sequence of `Id`s. +// NodeVector, a dynamically sized sequence of `Id`s. +// NodeFields, a fixed sized sequence of `Id`s accessed by field names rather than subscripts. + #define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom \ { \ static constexpr const char* tag = #name; \ + using Atom::Atom; \ + } + +#define LUAU_EQSAT_NODE_ARRAY(name, ops) \ + struct name : public ::Luau::EqSat::NodeVector> \ + { \ + static constexpr const char* tag = #name; \ + using NodeVector::NodeVector; \ + } + +#define LUAU_EQSAT_NODE_VECTOR(name) \ + struct name : public ::Luau::EqSat::NodeVector> \ + { \ + static constexpr const char* tag = #name; \ + using NodeVector::NodeVector; \ } #define LUAU_EQSAT_FIELD(name) \ @@ -39,25 +61,27 @@ std::size_t languageHash(const T& lang) { \ } -#define LUAU_EQSAT_UNARY_NODE(name, field) \ - struct name : public ::Luau::EqSat::Node \ +#define LUAU_EQSAT_NODE_FIELDS(name, ...) \ + struct name : public ::Luau::EqSat::NodeFields \ { \ static constexpr const char* tag = #name; \ - using Base::Node; \ - } - -#define LUAU_EQSAT_BINARY_NODE(name, field1, field2) \ - struct name : public ::Luau::EqSat::Node \ - { \ - static constexpr const char* tag = #name; \ - using Base::Node; \ + using NodeFields::NodeFields; \ } template struct Atom { - T value; + Atom(const T& value) + : _value(value) + { + } + const T& value() const + { + return _value; + } + +public: Slice operands() { return {}; @@ -65,7 +89,7 @@ struct Atom bool operator==(const Atom& rhs) const { - return value == rhs.value; + return _value == rhs._value; } bool operator!=(const Atom& rhs) const @@ -77,9 +101,54 @@ struct Atom { size_t operator()(const Atom& value) const { - return languageHash(value.value); + return languageHash(value._value); } }; + +private: + T _value; +}; + +template +struct NodeVector +{ + template + NodeVector(Args&&... args) + : vector{std::forward(args)...} + { + } + + const Id& operator[](size_t i) const + { + return vector[i]; + } + +public: + Slice operands() + { + return {vector.data(), vector.size()}; + } + + bool operator==(const NodeVector& rhs) const + { + return vector == rhs.vector; + } + + bool operator!=(const NodeVector& rhs) const + { + return !(*this == rhs); + } + + struct Hash + { + size_t operator()(const NodeVector& value) const + { + return languageHash(value.vector); + } + }; + +private: + T vector; }; /// Empty base class just for static_asserts. @@ -93,7 +162,7 @@ struct Field : FieldBase }; template -class Node +class NodeFields { static_assert(std::conjunction...>::value); @@ -102,10 +171,8 @@ class Node template static constexpr int getIndex() { - using TT = std::decay_t; - constexpr int N = sizeof...(Fields); - constexpr bool is[N] = {std::is_same_v...}; + constexpr bool is[N] = {std::is_same_v, Fields>...}; for (int i = 0; i < N; ++i) if (is[i]) @@ -115,10 +182,8 @@ class Node } public: - using Base = Node; - template - Node(Args&&... args) + NodeFields(Args&&... args) : array{std::forward(args)...} { } @@ -135,19 +200,19 @@ public: return array[getIndex()]; } - bool operator==(const Node& rhs) const + bool operator==(const NodeFields& rhs) const { return array == rhs.array; } - bool operator!=(const Node& rhs) const + bool operator!=(const NodeFields& rhs) const { return !(*this == rhs); } struct Hash { - size_t operator()(const Node& value) const + size_t operator()(const NodeFields& value) const { return languageHash(value.array); } diff --git a/tests/EqSat.language.test.cpp b/tests/EqSat.language.test.cpp index 9d041364..282d4ad2 100644 --- a/tests/EqSat.language.test.cpp +++ b/tests/EqSat.language.test.cpp @@ -13,8 +13,7 @@ LUAU_EQSAT_ATOM(Str, std::string); LUAU_EQSAT_FIELD(Left); LUAU_EQSAT_FIELD(Right); - -LUAU_EQSAT_BINARY_NODE(Add, Left, Right); +LUAU_EQSAT_NODE_FIELDS(Add, Left, Right); using namespace Luau; @@ -40,7 +39,7 @@ TEST_CASE("language_get") auto i = v.get(); REQUIRE(i); - CHECK(i->value); + CHECK(i->value()); CHECK(!v.get()); } @@ -54,7 +53,7 @@ TEST_CASE("language_copy_ctor") auto i2 = v2.get(); REQUIRE(i1); REQUIRE(i2); - CHECK(i1->value == i2->value); + CHECK(i1->value() == i2->value()); } TEST_CASE("language_move_ctor") @@ -63,18 +62,18 @@ TEST_CASE("language_move_ctor") { auto s1 = v1.get(); REQUIRE(s1); - CHECK(s1->value == "hello"); + CHECK(s1->value() == "hello"); } Value v2 = std::move(v1); auto s1 = v1.get(); REQUIRE(s1); - CHECK(s1->value == ""); // this also tests the dtor. + CHECK(s1->value() == ""); // this also tests the dtor. auto s2 = v2.get(); REQUIRE(s2); - CHECK(s2->value == "hello"); + CHECK(s2->value() == "hello"); } TEST_CASE("language_equality") diff --git a/tests/EqSat.propositional.test.cpp b/tests/EqSat.propositional.test.cpp index 47a35cac..4a4ffa7f 100644 --- a/tests/EqSat.propositional.test.cpp +++ b/tests/EqSat.propositional.test.cpp @@ -11,16 +11,16 @@ LUAU_EQSAT_ATOM(Var, std::string); LUAU_EQSAT_ATOM(Bool, bool); LUAU_EQSAT_FIELD(Negated); -LUAU_EQSAT_UNARY_NODE(Not, Negated); +LUAU_EQSAT_NODE_FIELDS(Not, Negated); LUAU_EQSAT_FIELD(Left); LUAU_EQSAT_FIELD(Right); -LUAU_EQSAT_BINARY_NODE(And, Left, Right); -LUAU_EQSAT_BINARY_NODE(Or, Left, Right); +LUAU_EQSAT_NODE_FIELDS(And, Left, Right); +LUAU_EQSAT_NODE_FIELDS(Or, Left, Right); LUAU_EQSAT_FIELD(Antecedent); LUAU_EQSAT_FIELD(Consequent); -LUAU_EQSAT_BINARY_NODE(Implies, Antecedent, Consequent); +LUAU_EQSAT_NODE_FIELDS(Implies, Antecedent, Consequent); using namespace Luau; @@ -39,7 +39,7 @@ struct ConstantFold Data make(const EGraph& egraph, const Bool& b) const { - return b.value; + return b.value(); } Data make(const EGraph& egraph, const Not& n) const