diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index dd886a44..4f9e56fd 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -13,10 +13,12 @@ struct UnionFind final { Id makeSet(); Id find(Id id) const; + Id find(Id id); void merge(Id a, Id b); private: std::vector parents; + std::vector ranks; }; } // namespace Luau::EqSat diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 04d9ba74..f42b8a17 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -10,6 +10,8 @@ Id UnionFind::makeSet() { Id id{parents.size()}; parents.push_back(id); + ranks.push_back(0); + return id; } @@ -24,12 +26,37 @@ Id UnionFind::find(Id id) const return id; } +Id UnionFind::find(Id id) +{ + LUAU_ASSERT(size_t(id) < parents.size()); + + // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. + if (id != parents[size_t(id)]) + // Note: we don't update the ranks here since a rank + // represents the upper bound on the maximum depth of a tree + parents[size_t(id)] = find(parents[size_t(id)]); + + return parents[size_t(id)]; +} + void UnionFind::merge(Id a, Id b) { LUAU_ASSERT(size_t(a) < parents.size()); LUAU_ASSERT(size_t(b) < parents.size()); + + Id aSet = find(a); + Id bSet = find(b); + if (aSet == bSet) + return; + + // Ensure that the rank of set A is greater than the rank of set B + if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) + std::swap(a, b); parents[size_t(b)] = a; + + if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) + ranks[size_t(aSet)]++; } } // namespace Luau::EqSat