optimize unionfind

This commit is contained in:
birds3345 2024-07-17 01:21:07 -04:00
parent 4f917420d7
commit 06c99b75de
2 changed files with 29 additions and 0 deletions

View file

@ -13,10 +13,12 @@ struct UnionFind final
{
Id makeSet();
Id find(Id id) const;
Id find(Id id);
void merge(Id a, Id b);
private:
std::vector<Id> parents;
std::vector<unsigned int> ranks;
};
} // namespace Luau::EqSat

View file

@ -10,6 +10,8 @@ Id UnionFind::makeSet()
{
Id id{parents.size()};
parents.push_back(id);
ranks.push_back(0);
return id;
}
@ -24,12 +26,37 @@ Id UnionFind::find(Id id) const
return id;
}
Id UnionFind::find(Id id)
{
LUAU_ASSERT(size_t(id) < parents.size());
// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
if (id != parents[size_t(id)])
// Note: we don't update the ranks here since a rank
// represents the upper bound on the maximum depth of a tree
parents[size_t(id)] = find(parents[size_t(id)]);
return parents[size_t(id)];
}
void UnionFind::merge(Id a, Id b)
{
LUAU_ASSERT(size_t(a) < parents.size());
LUAU_ASSERT(size_t(b) < parents.size());
Id aSet = find(a);
Id bSet = find(b);
if (aSet == bSet)
return;
// Ensure that the rank of set A is greater than the rank of set B
if (ranks[size_t(aSet)] < ranks[size_t(bSet)])
std::swap(a, b);
parents[size_t(b)] = a;
if (ranks[size_t(aSet)] == ranks[size_t(bSet)])
ranks[size_t(aSet)]++;
}
} // namespace Luau::EqSat