Add nodes that can point to child enodes.

This commit is contained in:
Alexander McCord 2024-05-26 19:47:27 -07:00
parent 0d882367f7
commit 2a2de1cea2
2 changed files with 147 additions and 13 deletions

View file

@ -1,6 +1,9 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Id.h"
#include <array>
#include <algorithm>
#include <type_traits>
#include <utility>
@ -8,11 +11,39 @@
namespace Luau::EqSat
{
#define LUAU_EQSAT_ATOM(name, t) LUAU_EQSAT_ATOM_CUSTOM(name, #name, t)
#define LUAU_EQSAT_ATOM_CUSTOM(name, custom, t) \
#define LUAU_EQSAT_ATOM(name, t) \
struct name : public ::Luau::EqSat::Atom<name, t> \
{ \
static constexpr const char* tag = custom; \
static constexpr const char* tag = #name; \
}
#define LUAU_EQSAT_FIELD(name) \
struct name : public ::Luau::EqSat::Field<name> \
{ \
}
#define LUAU_EQSAT_UNARY_NODE(name, field) \
struct name : public ::Luau::EqSat::Node<name, field> \
{ \
static constexpr const char* tag = #name; \
using Base::Node; \
}
#define LUAU_EQSAT_BINARY_NODE(name, field1, field2) \
struct name : public ::Luau::EqSat::Node<name, field1, field2> \
{ \
static constexpr const char* tag = #name; \
using Base::Node; \
}
#define DERIVE_EQ(name, field) \
bool operator==(const name& rhs) const \
{ \
return field == rhs.field; \
} \
bool operator!=(const name& rhs) const \
{ \
return !(*this == rhs); \
}
template<typename B, typename T>
@ -20,15 +51,7 @@ struct Atom
{
T value;
bool operator==(const Atom& rhs) const
{
return value == rhs.value;
}
bool operator!=(const Atom& rhs) const
{
return !(*this == rhs);
}
DERIVE_EQ(Atom, value);
struct Hash
{
@ -39,6 +62,84 @@ struct Atom
};
};
/// Empty base class just for static_asserts.
struct FieldBase
{
};
template<typename T>
struct Field : FieldBase
{
Id id;
Field(Id id)
: id(id)
{
}
DERIVE_EQ(Field, id);
struct Hash
{
size_t operator()(const Field& field) const
{
return std::hash<Id>{}(field.id);
}
};
};
template<typename B, typename... Fields>
class Node
{
static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value);
std::array<Id, sizeof...(Fields)> array;
template<typename T>
static constexpr int getIndex()
{
using TT = std::decay_t<T>;
constexpr int N = sizeof...(Fields);
constexpr bool is[N] = {std::is_same_v<TT, Fields>...};
for (int i = 0; i < N; ++i)
if (is[i])
return i;
return -1;
}
public:
using Base = Node;
template<typename... Args>
Node(Args&&... args)
: array{std::forward<Args>(args)...}
{
}
template<typename T>
Id field() const
{
static_assert(std::is_base_of<FieldBase, T>::value);
static_assert(getIndex<T>() >= 0);
return array[getIndex<T>()];
}
DERIVE_EQ(Node, array);
struct Hash
{
size_t operator()(const Node& node) const
{
return 0;
}
};
};
#undef DERIVE_EQ
// `Language` is very similar to `Luau::Variant` with enough differences warranting a different type altogether.
//
// Firstly, where `Luau::Variant` uses an `int` to decide which type the variant currently holds, we use

View file

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include <doctest.h>
#include "Luau/Id.h"
#include "Luau/Language.h"
#include <string>
@ -10,9 +11,14 @@ LUAU_EQSAT_ATOM(I32, int);
LUAU_EQSAT_ATOM(Bool, bool);
LUAU_EQSAT_ATOM(Str, std::string);
LUAU_EQSAT_FIELD(Left);
LUAU_EQSAT_FIELD(Right);
LUAU_EQSAT_BINARY_NODE(Add, Left, Right);
using namespace Luau;
using Value = EqSat::Language<I32, Bool, Str>;
using Value = EqSat::Language<I32, Bool, Str, Add>;
TEST_SUITE_BEGIN("EqSatLanguage");
@ -22,6 +28,12 @@ TEST_CASE("atom_equality")
CHECK(I32{0} != I32{1});
}
TEST_CASE("node_equality")
{
CHECK(Add{EqSat::Id{0}, EqSat::Id{0}} == Add{EqSat::Id{0}, EqSat::Id{0}});
CHECK(Add{EqSat::Id{1}, EqSat::Id{0}} != Add{EqSat::Id{0}, EqSat::Id{0}});
}
TEST_CASE("language_get")
{
Value v{I32{5}};
@ -71,10 +83,12 @@ TEST_CASE("language_equality")
Value v2{I32{0}};
Value v3{I32{1}};
Value v4{Bool{true}};
Value v5{Add{EqSat::Id{0}, EqSat::Id{1}}};
CHECK(v1 == v2);
CHECK(v2 != v3);
CHECK(v3 != v4);
CHECK(v4 != v5);
}
TEST_CASE("language_is_mappable")
@ -84,14 +98,33 @@ TEST_CASE("language_is_mappable")
Value v1{I32{5}};
Value v2{I32{5}};
Value v3{Bool{true}};
Value v4{Add{EqSat::Id{0}, EqSat::Id{1}}};
map[v1] = 1;
map[v2] = 2;
map[v3] = 42;
map[v4] = 37;
CHECK(map[v1] == 2);
CHECK(map[v2] == 2);
CHECK(map[v3] == 42);
CHECK(map[v4] == 37);
}
TEST_CASE("node_field")
{
EqSat::Id left{0};
EqSat::Id right{1};
Add add{left, right};
EqSat::Id left2 = add.field<Left>();
EqSat::Id right2 = add.field<Right>();
CHECK(left == left2);
CHECK(left != right2);
CHECK(right == right2);
CHECK(right != left2);
}
TEST_SUITE_END();