From 26f1d18c81b6def8cba21fd377a44502f567f406 Mon Sep 17 00:00:00 2001 From: Alexander McCord Date: Sat, 1 Jun 2024 16:33:51 -0700 Subject: [PATCH] Implement rebuild and show it works. --- EqSat/include/Luau/EGraph.h | 48 +++++++++++++++++-- EqSat/include/Luau/Language.h | 3 +- tests/EqSat.propositional.test.cpp | 77 ++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 4 deletions(-) diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index eb68ee0e..43e6c69e 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -62,6 +62,8 @@ class EGraph final /// The hashcons 𝐻 is a map from e-nodes to e-class ids. std::unordered_map hashcons; + VecDeque> worklist; + private: void canonicalize(L& enode) { @@ -95,6 +97,7 @@ private: for (Id operand : enode.operands()) get(operand).parents.push_back({enode, id}); + worklist.push_back({enode, id}); hashcons.insert_or_assign(enode, id); return id; @@ -104,9 +107,29 @@ private: // For a canonicalized eclass, use `get(find(id))` or `egraph[id]`. EClass& get(Id id) { - auto it = classes.find(id); - LUAU_ASSERT(it != classes.end()); - return it->second; + return classes.at(id); + } + + void repair(EClass& eclass) + { + // After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent. + std::unordered_map map; + for (auto& [enode, id] : eclass.parents) + { + // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. + hashcons.erase(enode); + canonicalize(enode); + hashcons.insert_or_assign(enode, find(id)); + + if (auto it = map.find(enode); it != map.end()) + merge(id, it->second); + + map.insert_or_assign(enode, find(id)); + } + + eclass.parents.clear(); + for (auto& [enode, id] : map) + eclass.parents.push_back({std::move(enode), id}); } public: @@ -150,9 +173,28 @@ public: EClass eclass2 = get(id2); classes.erase(id2); + worklist.reserve(worklist.size() + eclass2.parents.size()); + for (auto [enode, id] : eclass2.parents) + worklist.push_back({std::move(enode), id}); + analysis.join(eclass1.data, eclass2.data); } + void rebuild() + { + while (!worklist.empty()) + { + auto [enode, id] = worklist.back(); + worklist.pop_back(); + repair(get(find(id))); + } + } + + size_t size() const + { + return classes.size(); + } + EClass& operator[](Id id) { return get(find(id)); diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h index 36eaa89d..d83138a2 100644 --- a/EqSat/include/Luau/Language.h +++ b/EqSat/include/Luau/Language.h @@ -23,7 +23,8 @@ struct LanguageHash }; template -std::size_t languageHash(const T& lang) { +std::size_t languageHash(const T& lang) +{ return LanguageHash{}(lang); } diff --git a/tests/EqSat.propositional.test.cpp b/tests/EqSat.propositional.test.cpp index e79a0e3e..4358282c 100644 --- a/tests/EqSat.propositional.test.cpp +++ b/tests/EqSat.propositional.test.cpp @@ -112,4 +112,81 @@ TEST_CASE("egraph_merge") CHECK(egraph[id2].data == true); } +TEST_CASE("const_fold_true_and_true") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{true}); + EqSat::Id id3 = egraph.add(And{id1, id2}); + + CHECK(egraph[id3].data == true); +} + +TEST_CASE("const_fold_true_and_false") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{false}); + EqSat::Id id3 = egraph.add(And{id1, id2}); + + CHECK(egraph[id3].data == false); +} + +TEST_CASE("const_fold_false_and_false") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{false}); + EqSat::Id id2 = egraph.add(Bool{false}); + EqSat::Id id3 = egraph.add(And{id1, id2}); + + CHECK(egraph[id3].data == false); +} + +TEST_CASE("implications") +{ + EGraph egraph; + + EqSat::Id t = egraph.add(Bool{true}); + EqSat::Id f = egraph.add(Bool{false}); + + EqSat::Id a = egraph.add(Implies{t, t}); // true + EqSat::Id b = egraph.add(Implies{t, f}); // false + EqSat::Id c = egraph.add(Implies{f, t}); // true + EqSat::Id d = egraph.add(Implies{f, f}); // true + + CHECK(egraph[a].data == true); + CHECK(egraph[b].data == false); + CHECK(egraph[c].data == true); + CHECK(egraph[d].data == true); +} + +TEST_CASE("merge_x_and_y") +{ + EGraph egraph; + + EqSat::Id x = egraph.add(Var{"x"}); + EqSat::Id y = egraph.add(Var{"y"}); + + EqSat::Id a = egraph.add(Var{"a"}); + EqSat::Id ax = egraph.add(And{a, x}); + EqSat::Id ay = egraph.add(And{a, y}); + + egraph.merge(x, y); // [x y] [ax] [ay] [a] + CHECK_EQ(egraph.size(), 4); + CHECK_EQ(egraph.find(x), egraph.find(y)); + CHECK_NE(egraph.find(ax), egraph.find(ay)); + CHECK_NE(egraph.find(a), egraph.find(x)); + CHECK_NE(egraph.find(a), egraph.find(y)); + + egraph.rebuild(); // [x y] [ax ay] [a] + CHECK_EQ(egraph.size(), 3); + CHECK_EQ(egraph.find(x), egraph.find(y)); + CHECK_EQ(egraph.find(ax), egraph.find(ay)); + CHECK_NE(egraph.find(a), egraph.find(x)); + CHECK_NE(egraph.find(a), egraph.find(y)); +} + TEST_SUITE_END();