From 4183c31d23bc79c4cd34237e730df08f37364943 Mon Sep 17 00:00:00 2001 From: Alexander McCord Date: Sat, 1 Jun 2024 15:24:35 -0700 Subject: [PATCH] Implement egraph merging of two ids. --- EqSat/include/Luau/EGraph.h | 74 ++++++++++++++++++++++++------ tests/EqSat.propositional.test.cpp | 44 ++++++++++++++++++ 2 files changed, 105 insertions(+), 13 deletions(-) diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index 903a0110..eb68ee0e 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -4,10 +4,10 @@ #include "Luau/Common.h" #include "Luau/Id.h" #include "Luau/UnionFind.h" +#include "Luau/VecDeque.h" #include #include -#include #include namespace Luau::EqSat @@ -21,10 +21,17 @@ struct Analysis final { N analysis; - typename N::Data make(const EGraph& egraph, const L& enode) const + using D = typename N::Data; + + D make(const EGraph& egraph, const L& enode) const { return analysis.make(egraph, enode); } + + void join(D& a, const D& b) + { + return analysis.join(a, b); + } }; /// Each e-class is a set of e-nodes representing equivalent terms from a given language, @@ -56,8 +63,7 @@ class EGraph final std::unordered_map hashcons; private: - template - void canonicalize(T&& enode) + void canonicalize(L& enode) { // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). @@ -65,27 +71,54 @@ private: id = find(id); } + bool isCanonical(const L& enode) const + { + bool canonical = true; + for (Id id : enode.operands()) + canonical &= (id == find(id)); + return canonical; + } + Id makeEClass(const L& enode) { + LUAU_ASSERT(isCanonical(enode)); + Id id = unionfind.makeSet(); + classes.insert_or_assign(id, EClass{ id, {enode}, analysis.make(*this, enode), {}, }); + + for (Id operand : enode.operands()) + get(operand).parents.push_back({enode, id}); + + hashcons.insert_or_assign(enode, id); + return id; } + // Looks up for an eclass from a given non-canonicalized `id`. + // For a canonicalized eclass, use `get(find(id))` or `egraph[id]`. + EClass& get(Id id) + { + auto it = classes.find(id); + LUAU_ASSERT(it != classes.end()); + return it->second; + } + public: Id find(Id id) const { return unionfind.find(id); } - template - std::optional lookup(T&& enode) const + std::optional lookup(const L& enode) const { + LUAU_ASSERT(isCanonical(enode)); + if (auto it = hashcons.find(enode); it != hashcons.end()) return it->second; @@ -100,19 +133,34 @@ public: return *id; Id id = makeEClass(enode); - for (Id operand : enode.operands()) - (*this)[operand].parents.push_back({enode, id}); - - hashcons.insert_or_assign(enode, id); // TODO clean = false return id; } + void merge(Id id1, Id id2) + { + id1 = find(id1); + id2 = find(id2); + if (id1 == id2) + return; + + unionfind.merge(id1, id2); + + EClass& eclass1 = get(id1); + EClass eclass2 = get(id2); + classes.erase(id2); + + analysis.join(eclass1.data, eclass2.data); + } + EClass& operator[](Id id) { - auto it = classes.find(find(id)); - LUAU_ASSERT(it != classes.end()); - return it->second; + return get(find(id)); + } + + const EClass& operator[](Id id) const + { + return const_cast(this)->get(find(id)); } }; diff --git a/tests/EqSat.propositional.test.cpp b/tests/EqSat.propositional.test.cpp index 4988822a..e79a0e3e 100644 --- a/tests/EqSat.propositional.test.cpp +++ b/tests/EqSat.propositional.test.cpp @@ -38,9 +38,41 @@ struct ConstantFold return std::nullopt; else if (auto b = enode.get()) return b->value; + else if (auto n = enode.get()) + { + if (auto data = egraph[n->field()].data) + return !*data; + } + else if (auto a = enode.get()) + { + Data left = egraph[a->field()].data; + Data right = egraph[a->field()].data; + if (left && right) + return *left && *right; + } + else if (auto o = enode.get()) + { + Data left = egraph[o->field()].data; + Data right = egraph[o->field()].data; + if (left && right) + return *left && *right; + } + else if (auto i = enode.get()) + { + Data antecedent = egraph[i->field()].data; + Data consequent = egraph[i->field()].data; + if (antecedent && consequent) + return !*antecedent || *consequent; + } return std::nullopt; } + + void join(Data& a, const Data& b) + { + if (!a && b) + a = b; + } }; TEST_SUITE_BEGIN("EqSatPropositionalLogic"); @@ -68,4 +100,16 @@ TEST_CASE("egraph_data") CHECK(egraph[id2].data == false); } +TEST_CASE("egraph_merge") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Var{"a"}); + EqSat::Id id2 = egraph.add(Bool{true}); + egraph.merge(id1, id2); + + CHECK(egraph[id1].data == true); + CHECK(egraph[id2].data == true); +} + TEST_SUITE_END();