mirror of
https://github.com/luau-lang/luau.git
synced 2025-04-04 10:50:54 +01:00
Implement Language::operands() which means EGraph is able to canonicalize.
This commit is contained in:
parent
d9835952b7
commit
c8a1651c89
8 changed files with 246 additions and 9 deletions
|
@ -4,6 +4,7 @@
|
|||
#include "Luau/Id.h"
|
||||
#include "Luau/UnionFind.h"
|
||||
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
@ -39,13 +40,29 @@ struct EGraph final
|
|||
return unionfind.find(id);
|
||||
}
|
||||
|
||||
// Per the egg paper, definition 2.2 (Canonicalization):
|
||||
//
|
||||
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
|
||||
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
|
||||
//
|
||||
// Doing so requires sketching out `Luau::EqSat::Language` which
|
||||
// I want to do at a later time for the time being. Will revisit.
|
||||
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
|
||||
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
|
||||
std::optional<Id> lookup(L enode) const
|
||||
{
|
||||
for (Id& id : enode.operands())
|
||||
id = find(id);
|
||||
|
||||
if (auto it = hashcons.find(enode); it != hashcons.end())
|
||||
return it->second;
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
// TODO: `add`. For now, we call it shoveItIn so it's obvious it's just for testing.
|
||||
Id shoveItIn(L enode)
|
||||
{
|
||||
if (auto id = lookup(enode))
|
||||
return *id;
|
||||
|
||||
Id id{hashcons.size()};
|
||||
hashcons.insert_or_assign(enode, id);
|
||||
return id;
|
||||
}
|
||||
|
||||
private:
|
||||
/// A union-find data structure 𝑈 stores an equivalence relation over e-class ids.
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
#pragma once
|
||||
|
||||
#include "Luau/Id.h"
|
||||
#include "Luau/Slice.h"
|
||||
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
|
@ -63,6 +64,11 @@ struct Atom
|
|||
{
|
||||
T value;
|
||||
|
||||
Slice<Id> operands()
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
bool operator==(const Atom& rhs) const
|
||||
{
|
||||
return value == rhs.value;
|
||||
|
@ -123,8 +129,13 @@ public:
|
|||
{
|
||||
}
|
||||
|
||||
Slice<Id> operands()
|
||||
{
|
||||
return Slice{array};
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
Id field() const
|
||||
const Id& field() const
|
||||
{
|
||||
static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Fields>...>);
|
||||
return array[getIndex<T>()];
|
||||
|
@ -175,6 +186,7 @@ private:
|
|||
using FnDtor = void (*)(void*);
|
||||
using FnPred = bool (*)(const void*, const void*);
|
||||
using FnHash = size_t (*)(const void*);
|
||||
using FnOper = Slice<Id> (*)(void*);
|
||||
|
||||
template<typename T>
|
||||
static void fnCopy(void* dst, const void* src)
|
||||
|
@ -206,11 +218,18 @@ private:
|
|||
return typename T::Hash{}(*static_cast<const T*>(buffer));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static Slice<Id> fnOper(void* buffer)
|
||||
{
|
||||
return static_cast<T*>(buffer)->operands();
|
||||
}
|
||||
|
||||
static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy<Ts>...};
|
||||
static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...};
|
||||
static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...};
|
||||
static constexpr FnPred tablePred[sizeof...(Ts)] = {&fnPred<Ts>...};
|
||||
static constexpr FnHash tableHash[sizeof...(Ts)] = {&fnHash<Ts>...};
|
||||
static constexpr FnOper tableOper[sizeof...(Ts)] = {&fnOper<Ts>...};
|
||||
|
||||
static constexpr int getIndexFromTag(const char* tag)
|
||||
{
|
||||
|
@ -267,6 +286,18 @@ public:
|
|||
return *this;
|
||||
}
|
||||
|
||||
int index() const
|
||||
{
|
||||
return getIndexFromTag(tag);
|
||||
}
|
||||
|
||||
/// You should never call this function with the intention of mutating the `Id`.
|
||||
/// Reading is ok, but you should also never assume that these `Id`s are stable.
|
||||
Slice<Id> operands()
|
||||
{
|
||||
return tableOper[getIndexFromTag(tag)](&buffer);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
const T* get() const
|
||||
{
|
||||
|
|
77
EqSat/include/Luau/Slice.h
Normal file
77
EqSat/include/Luau/Slice.h
Normal file
|
@ -0,0 +1,77 @@
|
|||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#pragma once
|
||||
|
||||
#include "Luau/Common.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
namespace Luau::EqSat
|
||||
{
|
||||
|
||||
template<typename T>
|
||||
class Slice
|
||||
{
|
||||
T* _data;
|
||||
size_t _size;
|
||||
|
||||
public:
|
||||
Slice()
|
||||
: _data(nullptr)
|
||||
, _size(0)
|
||||
{
|
||||
}
|
||||
|
||||
/// Use this constructor if you have a dynamically sized vector.
|
||||
/// The slice is valid for as long as the backing vector has not moved
|
||||
/// elsewhere in memory.
|
||||
///
|
||||
/// In general, a slice should never be used from vectors except for
|
||||
/// any vectors whose size are statically unknown, but remains fixed
|
||||
/// upon the construction of such a slice over a vector.
|
||||
Slice(T* first, size_t last)
|
||||
: _data(first)
|
||||
, _size(last)
|
||||
{
|
||||
}
|
||||
|
||||
template<size_t I>
|
||||
explicit Slice(std::array<T, I>& array)
|
||||
: _data(array.data())
|
||||
, _size(array.size())
|
||||
{
|
||||
}
|
||||
|
||||
T* data() const
|
||||
{
|
||||
return _data;
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
return _size;
|
||||
}
|
||||
|
||||
bool empty() const
|
||||
{
|
||||
return _size == 0;
|
||||
}
|
||||
|
||||
T& operator[](size_t i) const
|
||||
{
|
||||
LUAU_ASSERT(i < _size);
|
||||
return _data[i];
|
||||
}
|
||||
|
||||
public:
|
||||
T* begin() const
|
||||
{
|
||||
return _data;
|
||||
}
|
||||
|
||||
T* end() const
|
||||
{
|
||||
return _data + _size;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
|
@ -17,6 +17,7 @@ Id UnionFind::find(Id id) const
|
|||
{
|
||||
LUAU_ASSERT(size_t(id) < parents.size());
|
||||
|
||||
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
|
||||
while (id != parents[size_t(id)])
|
||||
id = parents[size_t(id)];
|
||||
|
||||
|
|
|
@ -300,6 +300,7 @@ target_sources(Luau.EqSat PRIVATE
|
|||
EqSat/include/Luau/EGraph.h
|
||||
EqSat/include/Luau/Id.h
|
||||
EqSat/include/Luau/Language.h
|
||||
EqSat/include/Luau/Slice.h
|
||||
EqSat/include/Luau/UnionFind.h
|
||||
|
||||
EqSat/src/Id.cpp
|
||||
|
@ -431,6 +432,7 @@ if(TARGET Luau.UnitTest)
|
|||
tests/Differ.test.cpp
|
||||
tests/EqSat.language.test.cpp
|
||||
tests/EqSat.propositional.test.cpp
|
||||
tests/EqSat.slice.test.cpp
|
||||
tests/Error.test.cpp
|
||||
tests/Fixture.cpp
|
||||
tests/Fixture.h
|
||||
|
|
|
@ -127,4 +127,19 @@ TEST_CASE("node_field")
|
|||
CHECK(right != left2);
|
||||
}
|
||||
|
||||
TEST_CASE("language_operands")
|
||||
{
|
||||
Value v1{I32{0}};
|
||||
CHECK(v1.operands().empty());
|
||||
|
||||
Value v2{Add{EqSat::Id{0}, EqSat::Id{1}}};
|
||||
const Add* add = v2.get<Add>();
|
||||
REQUIRE(add);
|
||||
|
||||
EqSat::Slice<EqSat::Id> actual = v2.operands();
|
||||
CHECK(actual.size() == 2);
|
||||
CHECK(actual[0] == add->field<Left>());
|
||||
CHECK(actual[1] == add->field<Right>());
|
||||
}
|
||||
|
||||
TEST_SUITE_END();
|
||||
|
|
|
@ -1,8 +1,44 @@
|
|||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#include <doctest.h>
|
||||
|
||||
// Var, Bool, And, Or, Not, Implies
|
||||
#include "Luau/EGraph.h"
|
||||
#include "Luau/Language.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
LUAU_EQSAT_ATOM(Var, std::string);
|
||||
LUAU_EQSAT_ATOM(Bool, bool);
|
||||
|
||||
LUAU_EQSAT_FIELD(Negated);
|
||||
LUAU_EQSAT_UNARY_NODE(Not, Negated);
|
||||
|
||||
LUAU_EQSAT_FIELD(Left);
|
||||
LUAU_EQSAT_FIELD(Right);
|
||||
LUAU_EQSAT_BINARY_NODE(And, Left, Right);
|
||||
LUAU_EQSAT_BINARY_NODE(Or, Left, Right);
|
||||
|
||||
LUAU_EQSAT_FIELD(Antecedent);
|
||||
LUAU_EQSAT_FIELD(Consequent);
|
||||
LUAU_EQSAT_BINARY_NODE(Implies, Antecedent, Consequent);
|
||||
|
||||
using namespace Luau;
|
||||
|
||||
using PropositionalLogic = EqSat::Language<Var, Bool, Not, And, Or, Implies>;
|
||||
|
||||
struct ConstantFold
|
||||
{
|
||||
using Data = std::optional<bool>;
|
||||
};
|
||||
|
||||
using EGraph = EqSat::EGraph<PropositionalLogic, ConstantFold>;
|
||||
|
||||
TEST_SUITE_BEGIN("EqSatPropositionalLogic");
|
||||
|
||||
TEST_CASE("egraph_hashconsing")
|
||||
{
|
||||
EGraph egraph;
|
||||
|
||||
CHECK(egraph.shoveItIn(Bool{true}) == egraph.shoveItIn(Bool{true}));
|
||||
}
|
||||
|
||||
TEST_SUITE_END();
|
||||
|
|
58
tests/EqSat.slice.test.cpp
Normal file
58
tests/EqSat.slice.test.cpp
Normal file
|
@ -0,0 +1,58 @@
|
|||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||
#include <doctest.h>
|
||||
|
||||
#include "Luau/Slice.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
using namespace Luau;
|
||||
|
||||
TEST_SUITE_BEGIN("EqSatLanguage");
|
||||
|
||||
TEST_CASE("slice_is_a_view_over_array")
|
||||
{
|
||||
std::array<int, 8> a{1, 2, 3, 4, 5, 6, 7, 8};
|
||||
|
||||
EqSat::Slice<int> slice{a};
|
||||
|
||||
CHECK(slice.data() == a.data());
|
||||
CHECK(slice.size() == a.size());
|
||||
|
||||
for (size_t i = 0; i < a.size(); ++i)
|
||||
{
|
||||
CHECK(slice[i] == a[i]);
|
||||
CHECK(&slice[i] == &a[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("slice_is_a_view_over_vector")
|
||||
{
|
||||
std::vector<int> vector{1, 2, 3, 4, 5, 6, 7, 8};
|
||||
|
||||
EqSat::Slice<int> slice{vector.data(), vector.size()};
|
||||
|
||||
CHECK(slice.data() == vector.data());
|
||||
CHECK(slice.size() == vector.size());
|
||||
|
||||
for (size_t i = 0; i < vector.size(); ++i)
|
||||
{
|
||||
CHECK(slice[i] == vector[i]);
|
||||
CHECK(&slice[i] == &vector[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("mutate_via_slice")
|
||||
{
|
||||
std::array<int, 2> a{1, 2};
|
||||
CHECK(a[0] == 1);
|
||||
CHECK(a[1] == 2);
|
||||
|
||||
EqSat::Slice<int> slice{a};
|
||||
slice[0] = 42;
|
||||
slice[1] = 37;
|
||||
|
||||
CHECK(a[0] == 42);
|
||||
CHECK(a[1] == 37);
|
||||
}
|
||||
|
||||
TEST_SUITE_END();
|
Loading…
Add table
Reference in a new issue