Rework the API entirely.

This commit is contained in:
Alexander McCord 2024-06-02 17:30:15 -07:00
parent ee62df2ffb
commit bf62839634
3 changed files with 99 additions and 35 deletions

View file

@ -28,10 +28,32 @@ std::size_t languageHash(const T& lang)
return LanguageHash<T>{}(lang); return LanguageHash<T>{}(lang);
} }
// We have four different kinds of declarations:
//
// Atom, the root data type that holds the value in question.
// NodeArray, a fixed sized sequence of `Id`s.
// NodeVector, a dynamically sized sequence of `Id`s.
// NodeFields, a fixed sized sequence of `Id`s accessed by field names rather than subscripts.
#define LUAU_EQSAT_ATOM(name, t) \ #define LUAU_EQSAT_ATOM(name, t) \
struct name : public ::Luau::EqSat::Atom<name, t> \ struct name : public ::Luau::EqSat::Atom<name, t> \
{ \ { \
static constexpr const char* tag = #name; \ static constexpr const char* tag = #name; \
using Atom::Atom; \
}
#define LUAU_EQSAT_NODE_ARRAY(name, ops) \
struct name : public ::Luau::EqSat::NodeVector<name, std::array<::Luau::EqSat::Id, ops>> \
{ \
static constexpr const char* tag = #name; \
using NodeVector::NodeVector; \
}
#define LUAU_EQSAT_NODE_VECTOR(name) \
struct name : public ::Luau::EqSat::NodeVector<name, std::vector<::Luau::EqSat::Id>> \
{ \
static constexpr const char* tag = #name; \
using NodeVector::NodeVector; \
} }
#define LUAU_EQSAT_FIELD(name) \ #define LUAU_EQSAT_FIELD(name) \
@ -39,25 +61,27 @@ std::size_t languageHash(const T& lang)
{ \ { \
} }
#define LUAU_EQSAT_UNARY_NODE(name, field) \ #define LUAU_EQSAT_NODE_FIELDS(name, ...) \
struct name : public ::Luau::EqSat::Node<name, field> \ struct name : public ::Luau::EqSat::NodeFields<name, __VA_ARGS__> \
{ \ { \
static constexpr const char* tag = #name; \ static constexpr const char* tag = #name; \
using Base::Node; \ using NodeFields::NodeFields; \
}
#define LUAU_EQSAT_BINARY_NODE(name, field1, field2) \
struct name : public ::Luau::EqSat::Node<name, field1, field2> \
{ \
static constexpr const char* tag = #name; \
using Base::Node; \
} }
template<typename Phantom, typename T> template<typename Phantom, typename T>
struct Atom struct Atom
{ {
T value; Atom(const T& value)
: _value(value)
{
}
const T& value() const
{
return _value;
}
public:
Slice<Id> operands() Slice<Id> operands()
{ {
return {}; return {};
@ -65,7 +89,7 @@ struct Atom
bool operator==(const Atom& rhs) const bool operator==(const Atom& rhs) const
{ {
return value == rhs.value; return _value == rhs._value;
} }
bool operator!=(const Atom& rhs) const bool operator!=(const Atom& rhs) const
@ -77,9 +101,54 @@ struct Atom
{ {
size_t operator()(const Atom& value) const size_t operator()(const Atom& value) const
{ {
return languageHash(value.value); return languageHash(value._value);
} }
}; };
private:
T _value;
};
template<typename Phantom, typename T>
struct NodeVector
{
template<typename... Args>
NodeVector(Args&&... args)
: vector{std::forward<Args>(args)...}
{
}
const Id& operator[](size_t i) const
{
return vector[i];
}
public:
Slice<Id> operands()
{
return {vector.data(), vector.size()};
}
bool operator==(const NodeVector& rhs) const
{
return vector == rhs.vector;
}
bool operator!=(const NodeVector& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const NodeVector& value) const
{
return languageHash(value.vector);
}
};
private:
T vector;
}; };
/// Empty base class just for static_asserts. /// Empty base class just for static_asserts.
@ -93,7 +162,7 @@ struct Field : FieldBase
}; };
template<typename Phantom, typename... Fields> template<typename Phantom, typename... Fields>
class Node class NodeFields
{ {
static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value); static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value);
@ -102,10 +171,8 @@ class Node
template<typename T> template<typename T>
static constexpr int getIndex() static constexpr int getIndex()
{ {
using TT = std::decay_t<T>;
constexpr int N = sizeof...(Fields); constexpr int N = sizeof...(Fields);
constexpr bool is[N] = {std::is_same_v<TT, Fields>...}; constexpr bool is[N] = {std::is_same_v<std::decay_t<T>, Fields>...};
for (int i = 0; i < N; ++i) for (int i = 0; i < N; ++i)
if (is[i]) if (is[i])
@ -115,10 +182,8 @@ class Node
} }
public: public:
using Base = Node;
template<typename... Args> template<typename... Args>
Node(Args&&... args) NodeFields(Args&&... args)
: array{std::forward<Args>(args)...} : array{std::forward<Args>(args)...}
{ {
} }
@ -135,19 +200,19 @@ public:
return array[getIndex<T>()]; return array[getIndex<T>()];
} }
bool operator==(const Node& rhs) const bool operator==(const NodeFields& rhs) const
{ {
return array == rhs.array; return array == rhs.array;
} }
bool operator!=(const Node& rhs) const bool operator!=(const NodeFields& rhs) const
{ {
return !(*this == rhs); return !(*this == rhs);
} }
struct Hash struct Hash
{ {
size_t operator()(const Node& value) const size_t operator()(const NodeFields& value) const
{ {
return languageHash(value.array); return languageHash(value.array);
} }

View file

@ -13,8 +13,7 @@ LUAU_EQSAT_ATOM(Str, std::string);
LUAU_EQSAT_FIELD(Left); LUAU_EQSAT_FIELD(Left);
LUAU_EQSAT_FIELD(Right); LUAU_EQSAT_FIELD(Right);
LUAU_EQSAT_NODE_FIELDS(Add, Left, Right);
LUAU_EQSAT_BINARY_NODE(Add, Left, Right);
using namespace Luau; using namespace Luau;
@ -40,7 +39,7 @@ TEST_CASE("language_get")
auto i = v.get<I32>(); auto i = v.get<I32>();
REQUIRE(i); REQUIRE(i);
CHECK(i->value); CHECK(i->value());
CHECK(!v.get<Bool>()); CHECK(!v.get<Bool>());
} }
@ -54,7 +53,7 @@ TEST_CASE("language_copy_ctor")
auto i2 = v2.get<I32>(); auto i2 = v2.get<I32>();
REQUIRE(i1); REQUIRE(i1);
REQUIRE(i2); REQUIRE(i2);
CHECK(i1->value == i2->value); CHECK(i1->value() == i2->value());
} }
TEST_CASE("language_move_ctor") TEST_CASE("language_move_ctor")
@ -63,18 +62,18 @@ TEST_CASE("language_move_ctor")
{ {
auto s1 = v1.get<Str>(); auto s1 = v1.get<Str>();
REQUIRE(s1); REQUIRE(s1);
CHECK(s1->value == "hello"); CHECK(s1->value() == "hello");
} }
Value v2 = std::move(v1); Value v2 = std::move(v1);
auto s1 = v1.get<Str>(); auto s1 = v1.get<Str>();
REQUIRE(s1); REQUIRE(s1);
CHECK(s1->value == ""); // this also tests the dtor. CHECK(s1->value() == ""); // this also tests the dtor.
auto s2 = v2.get<Str>(); auto s2 = v2.get<Str>();
REQUIRE(s2); REQUIRE(s2);
CHECK(s2->value == "hello"); CHECK(s2->value() == "hello");
} }
TEST_CASE("language_equality") TEST_CASE("language_equality")

