Consistently use struct.

This commit is contained in:
Alexander McCord 2024-07-15 16:33:53 -07:00
parent a6b62ed0be
commit 201352eedf
3 changed files with 87 additions and 86 deletions

View file

@ -15,7 +15,7 @@ namespace Luau::EqSat
{
template<typename L, typename N>
class EGraph;
struct EGraph;
template<typename L, typename N>
struct Analysis final
@ -58,8 +58,80 @@ struct EClass final
/// See <https://arxiv.org/pdf/2004.03082>.
template<typename L, typename N>
class EGraph final
struct EGraph final
{
Id find(Id id) const
{
return unionfind.find(id);
}
std::optional<Id> lookup(const L& enode) const
{
LUAU_ASSERT(isCanonical(enode));
if (auto it = hashcons.find(enode); it != hashcons.end())
return it->second;
return std::nullopt;
}
Id add(L enode)
{
canonicalize(enode);
if (auto id = lookup(enode))
return *id;
Id id = makeEClass(enode);
return id;
}
void merge(Id id1, Id id2)
{
id1 = find(id1);
id2 = find(id2);
if (id1 == id2)
return;
unionfind.merge(id1, id2);
EClass<L, typename N::Data>& eclass1 = get(id1);
EClass<L, typename N::Data> eclass2 = std::move(get(id2));
classes.erase(id2);
worklist.reserve(worklist.size() + eclass2.parents.size());
for (auto [enode, id] : eclass2.parents)
worklist.push_back({std::move(enode), id});
analysis.join(eclass1.data, eclass2.data);
}
void rebuild()
{
while (!worklist.empty())
{
auto [enode, id] = worklist.back();
worklist.pop_back();
repair(get(find(id)));
}
}
size_t size() const
{
return classes.size();
}
EClass<L, typename N::Data>& operator[](Id id)
{
return get(find(id));
}
const EClass<L, typename N::Data>& operator[](Id id) const
{
return const_cast<EGraph*>(this)->get(find(id));
}
private:
Analysis<L, N> analysis;
/// A union-find data structure 𝑈 stores an equivalence relation over e-class ids.
@ -151,78 +223,6 @@ private:
eclass.parents.emplace_back(std::move(node.key()), node.mapped());
}
}
public:
Id find(Id id) const
{
return unionfind.find(id);
}
std::optional<Id> lookup(const L& enode) const
{
LUAU_ASSERT(isCanonical(enode));
if (auto it = hashcons.find(enode); it != hashcons.end())
return it->second;
return std::nullopt;
}
Id add(L enode)
{
canonicalize(enode);
if (auto id = lookup(enode))
return *id;
Id id = makeEClass(enode);
return id;
}
void merge(Id id1, Id id2)
{
id1 = find(id1);
id2 = find(id2);
if (id1 == id2)
return;
unionfind.merge(id1, id2);
EClass<L, typename N::Data>& eclass1 = get(id1);
EClass<L, typename N::Data> eclass2 = std::move(get(id2));
classes.erase(id2);
worklist.reserve(worklist.size() + eclass2.parents.size());
for (auto [enode, id] : eclass2.parents)
worklist.push_back({std::move(enode), id});
analysis.join(eclass1.data, eclass2.data);
}
void rebuild()
{
while (!worklist.empty())
{
auto [enode, id] = worklist.back();
worklist.pop_back();
repair(get(find(id)));
}
}
size_t size() const
{
return classes.size();
}
EClass<L, typename N::Data>& operator[](Id id)
{
return get(find(id));
}
const EClass<L, typename N::Data>& operator[](Id id) const
{
return const_cast<EGraph*>(this)->get(find(id));
}
};
} // namespace Luau::EqSat

View file

@ -158,12 +158,10 @@ struct Field : FieldBase
};
template<typename Phantom, typename... Fields>
class NodeFields
struct NodeFields
{
static_assert(std::conjunction<std::is_base_of<FieldBase, Fields>...>::value);
std::array<Id, sizeof...(Fields)> array;
template<typename T>
static constexpr int getIndex()
{
@ -218,17 +216,17 @@ public:
return languageHash(value.array);
}
};
private:
std::array<Id, sizeof...(Fields)> array;
};
template<typename... Ts>
class Language final
struct Language final
{
Variant<Ts...> v;
template<typename T>
using WithinDomain = std::disjunction<std::is_same<std::decay_t<T>, Ts>...>;
public:
template<typename T>
Language(T&& t, std::enable_if_t<WithinDomain<T>::value>* = 0) noexcept
: v(std::forward<T>(t))
@ -298,6 +296,9 @@ public:
return seed;
}
};
private:
Variant<Ts...> v;
};
} // namespace Luau::EqSat

View file

@ -10,12 +10,8 @@ namespace Luau::EqSat
{
template<typename T>
class Slice final
struct Slice final
{
T* _data;
size_t _size;
public:
Slice()
: _data(nullptr)
, _size(0)
@ -63,6 +59,10 @@ public:
return _data[i];
}
public:
T* _data;
size_t _size;
public:
T* begin() const
{