diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h index bd4bd9d9..ba5434f5 100644 --- a/EqSat/include/Luau/Language.h +++ b/EqSat/include/Luau/Language.h @@ -8,9 +8,9 @@ namespace Luau::EqSat { -#define LUAU_EQSAT_ATOM(name, value) LUAU_EQSAT_ATOM_CUSTOM(name, #name, value) -#define LUAU_EQSAT_ATOM_CUSTOM(name, custom, value) \ - struct name : public ::Luau::EqSat::Atom \ +#define LUAU_EQSAT_ATOM(name, t) LUAU_EQSAT_ATOM_CUSTOM(name, #name, t) +#define LUAU_EQSAT_ATOM_CUSTOM(name, custom, t) \ + struct name : public ::Luau::EqSat::Atom \ { \ static constexpr const char* tag = custom; \ } @@ -19,6 +19,16 @@ template struct Atom { T value; + + bool operator==(const Atom& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const Atom& rhs) const + { + return !(*this == rhs); + } }; // `Language` is very similar to `Luau::Variant` with enough differences warranting a different type altogether. @@ -38,25 +48,116 @@ class Language const char* tag; char buffer[std::max({sizeof(Ts)...})]; + template + using WithinDomain = std::disjunction, Ts>...>; + + using FnCopy = void (*)(void*, const void*); + using FnMove = void (*)(void*, void*); + using FnDtor = void (*)(void*); + using FnPred = bool (*)(const void*, const void*); + + template + static void fnCopy(void* dst, const void* src) + { + new (dst) T(*static_cast(src)); + } + + template + static void fnMove(void* dst, void* src) + { + // static_cast is equivalent to std::move() but faster in Debug + new (dst) T(static_cast(*static_cast(src))); + } + + template + static void fnDtor(void* dst) + { + static_cast(dst)->~T(); + } + + template + static bool fnPred(const void* lhs, const void* rhs) + { + return *static_cast(lhs) == *static_cast(rhs); + } + + static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy...}; + static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove...}; + static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor...}; + static constexpr FnPred tablePred[sizeof...(Ts)] = {&fnPred...}; + + static constexpr int getIndexFromTag(const char* tag) + { + constexpr int N = sizeof...(Ts); + constexpr const char* is[N] = {Ts::tag...}; + + for (int i = 0; i < N; ++i) + if (is[i] == tag) + return i; + + return -1; + } + public: template - Language(T&& t) + Language(T&& t, std::enable_if_t::value>* = 0) { - using TT = std::decay_t; - static_assert(std::disjunction_v...>); - tag = T::tag; - new (&buffer) TT(std::forward(t)); + new (&buffer) std::decay_t(std::forward(t)); + } + + Language(const Language& other) + { + tag = other.tag; + tableCopy[getIndexFromTag(tag)](&buffer, &other.buffer); + } + + Language(Language&& other) + { + tag = other.tag; + tableMove[getIndexFromTag(tag)](&buffer, &other.buffer); + } + + ~Language() + { + tableDtor[getIndexFromTag(tag)](&buffer); + } + + Language& operator=(const Language& other) + { + Language copy{other}; + *this = static_cast(copy); + return *this; + } + + Language& operator=(Language&& other) + { + if (this != &other) + { + tableDtor[getIndexFromTag(tag)](&buffer); + tag = other.tag; + tableMove[getIndexFromTag(tag)](&buffer, &other.buffer); // nothrow + } + return *this; } template const T* get() const { - static_assert(std::disjunction_v, Ts>...>); - + static_assert(WithinDomain::value); return tag == T::tag ? reinterpret_cast(&buffer) : nullptr; } + bool operator==(const Language& rhs) const + { + return tag == rhs.tag && tablePred[getIndexFromTag(tag)](&buffer, &rhs.buffer); + } + + bool operator!=(const Language& rhs) const + { + return !(*this == rhs); + } + public: struct Hash { diff --git a/tests/EqSat.language.test.cpp b/tests/EqSat.language.test.cpp index 6977d50b..a539188f 100644 --- a/tests/EqSat.language.test.cpp +++ b/tests/EqSat.language.test.cpp @@ -3,24 +3,77 @@ #include "Luau/Language.h" -LUAU_EQSAT_ATOM(Atom1, bool); -LUAU_EQSAT_ATOM(Atom2, bool); +#include + +LUAU_EQSAT_ATOM(I32, int); +LUAU_EQSAT_ATOM(Bool, bool); +LUAU_EQSAT_ATOM(Str, std::string); using namespace Luau; -using Mini = EqSat::Language; +using Value = EqSat::Language; TEST_SUITE_BEGIN("EqSatLanguage"); -TEST_CASE("language_get_works") +TEST_CASE("atom_equality") { - Mini m{Atom1{true}}; + CHECK(I32{0} == I32{0}); + CHECK(I32{0} != I32{1}); +} - const Atom1* atom = m.get(); - REQUIRE(atom); - CHECK(atom->value); +TEST_CASE("language_get") +{ + Value v{I32{5}}; - CHECK(!m.get()); + auto i = v.get(); + REQUIRE(i); + CHECK(i->value); + + CHECK(!v.get()); +} + +TEST_CASE("language_copy_ctor") +{ + Value v1{I32{5}}; + Value v2 = v1; + + auto i1 = v1.get(); + auto i2 = v2.get(); + REQUIRE(i1); + REQUIRE(i2); + CHECK(i1->value == i2->value); +} + +TEST_CASE("language_move_ctor") +{ + Value v1{Str{"hello"}}; + { + auto s1 = v1.get(); + REQUIRE(s1); + CHECK(s1->value == "hello"); + } + + Value v2 = std::move(v1); + + auto s1 = v1.get(); + REQUIRE(s1); + CHECK(s1->value == ""); // this also tests the dtor. + + auto s2 = v2.get(); + REQUIRE(s2); + CHECK(s2->value == "hello"); +} + +TEST_CASE("language_equality") +{ + Value v1{I32{0}}; + Value v2{I32{0}}; + Value v3{I32{1}}; + Value v4{Bool{true}}; + + CHECK(v1 == v2); + CHECK(v2 != v3); + CHECK(v3 != v4); } TEST_SUITE_END();