Implement Language::operands() which means EGraph is able to canonicalize.

This commit is contained in:
Alexander McCord 2024-05-27 17:05:02 -07:00
parent d9835952b7
commit c8a1651c89
8 changed files with 246 additions and 9 deletions

View file

@ -4,6 +4,7 @@
#include "Luau/Id.h"
#include "Luau/UnionFind.h"
#include <optional>
#include <unordered_map>
#include <utility>
#include <vector>
@ -39,13 +40,29 @@ struct EGraph final
return unionfind.find(id);
}
// Per the egg paper, definition 2.2 (Canonicalization):
//
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
//
// Doing so requires sketching out `Luau::EqSat::Language` which
// I want to do at a later time for the time being. Will revisit.
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
std::optional<Id> lookup(L enode) const
{
for (Id& id : enode.operands())
id = find(id);
if (auto it = hashcons.find(enode); it != hashcons.end())
return it->second;
return std::nullopt;
}
// TODO: `add`. For now, we call it shoveItIn so it's obvious it's just for testing.
Id shoveItIn(L enode)
{
if (auto id = lookup(enode))
return *id;
Id id{hashcons.size()};
hashcons.insert_or_assign(enode, id);
return id;
}
private:
/// A union-find data structure 𝑈 stores an equivalence relation over e-class ids.

View file