View file

@ -11,16 +11,16 @@ LUAU_EQSAT_ATOM(Var, std::string);
LUAU_EQSAT_ATOM(Bool, bool); LUAU_EQSAT_ATOM(Bool, bool);
LUAU_EQSAT_FIELD(Negated); LUAU_EQSAT_FIELD(Negated);
LUAU_EQSAT_UNARY_NODE(Not, Negated); LUAU_EQSAT_NODE_FIELDS(Not, Negated);
LUAU_EQSAT_FIELD(Left); LUAU_EQSAT_FIELD(Left);
LUAU_EQSAT_FIELD(Right); LUAU_EQSAT_FIELD(Right);
LUAU_EQSAT_BINARY_NODE(And, Left, Right); LUAU_EQSAT_NODE_FIELDS(And, Left, Right);
LUAU_EQSAT_BINARY_NODE(Or, Left, Right); LUAU_EQSAT_NODE_FIELDS(Or, Left, Right);
LUAU_EQSAT_FIELD(Antecedent); LUAU_EQSAT_FIELD(Antecedent);
LUAU_EQSAT_FIELD(Consequent); LUAU_EQSAT_FIELD(Consequent);
LUAU_EQSAT_BINARY_NODE(Implies, Antecedent, Consequent); LUAU_EQSAT_NODE_FIELDS(Implies, Antecedent, Consequent);
using namespace Luau; using namespace Luau;
@ -39,7 +39,7 @@ struct ConstantFold
Data make(const EGraph& egraph, const Bool& b) const Data make(const EGraph& egraph, const Bool& b) const
{ {
return b.value; return b.value();
} }
Data make(const EGraph& egraph, const Not& n) const Data make(const EGraph& egraph, const Not& n) const