Expand Language with variant stuff.

This commit is contained in:
Alexander McCord 2024-05-26 15:21:08 -07:00
parent d6f553a794
commit 38ccd662f2
2 changed files with 173 additions and 19 deletions

View file

@ -8,9 +8,9 @@
namespace Luau::EqSat
{
#define LUAU_EQSAT_ATOM(name, value) LUAU_EQSAT_ATOM_CUSTOM(name, #name, value)
#define LUAU_EQSAT_ATOM_CUSTOM(name, custom, value) \
struct name : public ::Luau::EqSat::Atom<name, value> \
#define LUAU_EQSAT_ATOM(name, t) LUAU_EQSAT_ATOM_CUSTOM(name, #name, t)
#define LUAU_EQSAT_ATOM_CUSTOM(name, custom, t) \
struct name : public ::Luau::EqSat::Atom<name, t> \
{ \
static constexpr const char* tag = custom; \
}
@ -19,6 +19,16 @@ template<typename B, typename T>
struct Atom
{
T value;
bool operator==(const Atom& rhs) const
{
return value == rhs.value;
}
bool operator!=(const Atom& rhs) const
{
return !(*this == rhs);
}
};
// `Language` is very similar to `Luau::Variant` with enough differences warranting a different type altogether.
@ -38,25 +48,116 @@ class Language
const char* tag;
char buffer[std::max({sizeof(Ts)...})];
template<typename T>
using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>;
using FnCopy = void (*)(void*, const void*);
using FnMove = void (*)(void*, void*);
using FnDtor = void (*)(void*);
using FnPred = bool (*)(const void*, const void*);
template<typename T>
static void fnCopy(void* dst, const void* src)
{
new (dst) T(*static_cast<const T*>(src));
}
template<typename T>
static void fnMove(void* dst, void* src)
{
// static_cast<T&&> is equivalent to std::move() but faster in Debug
new (dst) T(static_cast<T&&>(*static_cast<T*>(src)));
}
template<typename T>
static void fnDtor(void* dst)
{
static_cast<T*>(dst)->~T();
}
template<typename T>
static bool fnPred(const void* lhs, const void* rhs)
{
return *static_cast<const T*>(lhs) == *static_cast<const T*>(rhs);
}
static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy<Ts>...};
static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...};
static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...};
static constexpr FnPred tablePred[sizeof...(Ts)] = {&fnPred<Ts>...};
static constexpr int getIndexFromTag(const char* tag)
{
constexpr int N = sizeof...(Ts);
constexpr const char* is[N] = {Ts::tag...};
for (int i = 0; i < N; ++i)
if (is[i] == tag)
return i;
return -1;
}
public:
template<typename T>
Language(T&& t)
Language(T&& t, std::enable_if_t<WithinDomain<T>::value>* = 0)
{
using TT = std::decay_t<T>;
static_assert(std::disjunction_v<std::is_same<TT, Ts>...>);
tag = T::tag;
new (&buffer) TT(std::forward<T>(t));
new (&buffer) std::decay_t<T>(std::forward<T>(t));
}
Language(const Language& other)
{
tag = other.tag;
tableCopy[getIndexFromTag(tag)](&buffer, &other.buffer);
}
Language(Language&& other)
{
tag = other.tag;
tableMove[getIndexFromTag(tag)](&buffer, &other.buffer);
}
~Language()
{
tableDtor[getIndexFromTag(tag)](&buffer);
}
Language& operator=(const Language& other)
{
Language copy{other};
*this = static_cast<Language&&>(copy);
return *this;
}
Language& operator=(Language&& other)
{
if (this != &other)
{
tableDtor[getIndexFromTag(tag)](&buffer);
tag = other.tag;
tableMove[getIndexFromTag(tag)](&buffer, &other.buffer); // nothrow
}
return *this;
}
template<typename T>
const T* get() const
{
static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Ts>...>);
static_assert(WithinDomain<T>::value);
return tag == T::tag ? reinterpret_cast<const T*>(&buffer) : nullptr;
}
bool operator==(const Language& rhs) const
{
return tag == rhs.tag && tablePred[getIndexFromTag(tag)](&buffer, &rhs.buffer);
}
bool operator!=(const Language& rhs) const
{
return !(*this == rhs);
}
public:
struct Hash
{

View file

@ -3,24 +3,77 @@
#include "Luau/Language.h"
LUAU_EQSAT_ATOM(Atom1, bool);
LUAU_EQSAT_ATOM(Atom2, bool);
#include <string>
LUAU_EQSAT_ATOM(I32, int);
LUAU_EQSAT_ATOM(Bool, bool);
LUAU_EQSAT_ATOM(Str, std::string);
using namespace Luau;
using Mini = EqSat::Language<Atom1, Atom2>;
using Value = EqSat::Language<I32, Bool, Str>;
TEST_SUITE_BEGIN("EqSatLanguage");
TEST_CASE("language_get_works")
TEST_CASE("atom_equality")
{
Mini m{Atom1{true}};
CHECK(I32{0} == I32{0});
CHECK(I32{0} != I32{1});
}
const Atom1* atom = m.get<Atom1>();
REQUIRE(atom);
CHECK(atom->value);
TEST_CASE("language_get")
{
Value v{I32{5}};
CHECK(!m.get<Atom2>());
auto i = v.get<I32>();
REQUIRE(i);
CHECK(i->value);
CHECK(!v.get<Bool>());
}
TEST_CASE("language_copy_ctor")
{
Value v1{I32{5}};
Value v2 = v1;
auto i1 = v1.get<I32>();
auto i2 = v2.get<I32>();
REQUIRE(i1);
REQUIRE(i2);
CHECK(i1->value == i2->value);
}
TEST_CASE("language_move_ctor")
{
Value v1{Str{"hello"}};
{
auto s1 = v1.get<Str>();
REQUIRE(s1);
CHECK(s1->value == "hello");
}
Value v2 = std::move(v1);
auto s1 = v1.get<Str>();
REQUIRE(s1);
CHECK(s1->value == ""); // this also tests the dtor.
auto s2 = v2.get<Str>();
REQUIRE(s2);
CHECK(s2->value == "hello");
}
TEST_CASE("language_equality")
{
Value v1{I32{0}};
Value v2{I32{0}};
Value v3{I32{1}};
Value v4{Bool{true}};
CHECK(v1 == v2);
CHECK(v2 != v3);
CHECK(v3 != v4);
}
TEST_SUITE_END();