Implement a basic version of Analysis. Will expand it later.

This commit is contained in:
Alexander McCord 2024-05-28 08:31:25 -07:00
parent cd735ebad3
commit 4122f3fbcb
2 changed files with 31 additions and 8 deletions

View file

@ -13,6 +13,20 @@
namespace Luau::EqSat namespace Luau::EqSat
{ {
template<typename L, typename N>
class EGraph;
template<typename L, typename N>
struct Analysis final
{
N analysis;
typename N::Data make(const EGraph<L, N>& egraph, const L& enode) const
{
return analysis.make(egraph, enode);
}
};
/// Each e-class is a set of e-nodes representing equivalent terms from a given language, /// Each e-class is a set of e-nodes representing equivalent terms from a given language,
/// and an e-node is a function symbol paired with a list of children e-classes. /// and an e-node is a function symbol paired with a list of children e-classes.
template<typename L, typename D> template<typename L, typename D>
@ -28,6 +42,8 @@ struct EClass final
template<typename L, typename N> template<typename L, typename N>
class EGraph final class EGraph final
{ {
Analysis<L, N> analysis;
/// A union-find data structure 𝑈 stores an equivalence relation over e-class ids. /// A union-find data structure 𝑈 stores an equivalence relation over e-class ids.
UnionFind unionfind; UnionFind unionfind;
@ -55,16 +71,13 @@ private:
classes.insert_or_assign(id, EClass<L, typename N::Data>{ classes.insert_or_assign(id, EClass<L, typename N::Data>{
id, id,
{enode}, {enode},
{}, // TODO: analysis make analysis.make(*this, enode),
{}, {},
}); });
return id; return id;
} }
public: public:
// TODO: static_assert L <: Language
// TODO: static_assert N <: Analysis<L>
Id find(Id id) const Id find(Id id) const
{ {
return unionfind.find(id); return unionfind.find(id);

View file

@ -26,12 +26,22 @@ using namespace Luau;
using PropositionalLogic = EqSat::Language<Var, Bool, Not, And, Or, Implies>; using PropositionalLogic = EqSat::Language<Var, Bool, Not, And, Or, Implies>;
using EGraph = EqSat::EGraph<PropositionalLogic, struct ConstantFold>;
struct ConstantFold struct ConstantFold
{ {
using Data = std::optional<bool>; using Data = std::optional<bool>;
};
using EGraph = EqSat::EGraph<PropositionalLogic, ConstantFold>; Data make(const EGraph& egraph, const PropositionalLogic& enode) const
{
if (enode.get<Var>())
return std::nullopt;
else if (auto b = enode.get<Bool>())
return b->value;
return std::nullopt;
}
};
TEST_SUITE_BEGIN("EqSatPropositionalLogic"); TEST_SUITE_BEGIN("EqSatPropositionalLogic");
@ -54,8 +64,8 @@ TEST_CASE("egraph_data")
EqSat::Id id1 = egraph.add(Bool{true}); EqSat::Id id1 = egraph.add(Bool{true});
EqSat::Id id2 = egraph.add(Bool{false}); EqSat::Id id2 = egraph.add(Bool{false});
CHECK(egraph[id1].data == std::nullopt); // TODO: true CHECK(egraph[id1].data == true);
CHECK(egraph[id2].data == std::nullopt); // TODO: false CHECK(egraph[id2].data == false);
} }
TEST_SUITE_END(); TEST_SUITE_END();