2024-07-19 18:21:40 +01:00
|
|
|
|
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include "Luau/Common.h"
|
|
|
|
|
#include "Luau/Id.h"
|
|
|
|
|
#include "Luau/Language.h"
|
|
|
|
|
#include "Luau/UnionFind.h"
|
|
|
|
|
|
|
|
|
|
#include <optional>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
namespace Luau::EqSat
|
|
|
|
|
{
|
|
|
|
|
|
|
|
|
|
template<typename L, typename N>
|
|
|
|
|
struct EGraph;
|
|
|
|
|
|
|
|
|
|
template<typename L, typename N>
|
|
|
|
|
struct Analysis final
|
|
|
|
|
{
|
|
|
|
|
N analysis;
|
|
|
|
|
|
|
|
|
|
using D = typename N::Data;
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
Analysis() = default;
|
|
|
|
|
|
|
|
|
|
Analysis(N a)
|
|
|
|
|
: analysis(std::move(a))
|
|
|
|
|
{
|
|
|
|
|
}
|
|
|
|
|
|
2024-07-19 18:21:40 +01:00
|
|
|
|
template<typename T>
|
|
|
|
|
static D fnMake(const N& analysis, const EGraph<L, N>& egraph, const L& enode)
|
|
|
|
|
{
|
|
|
|
|
return analysis.make(egraph, *enode.template get<T>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename... Ts>
|
|
|
|
|
D make(const EGraph<L, N>& egraph, const Language<Ts...>& enode) const
|
|
|
|
|
{
|
|
|
|
|
using FnMake = D (*)(const N&, const EGraph<L, N>&, const L&);
|
|
|
|
|
static constexpr FnMake tableMake[sizeof...(Ts)] = {&fnMake<Ts>...};
|
|
|
|
|
|
|
|
|
|
return tableMake[enode.index()](analysis, egraph, enode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void join(D& a, const D& b) const
|
|
|
|
|
{
|
|
|
|
|
return analysis.join(a, b);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/// Each e-class is a set of e-nodes representing equivalent terms from a given language,
|
|
|
|
|
/// and an e-node is a function symbol paired with a list of children e-classes.
|
|
|
|
|
template<typename L, typename D>
|
|
|
|
|
struct EClass final
|
|
|
|
|
{
|
|
|
|
|
Id id;
|
|
|
|
|
std::vector<L> nodes;
|
|
|
|
|
D data;
|
|
|
|
|
std::vector<std::pair<L, Id>> parents;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/// See <https://arxiv.org/pdf/2004.03082>.
|
|
|
|
|
template<typename L, typename N>
|
|
|
|
|
struct EGraph final
|
|
|
|
|
{
|
2024-11-08 18:43:24 +00:00
|
|
|
|
using EClassT = EClass<L, typename N::Data>;
|
|
|
|
|
|
|
|
|
|
EGraph() = default;
|
|
|
|
|
|
|
|
|
|
explicit EGraph(N analysis)
|
|
|
|
|
: analysis(std::move(analysis))
|
|
|
|
|
{
|
|
|
|
|
}
|
|
|
|
|
|
2024-07-19 18:21:40 +01:00
|
|
|
|
Id find(Id id) const
|
|
|
|
|
{
|
|
|
|
|
return unionfind.find(id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::optional<Id> lookup(const L& enode) const
|
|
|
|
|
{
|
|
|
|
|
LUAU_ASSERT(isCanonical(enode));
|
|
|
|
|
|
|
|
|
|
if (auto it = hashcons.find(enode); it != hashcons.end())
|
|
|
|
|
return it->second;
|
|
|
|
|
|
|
|
|
|
return std::nullopt;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Id add(L enode)
|
|
|
|
|
{
|
|
|
|
|
canonicalize(enode);
|
|
|
|
|
|
|
|
|
|
if (auto id = lookup(enode))
|
|
|
|
|
return *id;
|
|
|
|
|
|
|
|
|
|
Id id = makeEClass(enode);
|
|
|
|
|
return id;
|
|
|
|
|
}
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
// Returns true if the two IDs were not previously merged.
|
|
|
|
|
bool merge(Id id1, Id id2)
|
2024-07-19 18:21:40 +01:00
|
|
|
|
{
|
|
|
|
|
id1 = find(id1);
|
|
|
|
|
id2 = find(id2);
|
|
|
|
|
if (id1 == id2)
|
2024-11-08 18:43:24 +00:00
|
|
|
|
return false;
|
|
|
|
|
|
|
|
|
|
const Id mergedId = unionfind.merge(id1, id2);
|
2024-07-19 18:21:40 +01:00
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
// Ensure that id1 is the Id that we keep, and id2 is the id that we drop.
|
|
|
|
|
if (mergedId == id2)
|
|
|
|
|
std::swap(id1, id2);
|
2024-07-19 18:21:40 +01:00
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
EClassT& eclass1 = get(id1);
|
|
|
|
|
EClassT eclass2 = std::move(get(id2));
|
2024-07-19 18:21:40 +01:00
|
|
|
|
classes.erase(id2);
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
eclass1.nodes.insert(eclass1.nodes.end(), eclass2.nodes.begin(), eclass2.nodes.end());
|
|
|
|
|
eclass1.parents.insert(eclass1.parents.end(), eclass2.parents.begin(), eclass2.parents.end());
|
|
|
|
|
|
|
|
|
|
std::sort(
|
|
|
|
|
eclass1.nodes.begin(),
|
|
|
|
|
eclass1.nodes.end(),
|
|
|
|
|
[](const L& left, const L& right)
|
|
|
|
|
{
|
|
|
|
|
return left.index() < right.index();
|
|
|
|
|
}
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
worklist.reserve(worklist.size() + eclass1.parents.size());
|
|
|
|
|
for (const auto& [eclass, id] : eclass1.parents)
|
|
|
|
|
worklist.push_back(id);
|
2024-07-19 18:21:40 +01:00
|
|
|
|
|
|
|
|
|
analysis.join(eclass1.data, eclass2.data);
|
2024-11-08 18:43:24 +00:00
|
|
|
|
|
|
|
|
|
return true;
|
2024-07-19 18:21:40 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void rebuild()
|
|
|
|
|
{
|
2024-11-08 18:43:24 +00:00
|
|
|
|
std::unordered_set<Id> seen;
|
|
|
|
|
|
2024-07-19 18:21:40 +01:00
|
|
|
|
while (!worklist.empty())
|
|
|
|
|
{
|
2024-11-08 18:43:24 +00:00
|
|
|
|
Id id = worklist.back();
|
2024-07-19 18:21:40 +01:00
|
|
|
|
worklist.pop_back();
|
2024-11-08 18:43:24 +00:00
|
|
|
|
|
|
|
|
|
const bool isFresh = seen.insert(id).second;
|
|
|
|
|
if (!isFresh)
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
|
|
repair(find(id));
|
2024-07-19 18:21:40 +01:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t size() const
|
|
|
|
|
{
|
|
|
|
|
return classes.size();
|
|
|
|
|
}
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
EClassT& operator[](Id id)
|
2024-07-19 18:21:40 +01:00
|
|
|
|
{
|
|
|
|
|
return get(find(id));
|
|
|
|
|
}
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
const EClassT& operator[](Id id) const
|
2024-07-19 18:21:40 +01:00
|
|
|
|
{
|
|
|
|
|
return const_cast<EGraph*>(this)->get(find(id));
|
|
|
|
|
}
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
const std::unordered_map<Id, EClassT>& getAllClasses() const
|
|
|
|
|
{
|
|
|
|
|
return classes;
|
|
|
|
|
}
|
|
|
|
|
|
2024-07-19 18:21:40 +01:00
|
|
|
|
private:
|
|
|
|
|
Analysis<L, N> analysis;
|
|
|
|
|
|
|
|
|
|
/// A union-find data structure 𝑈 stores an equivalence relation over e-class ids.
|
|
|
|
|
UnionFind unionfind;
|
|
|
|
|
|
|
|
|
|
/// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same
|
|
|
|
|
/// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the
|
|
|
|
|
/// e-class 𝑀[find(𝑎)].
|
2024-11-08 18:43:24 +00:00
|
|
|
|
std::unordered_map<Id, EClassT> classes;
|
2024-07-19 18:21:40 +01:00
|
|
|
|
|
|
|
|
|
/// The hashcons 𝐻 is a map from e-nodes to e-class ids.
|
|
|
|
|
std::unordered_map<L, Id, typename L::Hash> hashcons;
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
std::vector<Id> worklist;
|
2024-07-19 18:21:40 +01:00
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void canonicalize(L& enode)
|
|
|
|
|
{
|
|
|
|
|
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
|
|
|
|
|
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
|
2025-01-10 17:13:13 +00:00
|
|
|
|
Luau::EqSat::canonicalize(
|
|
|
|
|
enode,
|
|
|
|
|
[&](Id id)
|
|
|
|
|
{
|
|
|
|
|
return find(id);
|
|
|
|
|
}
|
|
|
|
|
);
|
2024-07-19 18:21:40 +01:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool isCanonical(const L& enode) const
|
|
|
|
|
{
|
|
|
|
|
bool canonical = true;
|
|
|
|
|
for (Id id : enode.operands())
|
|
|
|
|
canonical &= (id == find(id));
|
|
|
|
|
return canonical;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Id makeEClass(const L& enode)
|
|
|
|
|
{
|
|
|
|
|
LUAU_ASSERT(isCanonical(enode));
|
|
|
|
|
|
|
|
|
|
Id id = unionfind.makeSet();
|
|
|
|
|
|
2024-08-02 00:25:12 +01:00
|
|
|
|
classes.insert_or_assign(
|
2024-07-19 18:21:40 +01:00
|
|
|
|
id,
|
2024-11-08 18:43:24 +00:00
|
|
|
|
EClassT{
|
2024-08-02 00:25:12 +01:00
|
|
|
|
id,
|
|
|
|
|
{enode},
|
|
|
|
|
analysis.make(*this, enode),
|
|
|
|
|
{},
|
|
|
|
|
}
|
|
|
|
|
);
|
2024-07-19 18:21:40 +01:00
|
|
|
|
|
|
|
|
|
for (Id operand : enode.operands())
|
|
|
|
|
get(operand).parents.push_back({enode, id});
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
worklist.emplace_back(id);
|
2024-07-19 18:21:40 +01:00
|
|
|
|
hashcons.insert_or_assign(enode, id);
|
|
|
|
|
|
|
|
|
|
return id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Looks up for an eclass from a given non-canonicalized `id`.
|
|
|
|
|
// For a canonicalized eclass, use `get(find(id))` or `egraph[id]`.
|
2024-11-08 18:43:24 +00:00
|
|
|
|
EClassT& get(Id id)
|
2024-07-19 18:21:40 +01:00
|
|
|
|
{
|
2024-11-08 18:43:24 +00:00
|
|
|
|
LUAU_ASSERT(classes.count(id));
|
2024-07-19 18:21:40 +01:00
|
|
|
|
return classes.at(id);
|
|
|
|
|
}
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
void repair(Id id)
|
2024-07-19 18:21:40 +01:00
|
|
|
|
{
|
|
|
|
|
// In the egg paper, the `repair` function makes use of two loops over the `eclass.parents`
|
|
|
|
|
// by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id.
|
|
|
|
|
// And then in another loop that follows, deduplicate it.
|
|
|
|
|
//
|
|
|
|
|
// Here, we unify the two loops. I think it's equivalent?
|
|
|
|
|
|
|
|
|
|
// After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent.
|
2024-11-08 18:43:24 +00:00
|
|
|
|
std::unordered_map<L, Id, typename L::Hash> newParents;
|
|
|
|
|
|
|
|
|
|
// The eclass can be deallocated if it is merged into another eclass, so
|
|
|
|
|
// we take what we need from it and avoid retaining a pointer.
|
|
|
|
|
std::vector<std::pair<L, Id>> parents = get(id).parents;
|
|
|
|
|
for (auto& pair : parents)
|
2024-07-19 18:21:40 +01:00
|
|
|
|
{
|
2024-11-08 18:43:24 +00:00
|
|
|
|
L& enode = pair.first;
|
|
|
|
|
Id id = pair.second;
|
|
|
|
|
|
2024-07-19 18:21:40 +01:00
|
|
|
|
// By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id.
|
|
|
|
|
hashcons.erase(enode);
|
|
|
|
|
canonicalize(enode);
|
|
|
|
|
hashcons.insert_or_assign(enode, find(id));
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
if (auto it = newParents.find(enode); it != newParents.end())
|
2024-07-19 18:21:40 +01:00
|
|
|
|
merge(id, it->second);
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
newParents.insert_or_assign(enode, find(id));
|
2024-07-19 18:21:40 +01:00
|
|
|
|
}
|
|
|
|
|
|
2024-11-08 18:43:24 +00:00
|
|
|
|
// We reacquire the pointer because the prior loop potentially merges
|
|
|
|
|
// the eclass into another, which might move it around in memory.
|
|
|
|
|
EClassT* eclass = &get(find(id));
|
|
|
|
|
|
|
|
|
|
eclass->parents.clear();
|
|
|
|
|
|
|
|
|
|
for (const auto& [node, id] : newParents)
|
|
|
|
|
eclass->parents.emplace_back(std::move(node), std::move(id));
|
|
|
|
|
|
|
|
|
|
std::unordered_set<L, typename L::Hash> newNodes;
|
|
|
|
|
for (L node : eclass->nodes)
|
2024-07-19 18:21:40 +01:00
|
|
|
|
{
|
2024-11-08 18:43:24 +00:00
|
|
|
|
canonicalize(node);
|
|
|
|
|
newNodes.insert(std::move(node));
|
2024-07-19 18:21:40 +01:00
|
|
|
|
}
|
2024-11-08 18:43:24 +00:00
|
|
|
|
|
|
|
|
|
eclass->nodes.assign(newNodes.begin(), newNodes.end());
|
|
|
|
|
|
|
|
|
|
// FIXME: Extract into sortByTag()
|
|
|
|
|
std::sort(
|
|
|
|
|
eclass->nodes.begin(),
|
|
|
|
|
eclass->nodes.end(),
|
|
|
|
|
[](const L& left, const L& right)
|
|
|
|
|
{
|
|
|
|
|
return left.index() < right.index();
|
|
|
|
|
}
|
|
|
|
|
);
|
2024-07-19 18:21:40 +01:00
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace Luau::EqSat
|