@ -2,6 +2,7 @@
#pragma once
#include "Luau/Id.h"
#include "Luau/Slice.h"
#include <array>
#include <algorithm>
@ -63,6 +64,11 @@ struct Atom
{
T value;
Slice<Id> operands()
{
return {};
}
bool operator==(const Atom& rhs) const
{
return value == rhs.value;
@ -123,8 +129,13 @@ public:
{
}
Slice<Id> operands()
{
return Slice{array};
}
template<typename T>
Id field() const
const Id& field() const
{
static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Fields>...>);
return array[getIndex<T>()];
@ -175,6 +186,7 @@ private:
using FnDtor = void (*)(void*);
using FnPred = bool (*)(const void*, const void*);
using FnHash = size_t (*)(const void*);
using FnOper = Slice<Id> (*)(void*);
template<typename T>
static void fnCopy(void* dst, const void* src)
@ -206,11 +218,18 @@ private:
return typename T::Hash{}(*static_cast<const T*>(buffer));
}
template<typename T>
static Slice<Id> fnOper(void* buffer)
{
return static_cast<T*>(buffer)->operands();
}
static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy<Ts>...};
static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...};
static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...};
static constexpr FnPred tablePred[sizeof...(Ts)] = {&fnPred<Ts>...};
static constexpr FnHash tableHash[sizeof...(Ts)] = {&fnHash<Ts>...};
static constexpr FnOper tableOper[sizeof...(Ts)] = {&fnOper<Ts>...};
static constexpr int getIndexFromTag(const char* tag)
{
@ -267,6 +286,18 @@ public:
return *this;
}
int index() const
{
return getIndexFromTag(tag);
}
/// You should never call this function with the intention of mutating the `Id`.
/// Reading is ok, but you should also never assume that these `Id`s are stable.
Slice<Id> operands()
{
return tableOper[getIndexFromTag(tag)](&buffer);
}
template<typename T>
const T* get() const
{

View file

@ -0,0 +1,77 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Common.h"
#include <array>
namespace Luau::EqSat
{
template<typename T>
class Slice
{
T* _data;
size_t _size;
public:
Slice()
: _data(nullptr)
, _size(0)
{
}
/// Use this constructor if you have a dynamically sized vector.
/// The slice is valid for as long as the backing vector has not moved
/// elsewhere in memory.
///
/// In general, a slice should never be used from vectors except for
/// any vectors whose size are statically unknown, but remains fixed
/// upon the construction of such a slice over a vector.
Slice(T* first, size_t last)
: _data(first)
, _size(last)
{
}
template<size_t I>
explicit Slice(std::array<T, I>& array)
: _data(array.data())
, _size(array.size())
{
}
T* data() const
{
return _data;
}
size_t size() const
{
return _size;
}
bool empty() const
{
return _size == 0;
}
T& operator[](size_t i) const
{
LUAU_ASSERT(i < _size);
return _data[i];
}
public:
T* begin() const
{
return _data;
}
T* end() const
{
return _data + _size;
}
};
}

View file

@ -17,6 +17,7 @@ Id UnionFind::find(Id id) const
{
LUAU_ASSERT(size_t(id) < parents.size());
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
while (id != parents[size_t(id)])
id = parents[size_t(id)];

View file

@ -300,6 +300,7 @@ target_sources(Luau.EqSat PRIVATE
EqSat/include/Luau/EGraph.h
EqSat/include/Luau/Id.h
EqSat/include/Luau/Language.h
EqSat/include/Luau/Slice.h
EqSat/include/Luau/UnionFind.h
EqSat/src/Id.cpp
@ -431,6 +432,7 @@ if(TARGET Luau.UnitTest)
tests/Differ.test.cpp
tests/EqSat.language.test.cpp
tests/EqSat.propositional.test.cpp
tests/EqSat.slice.test.cpp
tests/Error.test.cpp
tests/Fixture.cpp
tests/Fixture.h

View file

@ -127,4 +127,19 @@ TEST_CASE("node_field")
CHECK(right != left2);
}
TEST_CASE("language_operands")
{
Value v1{I32{0}};
CHECK(v1.operands().empty());
Value v2{Add{EqSat::Id{0}, EqSat::Id{1}}};
const Add* add = v2.get<Add>();
REQUIRE(add);
EqSat::Slice<EqSat::Id> actual = v2.operands();
CHECK(actual.size() == 2);
CHECK(actual[0] == add->field<Left>());
CHECK(actual[1] == add->field<Right>());
}
TEST_SUITE_END();

View file

@ -1,8 +1,44 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include <doctest.h>
// Var, Bool, And, Or, Not, Implies
#include "Luau/EGraph.h"
#include "Luau/Language.h"
#include <optional>
LUAU_EQSAT_ATOM(Var, std::string);
LUAU_EQSAT_ATOM(Bool, bool);
LUAU_EQSAT_FIELD(Negated);
LUAU_EQSAT_UNARY_NODE(Not, Negated);
LUAU_EQSAT_FIELD(Left);
LUAU_EQSAT_FIELD(Right);
LUAU_EQSAT_BINARY_NODE(And, Left, Right);
LUAU_EQSAT_BINARY_NODE(Or, Left, Right);
LUAU_EQSAT_FIELD(Antecedent);
LUAU_EQSAT_FIELD(Consequent);
LUAU_EQSAT_BINARY_NODE(Implies, Antecedent, Consequent);
using namespace Luau;
using PropositionalLogic = EqSat::Language<Var, Bool, Not, And, Or, Implies>;
struct ConstantFold
{
using Data = std::optional<bool>;
};
using EGraph = EqSat::EGraph<PropositionalLogic, ConstantFold>;
TEST_SUITE_BEGIN("EqSatPropositionalLogic");
TEST_CASE("egraph_hashconsing")
{
EGraph egraph;
CHECK(egraph.shoveItIn(Bool{true}) == egraph.shoveItIn(Bool{true}));
}
TEST_SUITE_END();

View file

@ -0,0 +1,58 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include <doctest.h>
#include "Luau/Slice.h"
#include <vector>
using namespace Luau;
TEST_SUITE_BEGIN("EqSatLanguage");
TEST_CASE("slice_is_a_view_over_array")
{
std::array<int, 8> a{1, 2, 3, 4, 5, 6, 7, 8};
EqSat::Slice<int> slice{a};
CHECK(slice.data() == a.data());
CHECK(slice.size() == a.size());
for (size_t i = 0; i < a.size(); ++i)
{
CHECK(slice[i] == a[i]);
CHECK(&slice[i] == &a[i]);
}
}
TEST_CASE("slice_is_a_view_over_vector")
{
std::vector<int> vector{1, 2, 3, 4, 5, 6, 7, 8};
EqSat::Slice<int> slice{vector.data(), vector.size()};
CHECK(slice.data() == vector.data());
CHECK(slice.size() == vector.size());
for (size_t i = 0; i < vector.size(); ++i)
{
CHECK(slice[i] == vector[i]);
CHECK(&slice[i] == &vector[i]);
}
}
TEST_CASE("mutate_via_slice")
{
std::array<int, 2> a{1, 2};
CHECK(a[0] == 1);
CHECK(a[1] == 2);
EqSat::Slice<int> slice{a};
slice[0] = 42;
slice[1] = 37;
CHECK(a[0] == 42);
CHECK(a[1] == 37);
}
TEST_SUITE_END();