mirror of
https://github.com/luau-lang/luau.git
synced 2025-04-18 10:53:45 +01:00
Implement egraph merging of two ids.
This commit is contained in:
parent
15167ad497
commit
4183c31d23
2 changed files with 105 additions and 13 deletions
|
@ -4,10 +4,10 @@
|
||||||
#include "Luau/Common.h"
|
#include "Luau/Common.h"
|
||||||
#include "Luau/Id.h"
|
#include "Luau/Id.h"
|
||||||
#include "Luau/UnionFind.h"
|
#include "Luau/UnionFind.h"
|
||||||
|
#include "Luau/VecDeque.h"
|
||||||
|
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace Luau::EqSat
|
namespace Luau::EqSat
|
||||||
|
@ -21,10 +21,17 @@ struct Analysis final
|
||||||
{
|
{
|
||||||
N analysis;
|
N analysis;
|
||||||
|
|
||||||
typename N::Data make(const EGraph<L, N>& egraph, const L& enode) const
|
using D = typename N::Data;
|
||||||
|
|
||||||
|
D make(const EGraph<L, N>& egraph, const L& enode) const
|
||||||
{
|
{
|
||||||
return analysis.make(egraph, enode);
|
return analysis.make(egraph, enode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void join(D& a, const D& b)
|
||||||
|
{
|
||||||
|
return analysis.join(a, b);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Each e-class is a set of e-nodes representing equivalent terms from a given language,
|
/// Each e-class is a set of e-nodes representing equivalent terms from a given language,
|
||||||
|
@ -56,8 +63,7 @@ class EGraph final
|
||||||
std::unordered_map<L, Id, typename L::Hash> hashcons;
|
std::unordered_map<L, Id, typename L::Hash> hashcons;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template<typename T>
|
void canonicalize(L& enode)
|
||||||
void canonicalize(T&& enode)
|
|
||||||
{
|
{
|
||||||
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
|
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
|
||||||
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
|
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
|
||||||
|
@ -65,27 +71,54 @@ private:
|
||||||
id = find(id);
|
id = find(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isCanonical(const L& enode) const
|
||||||
|
{
|
||||||
|
bool canonical = true;
|
||||||
|
for (Id id : enode.operands())
|
||||||
|
canonical &= (id == find(id));
|
||||||
|
return canonical;
|
||||||
|
}
|
||||||
|
|
||||||
Id makeEClass(const L& enode)
|
Id makeEClass(const L& enode)
|
||||||
{
|
{
|
||||||
|
LUAU_ASSERT(isCanonical(enode));
|
||||||
|
|
||||||
Id id = unionfind.makeSet();
|
Id id = unionfind.makeSet();
|
||||||
|
|
||||||
classes.insert_or_assign(id, EClass<L, typename N::Data>{
|
classes.insert_or_assign(id, EClass<L, typename N::Data>{
|
||||||
id,
|
id,
|
||||||
{enode},
|
{enode},
|
||||||
analysis.make(*this, enode),
|
analysis.make(*this, enode),
|
||||||
{},
|
{},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
for (Id operand : enode.operands())
|
||||||
|
get(operand).parents.push_back({enode, id});
|
||||||
|
|
||||||
|
hashcons.insert_or_assign(enode, id);
|
||||||
|
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Looks up for an eclass from a given non-canonicalized `id`.
|
||||||
|
// For a canonicalized eclass, use `get(find(id))` or `egraph[id]`.
|
||||||
|
EClass<L, typename N::Data>& get(Id id)
|
||||||
|
{
|
||||||
|
auto it = classes.find(id);
|
||||||
|
LUAU_ASSERT(it != classes.end());
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Id find(Id id) const
|
Id find(Id id) const
|
||||||
{
|
{
|
||||||
return unionfind.find(id);
|
return unionfind.find(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
std::optional<Id> lookup(const L& enode) const
|
||||||
std::optional<Id> lookup(T&& enode) const
|
|
||||||
{
|
{
|
||||||
|
LUAU_ASSERT(isCanonical(enode));
|
||||||
|
|
||||||
if (auto it = hashcons.find(enode); it != hashcons.end())
|
if (auto it = hashcons.find(enode); it != hashcons.end())
|
||||||
return it->second;
|
return it->second;
|
||||||
|
|
||||||
|
@ -100,19 +133,34 @@ public:
|
||||||
return *id;
|
return *id;
|
||||||
|
|
||||||
Id id = makeEClass(enode);
|
Id id = makeEClass(enode);
|
||||||
for (Id operand : enode.operands())
|
|
||||||
(*this)[operand].parents.push_back({enode, id});
|
|
||||||
|
|
||||||
hashcons.insert_or_assign(enode, id);
|
|
||||||
// TODO clean = false
|
// TODO clean = false
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void merge(Id id1, Id id2)
|
||||||
|
{
|
||||||
|
id1 = find(id1);
|
||||||
|
id2 = find(id2);
|
||||||
|
if (id1 == id2)
|
||||||
|
return;
|
||||||
|
|
||||||
|
unionfind.merge(id1, id2);
|
||||||
|
|
||||||
|
EClass<L, typename N::Data>& eclass1 = get(id1);
|
||||||
|
EClass<L, typename N::Data> eclass2 = get(id2);
|
||||||
|
classes.erase(id2);
|
||||||
|
|
||||||
|
analysis.join(eclass1.data, eclass2.data);
|
||||||
|
}
|
||||||
|
|
||||||
EClass<L, typename N::Data>& operator[](Id id)
|
EClass<L, typename N::Data>& operator[](Id id)
|
||||||
{
|
{
|
||||||
auto it = classes.find(find(id));
|
return get(find(id));
|
||||||
LUAU_ASSERT(it != classes.end());
|
}
|
||||||
return it->second;
|
|
||||||
|
const EClass<L, typename N::Data>& operator[](Id id) const
|
||||||
|
{
|
||||||
|
return const_cast<EGraph*>(this)->get(find(id));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -38,9 +38,41 @@ struct ConstantFold
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
else if (auto b = enode.get<Bool>())
|
else if (auto b = enode.get<Bool>())
|
||||||
return b->value;
|
return b->value;
|
||||||
|
else if (auto n = enode.get<Not>())
|
||||||
|
{
|
||||||
|
if (auto data = egraph[n->field<Negated>()].data)
|
||||||
|
return !*data;
|
||||||
|
}
|
||||||
|
else if (auto a = enode.get<And>())
|
||||||
|
{
|
||||||
|
Data left = egraph[a->field<Left>()].data;
|
||||||
|
Data right = egraph[a->field<Right>()].data;
|
||||||
|
if (left && right)
|
||||||
|
return *left && *right;
|
||||||
|
}
|
||||||
|
else if (auto o = enode.get<Or>())
|
||||||
|
{
|
||||||
|
Data left = egraph[o->field<Left>()].data;
|
||||||
|
Data right = egraph[o->field<Right>()].data;
|
||||||
|
if (left && right)
|
||||||
|
return *left && *right;
|
||||||
|
}
|
||||||
|
else if (auto i = enode.get<Implies>())
|
||||||
|
{
|
||||||
|
Data antecedent = egraph[i->field<Antecedent>()].data;
|
||||||
|
Data consequent = egraph[i->field<Consequent>()].data;
|
||||||
|
if (antecedent && consequent)
|
||||||
|
return !*antecedent || *consequent;
|
||||||
|
}
|
||||||
|
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void join(Data& a, const Data& b)
|
||||||
|
{
|
||||||
|
if (!a && b)
|
||||||
|
a = b;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_SUITE_BEGIN("EqSatPropositionalLogic");
|
TEST_SUITE_BEGIN("EqSatPropositionalLogic");
|
||||||
|
@ -68,4 +100,16 @@ TEST_CASE("egraph_data")
|
||||||
CHECK(egraph[id2].data == false);
|
CHECK(egraph[id2].data == false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("egraph_merge")
|
||||||
|
{
|
||||||
|
EGraph egraph;
|
||||||
|
|
||||||
|
EqSat::Id id1 = egraph.add(Var{"a"});
|
||||||
|
EqSat::Id id2 = egraph.add(Bool{true});
|
||||||
|
egraph.merge(id1, id2);
|
||||||
|
|
||||||
|
CHECK(egraph[id1].data == true);
|
||||||
|
CHECK(egraph[id2].data == true);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_SUITE_END();
|
TEST_SUITE_END();
|
||||||
|
|
Loading…
Add table
Reference in a new issue