// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/Common.h" #include "Luau/Id.h" #include "Luau/Language.h" #include "Luau/UnionFind.h" #include <optional> #include <unordered_map> #include <vector> namespace Luau::EqSat { template<typename L, typename N> struct EGraph; template<typename L, typename N> struct Analysis final { N analysis; using D = typename N::Data; Analysis() = default; Analysis(N a) : analysis(std::move(a)) { } template<typename T> static D fnMake(const N& analysis, const EGraph<L, N>& egraph, const L& enode) { return analysis.make(egraph, *enode.template get<T>()); } template<typename... Ts> D make(const EGraph<L, N>& egraph, const Language<Ts...>& enode) const { using FnMake = D (*)(const N&, const EGraph<L, N>&, const L&); static constexpr FnMake tableMake[sizeof...(Ts)] = {&fnMake<Ts>...}; return tableMake[enode.index()](analysis, egraph, enode); } void join(D& a, const D& b) const { return analysis.join(a, b); } }; template<typename L> struct Node { L node; bool boring = false; struct Hash { size_t operator()(const Node& node) const { return typename L::Hash{}(node.node); } }; }; template<typename L> struct NodeIterator { private: using iterator = std::vector<Node<L>>; iterator iter; public: L& operator*() { return iter->node; } const L& operator*() const { return iter->node; } iterator& operator++() { ++iter; return *this; } iterator operator++(int) { iterator copy = *this; ++*this; return copy; } bool operator==(const iterator& rhs) const { return iter == rhs.iter; } bool operator!=(const iterator& rhs) const { return iter != rhs.iter; } }; /// Each e-class is a set of e-nodes representing equivalent terms from a given language, /// and an e-node is a function symbol paired with a list of children e-classes. template<typename L, typename D> struct EClass final { Id id; std::vector<Node<L>> nodes; D data; std::vector<std::pair<L, Id>> parents; }; /// See <https://arxiv.org/pdf/2004.03082>. template<typename L, typename N> struct EGraph final { using EClassT = EClass<L, typename N::Data>; EGraph() = default; explicit EGraph(N analysis) : analysis(std::move(analysis)) { } Id find(Id id) const { return unionfind.find(id); } std::optional<Id> lookup(const L& enode) const { LUAU_ASSERT(isCanonical(enode)); if (auto it = hashcons.find(enode); it != hashcons.end()) return it->second; return std::nullopt; } Id add(L enode) { canonicalize(enode); if (auto id = lookup(enode)) return *id; Id id = makeEClass(enode); return id; } // Returns true if the two IDs were not previously merged. bool merge(Id id1, Id id2) { id1 = find(id1); id2 = find(id2); if (id1 == id2) return false; const Id mergedId = unionfind.merge(id1, id2); // Ensure that id1 is the Id that we keep, and id2 is the id that we drop. if (mergedId == id2) std::swap(id1, id2); EClassT& eclass1 = get(id1); EClassT eclass2 = std::move(get(id2)); classes.erase(id2); eclass1.nodes.insert(eclass1.nodes.end(), eclass2.nodes.begin(), eclass2.nodes.end()); eclass1.parents.insert(eclass1.parents.end(), eclass2.parents.begin(), eclass2.parents.end()); std::sort( eclass1.nodes.begin(), eclass1.nodes.end(), [](const Node<L>& left, const Node<L>& right) { return left.node.index() < right.node.index(); } ); worklist.reserve(worklist.size() + eclass1.parents.size()); for (const auto& [eclass, id] : eclass1.parents) worklist.push_back(id); analysis.join(eclass1.data, eclass2.data); return true; } void rebuild() { std::unordered_set<Id> seen; while (!worklist.empty()) { Id id = worklist.back(); worklist.pop_back(); const bool isFresh = seen.insert(id).second; if (!isFresh) continue; repair(find(id)); } } size_t size() const { return classes.size(); } EClassT& operator[](Id id) { return get(find(id)); } const EClassT& operator[](Id id) const { return const_cast<EGraph*>(this)->get(find(id)); } const std::unordered_map<Id, EClassT>& getAllClasses() const { return classes; } void markBoring(Id id, size_t index) { get(id).nodes[index].boring = true; } private: Analysis<L, N> analysis; /// A union-find data structure 𝑈 stores an equivalence relation over e-class ids. UnionFind unionfind; /// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same /// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the /// e-class 𝑀[find(𝑎)]. std::unordered_map<Id, EClassT> classes; /// The hashcons 𝐻 is a map from e-nodes to e-class ids. std::unordered_map<L, Id, typename L::Hash> hashcons; std::vector<Id> worklist; private: void canonicalize(L& enode) { // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). Luau::EqSat::canonicalize( enode, [&](Id id) { return find(id); } ); } bool isCanonical(const L& enode) const { bool canonical = true; for (Id id : enode.operands()) canonical &= (id == find(id)); return canonical; } Id makeEClass(const L& enode) { LUAU_ASSERT(isCanonical(enode)); Id id = unionfind.makeSet(); classes.insert_or_assign( id, EClassT{ id, {Node<L>{enode, false}}, analysis.make(*this, enode), {}, } ); for (Id operand : enode.operands()) get(operand).parents.push_back({enode, id}); worklist.emplace_back(id); hashcons.insert_or_assign(enode, id); return id; } // Looks up for an eclass from a given non-canonicalized `id`. // For a canonicalized eclass, use `get(find(id))` or `egraph[id]`. EClassT& get(Id id) { LUAU_ASSERT(classes.count(id)); return classes.at(id); } void repair(Id id) { // In the egg paper, the `repair` function makes use of two loops over the `eclass.parents` // by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id. // And then in another loop that follows, deduplicate it. // // Here, we unify the two loops. I think it's equivalent? // After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent. std::unordered_map<L, Id, typename L::Hash> newParents; // The eclass can be deallocated if it is merged into another eclass, so // we take what we need from it and avoid retaining a pointer. std::vector<std::pair<L, Id>> parents = get(id).parents; for (auto& pair : parents) { L& parentNode = pair.first; Id parentId = pair.second; // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. hashcons.erase(parentNode); canonicalize(parentNode); hashcons.insert_or_assign(parentNode, find(parentId)); if (auto it = newParents.find(parentNode); it != newParents.end()) merge(parentId, it->second); newParents.insert_or_assign(parentNode, find(parentId)); } // We reacquire the pointer because the prior loop potentially merges // the eclass into another, which might move it around in memory. EClassT* eclass = &get(find(id)); eclass->parents.clear(); for (const auto& [node, id] : newParents) eclass->parents.emplace_back(std::move(node), std::move(id)); std::unordered_map<L, bool, typename L::Hash> newNodes; for (Node<L> node : eclass->nodes) { canonicalize(node.node); bool& b = newNodes[std::move(node.node)]; b = b || node.boring; } eclass->nodes.clear(); while (!newNodes.empty()) { auto n = newNodes.extract(newNodes.begin()); eclass->nodes.push_back(Node<L>{n.key(), n.mapped()}); } // FIXME: Extract into sortByTag() std::sort( eclass->nodes.begin(), eclass->nodes.end(), [](const Node<L>& left, const Node<L>& right) { return left.node.index() < right.node.index(); } ); } }; } // namespace Luau::EqSat