Consistently use struct.

This commit is contained in:
Alexander McCord 2024-07-15 16:33:53 -07:00
parent a6b62ed0be
commit 201352eedf
3 changed files with 87 additions and 86 deletions

View file

@ -15,7 +15,7 @@ namespace Luau::EqSat
{ {
template<typename L, typename N> template<typename L, typename N>
class EGraph; struct EGraph;
template<typename L, typename N> template<typename L, typename N>
struct Analysis final struct Analysis final
@ -58,8 +58,80 @@ struct EClass final
/// See <https://arxiv.org/pdf/2004.03082>. /// See <https://arxiv.org/pdf/2004.03082>.
template<typename L, typename N> template<typename L, typename N>
class EGraph final struct EGraph final
{ {
Id find(Id id) const
{
return unionfind.find(id);
}
std::optional<Id> lookup(const L& enode) const
{
LUAU_ASSERT(isCanonical(enode));
if (auto it = hashcons.find(enode); it != hashcons.end())
return it->second;
return std::nullopt;
}
Id add(L enode)
{
canonicalize(enode);
if (auto id = lookup(enode))
return *id;
Id id = makeEClass(enode);
return id;
}
void merge(Id id1, Id id2)
{
id1 = find(id1);
id2 = find(id2);
if (id1 == id2)
return;
unionfind.merge(id1, id2);
EClass<L, typename N::Data>& eclass1 = get(id1);
EClass<L, typename N::Data> eclass2 = std::move(get(id2));
classes.erase(id2);
worklist.reserve(worklist.size() + eclass2.parents.size());
for (auto [enode, id] : eclass2.parents)
worklist.push_back({std::move(enode), id});
analysis.join(eclass1.data, eclass2.data);
}
void rebuild()
{
while (!worklist.empty())
{
auto [enode, id] = worklist.back();
worklist.pop_back();
repair(get(find(id)));
}
}
size_t size() const
{
return classes.size();
}
EClass<L, typename N::Data>& operator[](Id id)
{
return get(find(id));
}
const EClass<L, typename N::Data>& operator[](Id id) const
{
return const_cast<EGraph*>(this)->get(find(id));
}
private:
Analysis<L, N> analysis; Analysis<L, N> analysis;
/// A union-find data structure 𝑈 stores an equivalence relation over e-class ids. /// A union-find data structure 𝑈 stores an equivalence relation over e-class ids.
@ -151,78 +223,6 @@ private:
eclass.parents.emplace_back(std::move(node.key()), node.mapped()); eclass.parents.emplace_back(std::move(node.key()), node.mapped());
} }
} }
public:
Id find(Id id) const
{
return unionfind.find(id);
}
std::optional<Id> lookup(const L& enode) const
{
LUAU_ASSERT(isCanonical(enode));
if (auto it = hashcons.find(enode); it != hashcons.end())
return it->second;
return std::nullopt;
}
Id add(L enode)
{
canonicalize(enode);
if (auto id = lookup(enode))
return *id;
Id id = makeEClass(enode);
return id;
}
void merge(Id id1, Id id2)
{
id1 = find(id1);
id2 = find(id2);
if (id1 == id2)
return;
unionfind.merge(id1, id2);
EClass<L, typename N::Data>& eclass1 = get(id1);
EClass<L, typename N::Data> eclass2 = std::move(get(id2));
classes.erase(id2);
worklist.reserve(worklist.size() + eclass2.parents.size());
for (auto [enode, id] : eclass2.parents)
worklist.push_back({std::move(enode), id});
analysis.join(eclass1.data, eclass2.data);
}
void rebuild()
{
while (!worklist.empty())
{
auto [enode, id] = worklist.back();
worklist.pop_back();
repair(get(find(id)));
}
}
size_t size() const
{
return classes.size();
}
EClass<L, typename N::Data>& operator[](Id id)
{
return get(find(id));
}
const EClass<L, typename N::Data>& operator[](Id id) const
{
return const_cast<EGraph*>(this)->get(find(id));
}
}; };
} // namespace Luau::EqSat } // namespace Luau::EqSat

View file

@ -158,12 +158,10 @@ struct Field : FieldBase
}; };
template<typename Phantom, typename... Fields> template<typename Phantom, typename... Fields>
class NodeFields struct NodeFields
{ {
static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value); static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value);
std::array<Id, sizeof...(Fields)> array;
template<typename T> template<typename T>
static constexpr int getIndex() static constexpr int getIndex()
{ {
@ -218,17 +216,17 @@ public:
return languageHash(value.array); return languageHash(value.array);
} }
}; };
private:
std::array<Id, sizeof...(Fields)> array;
}; };
template<typename... Ts> template<typename... Ts>
class Language final struct Language final
{ {
Variant<Ts...> v;
template<typename T> template<typename T>
using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>; using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>;
public:
template<typename T> template<typename T>
Language(T&& t, std::enable_if_t<WithinDomain<T>::value>* = 0) noexcept Language(T&& t, std::enable_if_t<WithinDomain<T>::value>* = 0) noexcept
: v(std::forward<T>(t)) : v(std::forward<T>(t))
@ -298,6 +296,9 @@ public:
return seed; return seed;
} }
}; };
private:
Variant<Ts...> v;
}; };
} // namespace Luau::EqSat } // namespace Luau::EqSat

View file

@ -10,12 +10,8 @@ namespace Luau::EqSat
{ {
template<typename T> template<typename T>
class Slice final struct Slice final
{ {
T* _data;
size_t _size;
public:
Slice() Slice()
: _data(nullptr) : _data(nullptr)
, _size(0) , _size(0)
@ -63,6 +59,10 @@ public:
return _data[i]; return _data[i];
} }
public:
T* _data;
size_t _size;
public: public:
T* begin() const T* begin() const
{ {