diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h index e887c532..0d967a6c 100644 --- a/EqSat/include/Luau/Language.h +++ b/EqSat/include/Luau/Language.h @@ -1,6 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Id.h" + +#include #include #include #include @@ -8,11 +11,39 @@ namespace Luau::EqSat { -#define LUAU_EQSAT_ATOM(name, t) LUAU_EQSAT_ATOM_CUSTOM(name, #name, t) -#define LUAU_EQSAT_ATOM_CUSTOM(name, custom, t) \ +#define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom \ { \ - static constexpr const char* tag = custom; \ + static constexpr const char* tag = #name; \ + } + +#define LUAU_EQSAT_FIELD(name) \ + struct name : public ::Luau::EqSat::Field \ + { \ + } + +#define LUAU_EQSAT_UNARY_NODE(name, field) \ + struct name : public ::Luau::EqSat::Node \ + { \ + static constexpr const char* tag = #name; \ + using Base::Node; \ + } + +#define LUAU_EQSAT_BINARY_NODE(name, field1, field2) \ + struct name : public ::Luau::EqSat::Node \ + { \ + static constexpr const char* tag = #name; \ + using Base::Node; \ + } + +#define DERIVE_EQ(name, field) \ + bool operator==(const name& rhs) const \ + { \ + return field == rhs.field; \ + } \ + bool operator!=(const name& rhs) const \ + { \ + return !(*this == rhs); \ } template @@ -20,15 +51,7 @@ struct Atom { T value; - bool operator==(const Atom& rhs) const - { - return value == rhs.value; - } - - bool operator!=(const Atom& rhs) const - { - return !(*this == rhs); - } + DERIVE_EQ(Atom, value); struct Hash { @@ -39,6 +62,84 @@ struct Atom }; }; +/// Empty base class just for static_asserts. +struct FieldBase +{ +}; + +template +struct Field : FieldBase +{ + Id id; + + Field(Id id) + : id(id) + { + } + + DERIVE_EQ(Field, id); + + struct Hash + { + size_t operator()(const Field& field) const + { + return std::hash{}(field.id); + } + }; +}; + +template +class Node +{ + static_assert(std::conjunction...>::value); + + std::array array; + + template + static constexpr int getIndex() + { + using TT = std::decay_t; + + constexpr int N = sizeof...(Fields); + constexpr bool is[N] = {std::is_same_v...}; + + for (int i = 0; i < N; ++i) + if (is[i]) + return i; + + return -1; + } + +public: + using Base = Node; + + template + Node(Args&&... args) + : array{std::forward(args)...} + { + } + + template + Id field() const + { + static_assert(std::is_base_of::value); + static_assert(getIndex() >= 0); + return array[getIndex()]; + } + + DERIVE_EQ(Node, array); + + struct Hash + { + size_t operator()(const Node& node) const + { + return 0; + } + }; +}; + +#undef DERIVE_EQ + // `Language` is very similar to `Luau::Variant` with enough differences warranting a different type altogether. // // Firstly, where `Luau::Variant` uses an `int` to decide which type the variant currently holds, we use diff --git a/tests/EqSat.language.test.cpp b/tests/EqSat.language.test.cpp index e5f5d772..da90ff6b 100644 --- a/tests/EqSat.language.test.cpp +++ b/tests/EqSat.language.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include +#include "Luau/Id.h" #include "Luau/Language.h" #include @@ -10,9 +11,14 @@ LUAU_EQSAT_ATOM(I32, int); LUAU_EQSAT_ATOM(Bool, bool); LUAU_EQSAT_ATOM(Str, std::string); +LUAU_EQSAT_FIELD(Left); +LUAU_EQSAT_FIELD(Right); + +LUAU_EQSAT_BINARY_NODE(Add, Left, Right); + using namespace Luau; -using Value = EqSat::Language; +using Value = EqSat::Language; TEST_SUITE_BEGIN("EqSatLanguage"); @@ -22,6 +28,12 @@ TEST_CASE("atom_equality") CHECK(I32{0} != I32{1}); } +TEST_CASE("node_equality") +{ + CHECK(Add{EqSat::Id{0}, EqSat::Id{0}} == Add{EqSat::Id{0}, EqSat::Id{0}}); + CHECK(Add{EqSat::Id{1}, EqSat::Id{0}} != Add{EqSat::Id{0}, EqSat::Id{0}}); +} + TEST_CASE("language_get") { Value v{I32{5}}; @@ -71,10 +83,12 @@ TEST_CASE("language_equality") Value v2{I32{0}}; Value v3{I32{1}}; Value v4{Bool{true}}; + Value v5{Add{EqSat::Id{0}, EqSat::Id{1}}}; CHECK(v1 == v2); CHECK(v2 != v3); CHECK(v3 != v4); + CHECK(v4 != v5); } TEST_CASE("language_is_mappable") @@ -84,14 +98,33 @@ TEST_CASE("language_is_mappable") Value v1{I32{5}}; Value v2{I32{5}}; Value v3{Bool{true}}; + Value v4{Add{EqSat::Id{0}, EqSat::Id{1}}}; map[v1] = 1; map[v2] = 2; map[v3] = 42; + map[v4] = 37; CHECK(map[v1] == 2); CHECK(map[v2] == 2); CHECK(map[v3] == 42); + CHECK(map[v4] == 37); +} + +TEST_CASE("node_field") +{ + EqSat::Id left{0}; + EqSat::Id right{1}; + + Add add{left, right}; + + EqSat::Id left2 = add.field(); + EqSat::Id right2 = add.field(); + + CHECK(left == left2); + CHECK(left != right2); + CHECK(right == right2); + CHECK(right != left2); } TEST_SUITE_END();