From 6e996d1899e12869f8283a1e672fbbd0dbe6dbcc Mon Sep 17 00:00:00 2001 From: Alexander McCord Date: Mon, 27 May 2024 22:21:04 -0700 Subject: [PATCH] Start to implement some stuff according to pg 9. --- EqSat/include/Luau/EGraph.h | 72 ++++++++++++++++++++---------- tests/EqSat.propositional.test.cpp | 11 +++++ 2 files changed, 60 insertions(+), 23 deletions(-) diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index efe21aa3..6edcc6a7 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Common.h" #include "Luau/Id.h" #include "Luau/UnionFind.h" @@ -23,15 +24,44 @@ struct EClass final std::vector> parents; }; -/// In Definition 2.1, an EGraph is composed with a tuple (U, M, H) where -/// - U: [`EGraph::unionfind`] -/// - M: [`EGraph::classes`] -/// - H: [`EGraph::hashcons`] -/// /// See . template -struct EGraph final +class EGraph final { + /// A union-find data structure 𝑈 stores an equivalence relation over e-class ids. + UnionFind unionfind; + + /// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same + /// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the + /// e-class 𝑀[find(𝑎)]. + std::unordered_map> classes; + + /// The hashcons 𝐻 is a map from e-nodes to e-class ids. + std::unordered_map hashcons; + +private: + template + void canonicalize(T&& enode) + { + // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where + // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). + for (Id& id : enode.operands()) + id = find(id); + } + + Id makeEClass(const L& enode) + { + Id id = unionfind.makeSet(); + classes.insert_or_assign(id, EClass{ + id, + {enode}, + {}, // TODO: analysis make + {}, + }); + return id; + } + +public: // TODO: static_assert L <: Language // TODO: static_assert N <: Analysis @@ -43,11 +73,6 @@ struct EGraph final template std::optional lookup(T&& enode) const { - // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where - // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). - for (Id& id : enode->operands()) - id = find(id); - if (auto it = hashcons.find(enode); it != hashcons.end()) return it->second; @@ -56,25 +81,26 @@ struct EGraph final Id add(L enode) { + canonicalize(enode); + if (auto id = lookup(enode)) return *id; - Id id{hashcons.size()}; + Id id = makeEClass(enode); + for (Id operand : enode.operands()) + (*this)[operand].parents.push_back({enode, id}); + hashcons.insert_or_assign(enode, id); + // TODO clean = false return id; } -private: - /// A union-find data structure 𝑈 stores an equivalence relation over e-class ids. - UnionFind unionfind; - - /// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same - /// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the - /// e-class 𝑀[find(𝑎)]. - std::unordered_map> classes; - - /// The hashcons 𝐻 is a map from e-nodes to e-class ids. - std::unordered_map hashcons; + EClass& operator[](Id id) + { + auto it = classes.find(find(id)); + LUAU_ASSERT(it != classes.end()); + return it->second; + } }; } // namespace Luau::EqSat diff --git a/tests/EqSat.propositional.test.cpp b/tests/EqSat.propositional.test.cpp index 723fd104..940918a2 100644 --- a/tests/EqSat.propositional.test.cpp +++ b/tests/EqSat.propositional.test.cpp @@ -47,4 +47,15 @@ TEST_CASE("egraph_hashconsing") CHECK(id2 != id3); } +TEST_CASE("egraph_data") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{false}); + + CHECK(egraph[id1].data == std::nullopt); // TODO: true + CHECK(egraph[id2].data == std::nullopt); // TODO: false +} + TEST_SUITE_END();