mirror of
https://github.com/luau-lang/luau.git
synced 2025-04-10 22:00:54 +01:00
Implement Language::operands() which means EGraph is able to canonicalize.
This commit is contained in:
parent
d9835952b7
commit
c8a1651c89
8 changed files with 246 additions and 9 deletions
|
@ -4,6 +4,7 @@
|
||||||
#include "Luau/Id.h"
|
#include "Luau/Id.h"
|
||||||
#include "Luau/UnionFind.h"
|
#include "Luau/UnionFind.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -39,13 +40,29 @@ struct EGraph final
|
||||||
return unionfind.find(id);
|
return unionfind.find(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Per the egg paper, definition 2.2 (Canonicalization):
|
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
|
||||||
//
|
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
|
||||||
// An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where
|
std::optional<Id> lookup(L enode) const
|
||||||
// canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...).
|
{
|
||||||
//
|
for (Id& id : enode.operands())
|
||||||
// Doing so requires sketching out `Luau::EqSat::Language` which
|
id = find(id);
|
||||||
// I want to do at a later time for the time being. Will revisit.
|
|
||||||
|
if (auto it = hashcons.find(enode); it != hashcons.end())
|
||||||
|
return it->second;
|
||||||
|
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: `add`. For now, we call it shoveItIn so it's obvious it's just for testing.
|
||||||
|
Id shoveItIn(L enode)
|
||||||
|
{
|
||||||
|
if (auto id = lookup(enode))
|
||||||
|
return *id;
|
||||||
|
|
||||||
|
Id id{hashcons.size()};
|
||||||
|
hashcons.insert_or_assign(enode, id);
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/// A union-find data structure 𝑈 stores an equivalence relation over e-class ids.
|
/// A union-find data structure 𝑈 stores an equivalence relation over e-class ids.
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "Luau/Id.h"
|
#include "Luau/Id.h"
|
||||||
|
#include "Luau/Slice.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
@ -63,6 +64,11 @@ struct Atom
|
||||||
{
|
{
|
||||||
T value;
|
T value;
|
||||||
|
|
||||||
|
Slice<Id> operands()
|
||||||
|
{
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
bool operator==(const Atom& rhs) const
|
bool operator==(const Atom& rhs) const
|
||||||
{
|
{
|
||||||
return value == rhs.value;
|
return value == rhs.value;
|
||||||
|
@ -123,8 +129,13 @@ public:
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Slice<Id> operands()
|
||||||
|
{
|
||||||
|
return Slice{array};
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
Id field() const
|
const Id& field() const
|
||||||
{
|
{
|
||||||
static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Fields>...>);
|
static_assert(std::disjunction_v<std::is_same<std::decay_t<T>, Fields>...>);
|
||||||
return array[getIndex<T>()];
|
return array[getIndex<T>()];
|
||||||
|
@ -175,6 +186,7 @@ private:
|
||||||
using FnDtor = void (*)(void*);
|
using FnDtor = void (*)(void*);
|
||||||
using FnPred = bool (*)(const void*, const void*);
|
using FnPred = bool (*)(const void*, const void*);
|
||||||
using FnHash = size_t (*)(const void*);
|
using FnHash = size_t (*)(const void*);
|
||||||
|
using FnOper = Slice<Id> (*)(void*);
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void fnCopy(void* dst, const void* src)
|
static void fnCopy(void* dst, const void* src)
|
||||||
|
@ -206,11 +218,18 @@ private:
|
||||||
return typename T::Hash{}(*static_cast<const T*>(buffer));
|
return typename T::Hash{}(*static_cast<const T*>(buffer));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static Slice<Id> fnOper(void* buffer)
|
||||||
|
{
|
||||||
|
return static_cast<T*>(buffer)->operands();
|
||||||
|
}
|
||||||
|
|
||||||
static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy<Ts>...};
|
static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy<Ts>...};
|
||||||
static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...};
|
static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove<Ts>...};
|
||||||
static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...};
|
static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor<Ts>...};
|
||||||
static constexpr FnPred tablePred[sizeof...(Ts)] = {&fnPred<Ts>...};
|
static constexpr FnPred tablePred[sizeof...(Ts)] = {&fnPred<Ts>...};
|
||||||
static constexpr FnHash tableHash[sizeof...(Ts)] = {&fnHash<Ts>...};
|
static constexpr FnHash tableHash[sizeof...(Ts)] = {&fnHash<Ts>...};
|
||||||
|
static constexpr FnOper tableOper[sizeof...(Ts)] = {&fnOper<Ts>...};
|
||||||
|
|
||||||
static constexpr int getIndexFromTag(const char* tag)
|
static constexpr int getIndexFromTag(const char* tag)
|
||||||
{
|
{
|
||||||
|
@ -267,6 +286,18 @@ public:
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int index() const
|
||||||
|
{
|
||||||
|
return getIndexFromTag(tag);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// You should never call this function with the intention of mutating the `Id`.
|
||||||
|
/// Reading is ok, but you should also never assume that these `Id`s are stable.
|
||||||
|
Slice<Id> operands()
|
||||||
|
{
|
||||||
|
return tableOper[getIndexFromTag(tag)](&buffer);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
const T* get() const
|
const T* get() const
|
||||||
{
|
{
|
||||||
|
|
77
EqSat/include/Luau/Slice.h
Normal file
77
EqSat/include/Luau/Slice.h
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Luau/Common.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
namespace Luau::EqSat
|
||||||
|
{
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
class Slice
|
||||||
|
{
|
||||||
|
T* _data;
|
||||||
|
size_t _size;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Slice()
|
||||||
|
: _data(nullptr)
|
||||||
|
, _size(0)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Use this constructor if you have a dynamically sized vector.
|
||||||
|
/// The slice is valid for as long as the backing vector has not moved
|
||||||
|
/// elsewhere in memory.
|
||||||
|
///
|
||||||
|
/// In general, a slice should never be used from vectors except for
|
||||||
|
/// any vectors whose size are statically unknown, but remains fixed
|
||||||
|
/// upon the construction of such a slice over a vector.
|
||||||
|
Slice(T* first, size_t last)
|
||||||
|
: _data(first)
|
||||||
|
, _size(last)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
template<size_t I>
|
||||||
|
explicit Slice(std::array<T, I>& array)
|
||||||
|
: _data(array.data())
|
||||||
|
, _size(array.size())
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
T* data() const
|
||||||
|
{
|
||||||
|
return _data;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const
|
||||||
|
{
|
||||||
|
return _size;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool empty() const
|
||||||
|
{
|
||||||
|
return _size == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
T& operator[](size_t i) const
|
||||||
|
{
|
||||||
|
LUAU_ASSERT(i < _size);
|
||||||
|
return _data[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
T* begin() const
|
||||||
|
{
|
||||||
|
return _data;
|
||||||
|
}
|
||||||
|
|
||||||
|
T* end() const
|
||||||
|
{
|
||||||
|
return _data + _size;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
|
@ -17,6 +17,7 @@ Id UnionFind::find(Id id) const
|
||||||
{
|
{
|
||||||
LUAU_ASSERT(size_t(id) < parents.size());
|
LUAU_ASSERT(size_t(id) < parents.size());
|
||||||
|
|
||||||
|
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
|
||||||
while (id != parents[size_t(id)])
|
while (id != parents[size_t(id)])
|
||||||
id = parents[size_t(id)];
|
id = parents[size_t(id)];
|
||||||
|
|
||||||
|
|
|
@ -300,6 +300,7 @@ target_sources(Luau.EqSat PRIVATE
|
||||||
EqSat/include/Luau/EGraph.h
|
EqSat/include/Luau/EGraph.h
|
||||||
EqSat/include/Luau/Id.h
|
EqSat/include/Luau/Id.h
|
||||||
EqSat/include/Luau/Language.h
|
EqSat/include/Luau/Language.h
|
||||||
|
EqSat/include/Luau/Slice.h
|
||||||
EqSat/include/Luau/UnionFind.h
|
EqSat/include/Luau/UnionFind.h
|
||||||
|
|
||||||
EqSat/src/Id.cpp
|
EqSat/src/Id.cpp
|
||||||
|
@ -431,6 +432,7 @@ if(TARGET Luau.UnitTest)
|
||||||
tests/Differ.test.cpp
|
tests/Differ.test.cpp
|
||||||
tests/EqSat.language.test.cpp
|
tests/EqSat.language.test.cpp
|
||||||
tests/EqSat.propositional.test.cpp
|
tests/EqSat.propositional.test.cpp
|
||||||
|
tests/EqSat.slice.test.cpp
|
||||||
tests/Error.test.cpp
|
tests/Error.test.cpp
|
||||||
tests/Fixture.cpp
|
tests/Fixture.cpp
|
||||||
tests/Fixture.h
|
tests/Fixture.h
|
||||||
|
|
|
@ -127,4 +127,19 @@ TEST_CASE("node_field")
|
||||||
CHECK(right != left2);
|
CHECK(right != left2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("language_operands")
|
||||||
|
{
|
||||||
|
Value v1{I32{0}};
|
||||||
|
CHECK(v1.operands().empty());
|
||||||
|
|
||||||
|
Value v2{Add{EqSat::Id{0}, EqSat::Id{1}}};
|
||||||
|
const Add* add = v2.get<Add>();
|
||||||
|
REQUIRE(add);
|
||||||
|
|
||||||
|
EqSat::Slice<EqSat::Id> actual = v2.operands();
|
||||||
|
CHECK(actual.size() == 2);
|
||||||
|
CHECK(actual[0] == add->field<Left>());
|
||||||
|
CHECK(actual[1] == add->field<Right>());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_SUITE_END();
|
TEST_SUITE_END();
|
||||||
|
|
|
@ -1,8 +1,44 @@
|
||||||
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||||
#include <doctest.h>
|
#include <doctest.h>
|
||||||
|
|
||||||
// Var, Bool, And, Or, Not, Implies
|
#include "Luau/EGraph.h"
|
||||||
|
#include "Luau/Language.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
LUAU_EQSAT_ATOM(Var, std::string);
|
||||||
|
LUAU_EQSAT_ATOM(Bool, bool);
|
||||||
|
|
||||||
|
LUAU_EQSAT_FIELD(Negated);
|
||||||
|
LUAU_EQSAT_UNARY_NODE(Not, Negated);
|
||||||
|
|
||||||
|
LUAU_EQSAT_FIELD(Left);
|
||||||
|
LUAU_EQSAT_FIELD(Right);
|
||||||
|
LUAU_EQSAT_BINARY_NODE(And, Left, Right);
|
||||||
|
LUAU_EQSAT_BINARY_NODE(Or, Left, Right);
|
||||||
|
|
||||||
|
LUAU_EQSAT_FIELD(Antecedent);
|
||||||
|
LUAU_EQSAT_FIELD(Consequent);
|
||||||
|
LUAU_EQSAT_BINARY_NODE(Implies, Antecedent, Consequent);
|
||||||
|
|
||||||
|
using namespace Luau;
|
||||||
|
|
||||||
|
using PropositionalLogic = EqSat::Language<Var, Bool, Not, And, Or, Implies>;
|
||||||
|
|
||||||
|
struct ConstantFold
|
||||||
|
{
|
||||||
|
using Data = std::optional<bool>;
|
||||||
|
};
|
||||||
|
|
||||||
|
using EGraph = EqSat::EGraph<PropositionalLogic, ConstantFold>;
|
||||||
|
|
||||||
TEST_SUITE_BEGIN("EqSatPropositionalLogic");
|
TEST_SUITE_BEGIN("EqSatPropositionalLogic");
|
||||||
|
|
||||||
|
TEST_CASE("egraph_hashconsing")
|
||||||
|
{
|
||||||
|
EGraph egraph;
|
||||||
|
|
||||||
|
CHECK(egraph.shoveItIn(Bool{true}) == egraph.shoveItIn(Bool{true}));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_SUITE_END();
|
TEST_SUITE_END();
|
||||||
|
|
58
tests/EqSat.slice.test.cpp
Normal file
58
tests/EqSat.slice.test.cpp
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
|
||||||
|
#include <doctest.h>
|
||||||
|
|
||||||
|
#include "Luau/Slice.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using namespace Luau;
|
||||||
|
|
||||||
|
TEST_SUITE_BEGIN("EqSatLanguage");
|
||||||
|
|
||||||
|
TEST_CASE("slice_is_a_view_over_array")
|
||||||
|
{
|
||||||
|
std::array<int, 8> a{1, 2, 3, 4, 5, 6, 7, 8};
|
||||||
|
|
||||||
|
EqSat::Slice<int> slice{a};
|
||||||
|
|
||||||
|
CHECK(slice.data() == a.data());
|
||||||
|
CHECK(slice.size() == a.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < a.size(); ++i)
|
||||||
|
{
|
||||||
|
CHECK(slice[i] == a[i]);
|
||||||
|
CHECK(&slice[i] == &a[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("slice_is_a_view_over_vector")
|
||||||
|
{
|
||||||
|
std::vector<int> vector{1, 2, 3, 4, 5, 6, 7, 8};
|
||||||
|
|
||||||
|
EqSat::Slice<int> slice{vector.data(), vector.size()};
|
||||||
|
|
||||||
|
CHECK(slice.data() == vector.data());
|
||||||
|
CHECK(slice.size() == vector.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < vector.size(); ++i)
|
||||||
|
{
|
||||||
|
CHECK(slice[i] == vector[i]);
|
||||||
|
CHECK(&slice[i] == &vector[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("mutate_via_slice")
|
||||||
|
{
|
||||||
|
std::array<int, 2> a{1, 2};
|
||||||
|
CHECK(a[0] == 1);
|
||||||
|
CHECK(a[1] == 2);
|
||||||
|
|
||||||
|
EqSat::Slice<int> slice{a};
|
||||||
|
slice[0] = 42;
|
||||||
|
slice[1] = 37;
|
||||||
|
|
||||||
|
CHECK(a[0] == 42);
|
||||||
|
CHECK(a[1] == 37);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_SUITE_END();
|
Loading…
Add table
Reference in a new issue