From c8a1651c89478aaf3d13cb77ce050583a4ed49c0 Mon Sep 17 00:00:00 2001 From: Alexander McCord Date: Mon, 27 May 2024 17:05:02 -0700 Subject: [PATCH] Implement Language::operands() which means EGraph is able to canonicalize. --- EqSat/include/Luau/EGraph.h | 31 +++++++++--- EqSat/include/Luau/Language.h | 33 ++++++++++++- EqSat/include/Luau/Slice.h | 77 ++++++++++++++++++++++++++++++ EqSat/src/UnionFind.cpp | 1 + Sources.cmake | 2 + tests/EqSat.language.test.cpp | 15 ++++++ tests/EqSat.propositional.test.cpp | 38 ++++++++++++++- tests/EqSat.slice.test.cpp | 58 ++++++++++++++++++++++ 8 files changed, 246 insertions(+), 9 deletions(-) create mode 100644 EqSat/include/Luau/Slice.h create mode 100644 tests/EqSat.slice.test.cpp diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index 21932f9d..d2099f3c 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -4,6 +4,7 @@ #include "Luau/Id.h" #include "Luau/UnionFind.h" +#include #include #include #include @@ -39,13 +40,29 @@ struct EGraph final return unionfind.find(id); } - // Per the egg paper, definition 2.2 (Canonicalization): - // - // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where - // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). - // - // Doing so requires sketching out `Luau::EqSat::Language` which - // I want to do at a later time for the time being. Will revisit. + // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where + // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). + std::optional lookup(L enode) const + { + for (Id& id : enode.operands()) + id = find(id); + + if (auto it = hashcons.find(enode); it != hashcons.end()) + return it->second; + + return std::nullopt; + } + + // TODO: `add`. For now, we call it shoveItIn so it's obvious it's just for testing. + Id shoveItIn(L enode) + { + if (auto id = lookup(enode)) + return *id; + + Id id{hashcons.size()}; + hashcons.insert_or_assign(enode, id); + return id; + } private: /// A union-find data structure 𝑈 stores an equivalence relation over e-class ids. diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h index cdd60648..126b67bc 100644 --- a/EqSat/include/Luau/Language.h +++ b/EqSat/include/Luau/Language.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Id.h" +#include "Luau/Slice.h" #include #include @@ -63,6 +64,11 @@ struct Atom { T value; + Slice operands() + { + return {}; + } + bool operator==(const Atom& rhs) const { return value == rhs.value; @@ -123,8 +129,13 @@ public: { } + Slice operands() + { + return Slice{array}; + } + template - Id field() const + const Id& field() const { static_assert(std::disjunction_v, Fields>...>); return array[getIndex()]; @@ -175,6 +186,7 @@ private: using FnDtor = void (*)(void*); using FnPred = bool (*)(const void*, const void*); using FnHash = size_t (*)(const void*); + using FnOper = Slice (*)(void*); template static void fnCopy(void* dst, const void* src) @@ -206,11 +218,18 @@ private: return typename T::Hash{}(*static_cast(buffer)); } + template + static Slice fnOper(void* buffer) + { + return static_cast(buffer)->operands(); + } + static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy...}; static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove...}; static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor...}; static constexpr FnPred tablePred[sizeof...(Ts)] = {&fnPred...}; static constexpr FnHash tableHash[sizeof...(Ts)] = {&fnHash...}; + static constexpr FnOper tableOper[sizeof...(Ts)] = {&fnOper...}; static constexpr int getIndexFromTag(const char* tag) { @@ -267,6 +286,18 @@ public: return *this; } + int index() const + { + return getIndexFromTag(tag); + } + + /// You should never call this function with the intention of mutating the `Id`. + /// Reading is ok, but you should also never assume that these `Id`s are stable. + Slice operands() + { + return tableOper[getIndexFromTag(tag)](&buffer); + } + template const T* get() const { diff --git a/EqSat/include/Luau/Slice.h b/EqSat/include/Luau/Slice.h new file mode 100644 index 00000000..a43642b2 --- /dev/null +++ b/EqSat/include/Luau/Slice.h @@ -0,0 +1,77 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +namespace Luau::EqSat +{ + +template +class Slice +{ + T* _data; + size_t _size; + +public: + Slice() + : _data(nullptr) + , _size(0) + { + } + + /// Use this constructor if you have a dynamically sized vector. + /// The slice is valid for as long as the backing vector has not moved + /// elsewhere in memory. + /// + /// In general, a slice should never be used from vectors except for + /// any vectors whose size are statically unknown, but remains fixed + /// upon the construction of such a slice over a vector. + Slice(T* first, size_t last) + : _data(first) + , _size(last) + { + } + + template + explicit Slice(std::array& array) + : _data(array.data()) + , _size(array.size()) + { + } + + T* data() const + { + return _data; + } + + size_t size() const + { + return _size; + } + + bool empty() const + { + return _size == 0; + } + + T& operator[](size_t i) const + { + LUAU_ASSERT(i < _size); + return _data[i]; + } + +public: + T* begin() const + { + return _data; + } + + T* end() const + { + return _data + _size; + } +}; + +} diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 9b991422..04d9ba74 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -17,6 +17,7 @@ Id UnionFind::find(Id id) const { LUAU_ASSERT(size_t(id) < parents.size()); + // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. while (id != parents[size_t(id)]) id = parents[size_t(id)]; diff --git a/Sources.cmake b/Sources.cmake index 89da27ca..902002f2 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -300,6 +300,7 @@ target_sources(Luau.EqSat PRIVATE EqSat/include/Luau/EGraph.h EqSat/include/Luau/Id.h EqSat/include/Luau/Language.h + EqSat/include/Luau/Slice.h EqSat/include/Luau/UnionFind.h EqSat/src/Id.cpp @@ -431,6 +432,7 @@ if(TARGET Luau.UnitTest) tests/Differ.test.cpp tests/EqSat.language.test.cpp tests/EqSat.propositional.test.cpp + tests/EqSat.slice.test.cpp tests/Error.test.cpp tests/Fixture.cpp tests/Fixture.h diff --git a/tests/EqSat.language.test.cpp b/tests/EqSat.language.test.cpp index da90ff6b..9d041364 100644 --- a/tests/EqSat.language.test.cpp +++ b/tests/EqSat.language.test.cpp @@ -127,4 +127,19 @@ TEST_CASE("node_field") CHECK(right != left2); } +TEST_CASE("language_operands") +{ + Value v1{I32{0}}; + CHECK(v1.operands().empty()); + + Value v2{Add{EqSat::Id{0}, EqSat::Id{1}}}; + const Add* add = v2.get(); + REQUIRE(add); + + EqSat::Slice actual = v2.operands(); + CHECK(actual.size() == 2); + CHECK(actual[0] == add->field()); + CHECK(actual[1] == add->field()); +} + TEST_SUITE_END(); diff --git a/tests/EqSat.propositional.test.cpp b/tests/EqSat.propositional.test.cpp index 0cbb2e25..81248a7c 100644 --- a/tests/EqSat.propositional.test.cpp +++ b/tests/EqSat.propositional.test.cpp @@ -1,8 +1,44 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include -// Var, Bool, And, Or, Not, Implies +#include "Luau/EGraph.h" +#include "Luau/Language.h" + +#include + +LUAU_EQSAT_ATOM(Var, std::string); +LUAU_EQSAT_ATOM(Bool, bool); + +LUAU_EQSAT_FIELD(Negated); +LUAU_EQSAT_UNARY_NODE(Not, Negated); + +LUAU_EQSAT_FIELD(Left); +LUAU_EQSAT_FIELD(Right); +LUAU_EQSAT_BINARY_NODE(And, Left, Right); +LUAU_EQSAT_BINARY_NODE(Or, Left, Right); + +LUAU_EQSAT_FIELD(Antecedent); +LUAU_EQSAT_FIELD(Consequent); +LUAU_EQSAT_BINARY_NODE(Implies, Antecedent, Consequent); + +using namespace Luau; + +using PropositionalLogic = EqSat::Language; + +struct ConstantFold +{ + using Data = std::optional; +}; + +using EGraph = EqSat::EGraph; TEST_SUITE_BEGIN("EqSatPropositionalLogic"); +TEST_CASE("egraph_hashconsing") +{ + EGraph egraph; + + CHECK(egraph.shoveItIn(Bool{true}) == egraph.shoveItIn(Bool{true})); +} + TEST_SUITE_END(); diff --git a/tests/EqSat.slice.test.cpp b/tests/EqSat.slice.test.cpp new file mode 100644 index 00000000..a0852e83 --- /dev/null +++ b/tests/EqSat.slice.test.cpp @@ -0,0 +1,58 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include + +#include "Luau/Slice.h" + +#include + +using namespace Luau; + +TEST_SUITE_BEGIN("EqSatLanguage"); + +TEST_CASE("slice_is_a_view_over_array") +{ + std::array a{1, 2, 3, 4, 5, 6, 7, 8}; + + EqSat::Slice slice{a}; + + CHECK(slice.data() == a.data()); + CHECK(slice.size() == a.size()); + + for (size_t i = 0; i < a.size(); ++i) + { + CHECK(slice[i] == a[i]); + CHECK(&slice[i] == &a[i]); + } +} + +TEST_CASE("slice_is_a_view_over_vector") +{ + std::vector vector{1, 2, 3, 4, 5, 6, 7, 8}; + + EqSat::Slice slice{vector.data(), vector.size()}; + + CHECK(slice.data() == vector.data()); + CHECK(slice.size() == vector.size()); + + for (size_t i = 0; i < vector.size(); ++i) + { + CHECK(slice[i] == vector[i]); + CHECK(&slice[i] == &vector[i]); + } +} + +TEST_CASE("mutate_via_slice") +{ + std::array a{1, 2}; + CHECK(a[0] == 1); + CHECK(a[1] == 2); + + EqSat::Slice slice{a}; + slice[0] = 42; + slice[1] = 37; + + CHECK(a[0] == 42); + CHECK(a[1] == 37); +} + +TEST_SUITE_END();