diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index 4f9e56fd..e319b5f1 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -19,6 +19,9 @@ struct UnionFind final private: std::vector parents; std::vector ranks; + +private: + Id canonicalize(Id id) const; }; } // namespace Luau::EqSat diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 09418a4a..d6f13a04 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -17,20 +17,14 @@ Id UnionFind::makeSet() Id UnionFind::find(Id id) const { - LUAU_ASSERT(size_t(id) < parents.size()); - - // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. - while (id != parents[size_t(id)]) - id = parents[size_t(id)]; - - return id; + return canonicalize(id); } Id UnionFind::find(Id id) { LUAU_ASSERT(size_t(id) < parents.size()); - Id set = const_cast(this)->find(id); + Id set = canonicalize(id); // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. while (id != parents[size_t(id)]) @@ -62,4 +56,15 @@ void UnionFind::merge(Id a, Id b) ranks[size_t(aSet)]++; } +Id UnionFind::canonicalize(Id id) const +{ + LUAU_ASSERT(size_t(id) < parents.size()); + + // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. + while (id != parents[size_t(id)]) + id = parents[size_t(id)]; + + return id; +} + } // namespace Luau::EqSat