Rework the API entirely.

This commit is contained in:
Alexander McCord 2024-06-02 17:30:15 -07:00
parent ee62df2ffb
commit bf62839634
3 changed files with 99 additions and 35 deletions

View file

@ -28,10 +28,32 @@ std::size_t languageHash(const T& lang)
return LanguageHash<T>{}(lang);
}
// We have four different kinds of declarations:
//
// Atom, the root data type that holds the value in question.
// NodeArray, a fixed sized sequence of `Id`s.
// NodeVector, a dynamically sized sequence of `Id`s.
// NodeFields, a fixed sized sequence of `Id`s accessed by field names rather than subscripts.
#define LUAU_EQSAT_ATOM(name, t) \
struct name : public ::Luau::EqSat::Atom<name, t> \
{ \
static constexpr const char* tag = #name; \
using Atom::Atom; \
}
#define LUAU_EQSAT_NODE_ARRAY(name, ops) \
struct name : public ::Luau::EqSat::NodeVector<name, std::array<::Luau::EqSat::Id, ops>> \
{ \
static constexpr const char* tag = #name; \
using NodeVector::NodeVector; \
}
#define LUAU_EQSAT_NODE_VECTOR(name) \
struct name : public ::Luau::EqSat::NodeVector<name, std::vector<::Luau::EqSat::Id>> \
{ \
static constexpr const char* tag = #name; \
using NodeVector::NodeVector; \
}
#define LUAU_EQSAT_FIELD(name) \
@ -39,25 +61,27 @@ std::size_t languageHash(const T& lang)
{ \
}
#define LUAU_EQSAT_UNARY_NODE(name, field) \
struct name : public ::Luau::EqSat::Node<name, field> \
#define LUAU_EQSAT_NODE_FIELDS(name, ...) \
struct name : public ::Luau::EqSat::NodeFields<name, __VA_ARGS__> \
{ \
static constexpr const char* tag = #name; \
using Base::Node; \
}
#define LUAU_EQSAT_BINARY_NODE(name, field1, field2) \
struct name : public ::Luau::EqSat::Node<name, field1, field2> \
{ \
static constexpr const char* tag = #name; \
using Base::Node; \
using NodeFields::NodeFields; \
}
template<typename Phantom, typename T>
struct Atom
{
T value;
Atom(const T& value)
: _value(value)
{
}
const T& value() const
{
return _value;
}
public:
Slice<Id> operands()
{
return {};
@ -65,7 +89,7 @@ struct Atom
bool operator==(const Atom& rhs) const
{
return value == rhs.value;
return _value == rhs._value;
}
bool operator!=(const Atom& rhs) const
@ -77,9 +101,54 @@ struct Atom
{
size_t operator()(const Atom& value) const
{
return languageHash(value.value);
return languageHash(value._value);
}
};
private:
T _value;
};
template<typename Phantom, typename T>
struct NodeVector
{
template<typename... Args>
NodeVector(Args&&... args)
: vector{std::forward<Args>(args)...}
{
}
const Id& operator[](size_t i) const
{
return vector[i];
}
public:
Slice<Id> operands()
{
return {vector.data(), vector.size()};
}
bool operator==(const NodeVector& rhs) const
{
return vector == rhs.vector;
}
bool operator!=(const NodeVector& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const NodeVector& value) const
{
return languageHash(value.vector);
}
};
private:
T vector;
};
/// Empty base class just for static_asserts.
@ -93,7 +162,7 @@ struct Field : FieldBase
};
template<typename Phantom, typename... Fields>
class Node
class NodeFields
{
static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value);
@ -102,10 +171,8 @@ class Node
template<typename T>
static constexpr int getIndex()
{
using TT = std::decay_t<T>;
constexpr int N = sizeof...(Fields);
constexpr bool is[N] = {std::is_same_v<TT, Fields>...};
constexpr bool is[N] = {std::is_same_v<std::decay_t<T>, Fields>...};
for (int i = 0; i < N; ++i)
if (is[i])
@ -115,10 +182,8 @@ class Node
}
public:
using Base = Node;
template<typename... Args>
Node(Args&&... args)
NodeFields(Args&&... args)
: array{std::forward<Args>(args)...}
{
}
@ -135,19 +200,19 @@ public:
return array[getIndex<T>()];
}
bool operator==(const Node& rhs) const
bool operator==(const NodeFields& rhs) const
{
return array == rhs.array;
}
bool operator!=(const Node& rhs) const
bool operator!=(const NodeFields& rhs) const
{
return !(*this == rhs);
}
struct Hash
{
size_t operator()(const Node& value) const
size_t operator()(const NodeFields& value) const
{
return languageHash(value.array);
}

View file

@ -13,8 +13,7 @@ LUAU_EQSAT_ATOM(Str, std::string);
LUAU_EQSAT_FIELD(Left);
LUAU_EQSAT_FIELD(Right);
LUAU_EQSAT_BINARY_NODE(Add, Left, Right);
LUAU_EQSAT_NODE_FIELDS(Add, Left, Right);
using namespace Luau;
@ -40,7 +39,7 @@ TEST_CASE("language_get")
auto i = v.get<I32>();
REQUIRE(i);
CHECK(i->value);
CHECK(i->value());
CHECK(!v.get<Bool>());
}
@ -54,7 +53,7 @@ TEST_CASE("language_copy_ctor")
auto i2 = v2.get<I32>();
REQUIRE(i1);
REQUIRE(i2);
CHECK(i1->value == i2->value);
CHECK(i1->value() == i2->value());
}
TEST_CASE("language_move_ctor")
@ -63,18 +62,18 @@ TEST_CASE("language_move_ctor")
{
auto s1 = v1.get<Str>();
REQUIRE(s1);
CHECK(s1->value == "hello");
CHECK(s1->value() == "hello");
}
Value v2 = std::move(v1);
auto s1 = v1.get<Str>();
REQUIRE(s1);
CHECK(s1->value == ""); // this also tests the dtor.
CHECK(s1->value() == ""); // this also tests the dtor.
auto s2 = v2.get<Str>();
REQUIRE(s2);
CHECK(s2->value == "hello");
CHECK(s2->value() == "hello");
}
TEST_CASE("language_equality")

View file

@ -11,16 +11,16 @@ LUAU_EQSAT_ATOM(Var, std::string);
LUAU_EQSAT_ATOM(Bool, bool);
LUAU_EQSAT_FIELD(Negated);
LUAU_EQSAT_UNARY_NODE(Not, Negated);
LUAU_EQSAT_NODE_FIELDS(Not, Negated);
LUAU_EQSAT_FIELD(Left);
LUAU_EQSAT_FIELD(Right);
LUAU_EQSAT_BINARY_NODE(And, Left, Right);
LUAU_EQSAT_BINARY_NODE(Or, Left, Right);
LUAU_EQSAT_NODE_FIELDS(And, Left, Right);
LUAU_EQSAT_NODE_FIELDS(Or, Left, Right);
LUAU_EQSAT_FIELD(Antecedent);
LUAU_EQSAT_FIELD(Consequent);
LUAU_EQSAT_BINARY_NODE(Implies, Antecedent, Consequent);
LUAU_EQSAT_NODE_FIELDS(Implies, Antecedent, Consequent);
using namespace Luau;
@ -39,7 +39,7 @@ struct ConstantFold
Data make(const EGraph& egraph, const Bool& b) const
{
return b.value;
return b.value();
}
Data make(const EGraph& egraph, const Not& n) const