diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index 43e6c69e..6029b18f 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -3,6 +3,7 @@ #include "Luau/Common.h" #include "Luau/Id.h" +#include "Luau/Language.h" #include "Luau/UnionFind.h" #include "Luau/VecDeque.h" @@ -23,9 +24,19 @@ struct Analysis final using D = typename N::Data; - D make(const EGraph& egraph, const L& enode) const + template + static D fnMake(const N& analysis, const EGraph& egraph, const L& enode) { - return analysis.make(egraph, enode); + return analysis.make(egraph, *enode.template get()); + } + + template + D make(const EGraph& egraph, const Language& enode) const + { + using FnMake = D (*)(const N&, const EGraph&, const L&); + static constexpr FnMake tableMake[sizeof...(Ts)] = {&fnMake...}; + + return tableMake[enode.index()](analysis, egraph, enode); } void join(D& a, const D& b) diff --git a/tests/EqSat.propositional.test.cpp b/tests/EqSat.propositional.test.cpp index 4358282c..5e019adc 100644 --- a/tests/EqSat.propositional.test.cpp +++ b/tests/EqSat.propositional.test.cpp @@ -32,38 +32,51 @@ struct ConstantFold { using Data = std::optional; - Data make(const EGraph& egraph, const PropositionalLogic& enode) const + Data make(const EGraph& egraph, const Var& var) const { - if (enode.get()) - return std::nullopt; - else if (auto b = enode.get()) - return b->value; - else if (auto n = enode.get()) - { - if (auto data = egraph[n->field()].data) - return !*data; - } - else if (auto a = enode.get()) - { - Data left = egraph[a->field()].data; - Data right = egraph[a->field()].data; - if (left && right) - return *left && *right; - } - else if (auto o = enode.get()) - { - Data left = egraph[o->field()].data; - Data right = egraph[o->field()].data; - if (left && right) - return *left && *right; - } - else if (auto i = enode.get()) - { - Data antecedent = egraph[i->field()].data; - Data consequent = egraph[i->field()].data; - if (antecedent && consequent) - return !*antecedent || *consequent; - } + return std::nullopt; + } + + Data make(const EGraph& egraph, const Bool& b) const + { + return b.value; + } + + Data make(const EGraph& egraph, const Not& n) const + { + Data data = egraph[n.field()].data; + if (data) + return !*data; + + return std::nullopt; + } + + Data make(const EGraph& egraph, const And& a) const + { + Data left = egraph[a.field()].data; + Data right = egraph[a.field()].data; + if (left && right) + return *left && *right; + + return std::nullopt; + } + + Data make(const EGraph& egraph, const Or& o) const + { + Data left = egraph[o.field()].data; + Data right = egraph[o.field()].data; + if (left && right) + return *left || *right; + + return std::nullopt; + } + + Data make(const EGraph& egraph, const Implies& i) const + { + Data antecedent = egraph[i.field()].data; + Data consequent = egraph[i.field()].data; + if (antecedent && consequent) + return !*antecedent || *consequent; return std::nullopt; }