diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index bd8bdaaf..79128af9 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -1,12 +1,22 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Id.h" + +#include + namespace Luau::EqSat { /// See . struct UnionFind { + Id makeSet(); + Id find(Id id) const; + void merge(Id a, Id b); + +private: + std::vector parents; }; } // namespace Luau::EqSat diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 0bf171b5..9b991422 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -1,9 +1,34 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/UnionFind.h" +#include "Luau/Common.h" + namespace Luau::EqSat { +Id UnionFind::makeSet() +{ + Id id{parents.size()}; + parents.push_back(id); + return id; +} +Id UnionFind::find(Id id) const +{ + LUAU_ASSERT(size_t(id) < parents.size()); + + while (id != parents[size_t(id)]) + id = parents[size_t(id)]; + + return id; +} + +void UnionFind::merge(Id a, Id b) +{ + LUAU_ASSERT(size_t(a) < parents.size()); + LUAU_ASSERT(size_t(b) < parents.size()); + + parents[size_t(b)] = a; +} } // namespace Luau::EqSat