diff --git a/.gitignore b/.gitignore index 8de6d91d..8e5c95dd 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ /luau-analyze /luau-compile __pycache__ +.cache +.clangd diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 401784ae..dfcf8b6b 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -499,10 +499,11 @@ struct ClassType Tags tags; std::shared_ptr userData; ModuleName definitionModuleName; + std::optional definitionLocation; std::optional indexer; ClassType(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, - std::shared_ptr userData, ModuleName definitionModuleName) + std::shared_ptr userData, ModuleName definitionModuleName, std::optional definitionLocation) : name(name) , props(props) , parent(parent) @@ -510,11 +511,13 @@ struct ClassType , tags(tags) , userData(userData) , definitionModuleName(definitionModuleName) + , definitionLocation(definitionLocation) { } ClassType(Name name, Props props, std::optional parent, std::optional metatable, Tags tags, - std::shared_ptr userData, ModuleName definitionModuleName, std::optional indexer) + std::shared_ptr userData, ModuleName definitionModuleName, std::optional definitionLocation, + std::optional indexer) : name(name) , props(props) , parent(parent) @@ -522,6 +525,7 @@ struct ClassType , tags(tags) , userData(userData) , definitionModuleName(definitionModuleName) + , definitionLocation(definitionLocation) , indexer(indexer) { } diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index aa5e4ca8..bcfc21dd 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -1349,7 +1349,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas Name className(declaredClass->name.value); - TypeId classTy = arena->addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, module->name)); + TypeId classTy = arena->addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, module->name, declaredClass->location)); ClassType* ctv = getMutable(classTy); TypeId metaTy = arena->addType(TableType{TableState::Sealed, scope->level, scope.get()}); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index fae4b60c..a0128b41 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -117,7 +117,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a { if (alwaysClone) { - ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName, a.indexer}; + ClassType clone{a.name, a.props, a.parent, a.metatable, a.tags, a.userData, a.definitionModuleName, a.definitionLocation, a.indexer}; return dest.addType(std::move(clone)); } else diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 6735e367..9479d0ef 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -969,7 +969,7 @@ BuiltinTypes::BuiltinTypes() , threadType(arena->addType(Type{PrimitiveType{PrimitiveType::Thread}, /*persistent*/ true})) , bufferType(arena->addType(Type{PrimitiveType{PrimitiveType::Buffer}, /*persistent*/ true})) , functionType(arena->addType(Type{PrimitiveType{PrimitiveType::Function}, /*persistent*/ true})) - , classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}}, /*persistent*/ true})) + , classType(arena->addType(Type{ClassType{"class", {}, std::nullopt, std::nullopt, {}, {}, {}, {}}, /*persistent*/ true})) , tableType(arena->addType(Type{PrimitiveType{PrimitiveType::Table}, /*persistent*/ true})) , emptyTableType(arena->addType(Type{TableType{TableState::Sealed, TypeLevel{}, nullptr}, /*persistent*/ true})) , trueType(arena->addType(Type{SingletonType{BooleanSingleton{true}}, /*persistent*/ true})) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index d4c25c34..00d683dd 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1733,7 +1733,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& de Name className(declaredClass.name.value); - TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModule->name)); + TypeId classTy = addType(ClassType(className, {}, superTy, std::nullopt, {}, {}, currentModule->name, declaredClass.location)); ClassType* ctv = getMutable(classTy); TypeId metaTy = addType(TableType{TableState::Sealed, scope->level}); diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b7e551e..34e104e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,6 +28,7 @@ add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) add_library(Luau.Config STATIC) add_library(Luau.Analysis STATIC) +add_library(Luau.EqSat STATIC) add_library(Luau.CodeGen STATIC) add_library(Luau.VM STATIC) add_library(isocline STATIC) @@ -83,7 +84,11 @@ target_link_libraries(Luau.Config PUBLIC Luau.Ast) target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_include_directories(Luau.Analysis PUBLIC Analysis/include) -target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.Config) +target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.EqSat Luau.Config) + +target_compile_features(Luau.EqSat PUBLIC cxx_std_17) +target_include_directories(Luau.EqSat PUBLIC EqSat/include) +target_link_libraries(Luau.EqSat PUBLIC Luau.Common) target_compile_features(Luau.CodeGen PRIVATE cxx_std_17) target_include_directories(Luau.CodeGen PUBLIC CodeGen/include) @@ -141,6 +146,7 @@ endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) +target_compile_options(Luau.EqSat PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.CLI.lib PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.CodeGen PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) @@ -263,13 +269,13 @@ endif() add_subdirectory(fuzz) # validate dependencies for internal libraries -foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.CodeGen Luau.VM) +foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.EqSat Luau.CodeGen Luau.VM) if(TARGET ${LIB}) get_target_property(DEPENDS ${LIB} LINK_LIBRARIES) if(LIB MATCHES "CodeGen|VM" AND DEPENDS MATCHES "Ast|Analysis|Config|Compiler") message(FATAL_ERROR ${LIB} " is a runtime component but it depends on one of the offline components") endif() - if(LIB MATCHES "Ast|Analysis|Compiler" AND DEPENDS MATCHES "CodeGen|VM") + if(LIB MATCHES "Ast|Analysis|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM") message(FATAL_ERROR ${LIB} " is an offline component but it depends on one of the runtime components") endif() if(LIB MATCHES "Ast|Compiler" AND DEPENDS MATCHES "Analysis|Config") diff --git a/Analysis/include/Luau/Variant.h b/Common/include/Luau/Variant.h similarity index 100% rename from Analysis/include/Luau/Variant.h rename to Common/include/Luau/Variant.h diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h new file mode 100644 index 00000000..abccd70c --- /dev/null +++ b/EqSat/include/Luau/EGraph.h @@ -0,0 +1,228 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Id.h" +#include "Luau/Language.h" +#include "Luau/UnionFind.h" +#include "Luau/VecDeque.h" + +#include +#include +#include + +namespace Luau::EqSat +{ + +template +struct EGraph; + +template +struct Analysis final +{ + N analysis; + + using D = typename N::Data; + + template + static D fnMake(const N& analysis, const EGraph& egraph, const L& enode) + { + return analysis.make(egraph, *enode.template get()); + } + + template + D make(const EGraph& egraph, const Language& enode) const + { + using FnMake = D (*)(const N&, const EGraph&, const L&); + static constexpr FnMake tableMake[sizeof...(Ts)] = {&fnMake...}; + + return tableMake[enode.index()](analysis, egraph, enode); + } + + void join(D& a, const D& b) const + { + return analysis.join(a, b); + } +}; + +/// Each e-class is a set of e-nodes representing equivalent terms from a given language, +/// and an e-node is a function symbol paired with a list of children e-classes. +template +struct EClass final +{ + Id id; + std::vector nodes; + D data; + std::vector> parents; +}; + +/// See . +template +struct EGraph final +{ + Id find(Id id) const + { + return unionfind.find(id); + } + + std::optional lookup(const L& enode) const + { + LUAU_ASSERT(isCanonical(enode)); + + if (auto it = hashcons.find(enode); it != hashcons.end()) + return it->second; + + return std::nullopt; + } + + Id add(L enode) + { + canonicalize(enode); + + if (auto id = lookup(enode)) + return *id; + + Id id = makeEClass(enode); + return id; + } + + void merge(Id id1, Id id2) + { + id1 = find(id1); + id2 = find(id2); + if (id1 == id2) + return; + + unionfind.merge(id1, id2); + + EClass& eclass1 = get(id1); + EClass eclass2 = std::move(get(id2)); + classes.erase(id2); + + worklist.reserve(worklist.size() + eclass2.parents.size()); + for (auto [enode, id] : eclass2.parents) + worklist.push_back({std::move(enode), id}); + + analysis.join(eclass1.data, eclass2.data); + } + + void rebuild() + { + while (!worklist.empty()) + { + auto [enode, id] = worklist.back(); + worklist.pop_back(); + repair(get(find(id))); + } + } + + size_t size() const + { + return classes.size(); + } + + EClass& operator[](Id id) + { + return get(find(id)); + } + + const EClass& operator[](Id id) const + { + return const_cast(this)->get(find(id)); + } + +private: + Analysis analysis; + + /// A union-find data structure 𝑈 stores an equivalence relation over e-class ids. + UnionFind unionfind; + + /// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same + /// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the + /// e-class 𝑀[find(𝑎)]. + std::unordered_map> classes; + + /// The hashcons 𝐻 is a map from e-nodes to e-class ids. + std::unordered_map hashcons; + + VecDeque> worklist; + +private: + void canonicalize(L& enode) + { + // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where + // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). + for (Id& id : enode.operands()) + id = find(id); + } + + bool isCanonical(const L& enode) const + { + bool canonical = true; + for (Id id : enode.operands()) + canonical &= (id == find(id)); + return canonical; + } + + Id makeEClass(const L& enode) + { + LUAU_ASSERT(isCanonical(enode)); + + Id id = unionfind.makeSet(); + + classes.insert_or_assign(id, EClass{ + id, + {enode}, + analysis.make(*this, enode), + {}, + }); + + for (Id operand : enode.operands()) + get(operand).parents.push_back({enode, id}); + + worklist.push_back({enode, id}); + hashcons.insert_or_assign(enode, id); + + return id; + } + + // Looks up for an eclass from a given non-canonicalized `id`. + // For a canonicalized eclass, use `get(find(id))` or `egraph[id]`. + EClass& get(Id id) + { + return classes.at(id); + } + + void repair(EClass& eclass) + { + // In the egg paper, the `repair` function makes use of two loops over the `eclass.parents` + // by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id. + // And then in another loop that follows, deduplicate it. + // + // Here, we unify the two loops. I think it's equivalent? + + // After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent. + std::unordered_map map; + for (auto& [enode, id] : eclass.parents) + { + // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. + hashcons.erase(enode); + canonicalize(enode); + hashcons.insert_or_assign(enode, find(id)); + + if (auto it = map.find(enode); it != map.end()) + merge(id, it->second); + + map.insert_or_assign(enode, find(id)); + } + + eclass.parents.clear(); + for (auto it = map.begin(); it != map.end();) + { + auto node = map.extract(it++); + eclass.parents.emplace_back(std::move(node.key()), node.mapped()); + } + } +}; + +} // namespace Luau::EqSat diff --git a/EqSat/include/Luau/Id.h b/EqSat/include/Luau/Id.h new file mode 100644 index 00000000..c56a6ab6 --- /dev/null +++ b/EqSat/include/Luau/Id.h @@ -0,0 +1,29 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +namespace Luau::EqSat +{ + +struct Id final +{ + explicit Id(size_t id); + + explicit operator size_t() const; + + bool operator==(Id rhs) const; + bool operator!=(Id rhs) const; + +private: + size_t id; +}; + +} // namespace Luau::EqSat + +template<> +struct std::hash +{ + size_t operator()(Luau::EqSat::Id id) const; +}; diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h new file mode 100644 index 00000000..c17ac577 --- /dev/null +++ b/EqSat/include/Luau/Language.h @@ -0,0 +1,304 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Id.h" +#include "Luau/LanguageHash.h" +#include "Luau/Slice.h" +#include "Luau/Variant.h" + +#include +#include +#include +#include + +#define LUAU_EQSAT_ATOM(name, t) \ + struct name : public ::Luau::EqSat::Atom \ + { \ + static constexpr const char* tag = #name; \ + using Atom::Atom; \ + } + +#define LUAU_EQSAT_NODE_ARRAY(name, ops) \ + struct name : public ::Luau::EqSat::NodeVector> \ + { \ + static constexpr const char* tag = #name; \ + using NodeVector::NodeVector; \ + } + +#define LUAU_EQSAT_NODE_VECTOR(name) \ + struct name : public ::Luau::EqSat::NodeVector> \ + { \ + static constexpr const char* tag = #name; \ + using NodeVector::NodeVector; \ + } + +#define LUAU_EQSAT_FIELD(name) \ + struct name : public ::Luau::EqSat::Field \ + { \ + } + +#define LUAU_EQSAT_NODE_FIELDS(name, ...) \ + struct name : public ::Luau::EqSat::NodeFields \ + { \ + static constexpr const char* tag = #name; \ + using NodeFields::NodeFields; \ + } + +namespace Luau::EqSat +{ + +template +struct Atom +{ + Atom(const T& value) + : _value(value) + { + } + + const T& value() const + { + return _value; + } + +public: + Slice operands() + { + return {}; + } + + Slice operands() const + { + return {}; + } + + bool operator==(const Atom& rhs) const + { + return _value == rhs._value; + } + + bool operator!=(const Atom& rhs) const + { + return !(*this == rhs); + } + + struct Hash + { + size_t operator()(const Atom& value) const + { + return languageHash(value._value); + } + }; + +private: + T _value; +}; + +template +struct NodeVector +{ + template + NodeVector(Args&&... args) + : vector{std::forward(args)...} + { + } + + Id operator[](size_t i) const + { + return vector[i]; + } + +public: + Slice operands() + { + return Slice{vector.data(), vector.size()}; + } + + Slice operands() const + { + return Slice{vector.data(), vector.size()}; + } + + bool operator==(const NodeVector& rhs) const + { + return vector == rhs.vector; + } + + bool operator!=(const NodeVector& rhs) const + { + return !(*this == rhs); + } + + struct Hash + { + size_t operator()(const NodeVector& value) const + { + return languageHash(value.vector); + } + }; + +private: + T vector; +}; + +/// Empty base class just for static_asserts. +struct FieldBase +{ + FieldBase() = delete; + + FieldBase(FieldBase&&) = delete; + FieldBase& operator=(FieldBase&&) = delete; + + FieldBase(const FieldBase&) = delete; + FieldBase& operator=(const FieldBase&) = delete; +}; + +template +struct Field : FieldBase +{ +}; + +template +struct NodeFields +{ + static_assert(std::conjunction...>::value); + + template + static constexpr int getIndex() + { + constexpr int N = sizeof...(Fields); + constexpr bool is[N] = {std::is_same_v, Fields>...}; + + for (int i = 0; i < N; ++i) + if (is[i]) + return i; + + return -1; + } + +public: + template + NodeFields(Args&&... args) + : array{std::forward(args)...} + { + } + + Slice operands() + { + return Slice{array}; + } + + Slice operands() const + { + return Slice{array.data(), array.size()}; + } + + template + Id field() const + { + static_assert(std::disjunction_v, Fields>...>); + return array[getIndex()]; + } + + bool operator==(const NodeFields& rhs) const + { + return array == rhs.array; + } + + bool operator!=(const NodeFields& rhs) const + { + return !(*this == rhs); + } + + struct Hash + { + size_t operator()(const NodeFields& value) const + { + return languageHash(value.array); + } + }; + +private: + std::array array; +}; + +template +struct Language final +{ + template + using WithinDomain = std::disjunction, Ts>...>; + + template + Language(T&& t, std::enable_if_t::value>* = 0) noexcept + : v(std::forward(t)) + { + } + + Language(const Language&) noexcept = default; + Language& operator=(const Language&) noexcept = default; + + Language(Language&&) noexcept = default; + Language& operator=(Language&&) noexcept = default; + + int index() const noexcept + { + return v.index(); + } + + /// You should never call this function with the intention of mutating the `Id`. + /// Reading is ok, but you should also never assume that these `Id`s are stable. + Slice operands() noexcept + { + return visit([](auto&& v) -> Slice { + return v.operands(); + }, v); + } + + Slice operands() const noexcept + { + return visit([](auto&& v) -> Slice { + return v.operands(); + }, v); + } + + template + T* get() noexcept + { + static_assert(WithinDomain::value); + return v.template get_if(); + } + + template + const T* get() const noexcept + { + static_assert(WithinDomain::value); + return v.template get_if(); + } + + bool operator==(const Language& rhs) const noexcept + { + return v == rhs.v; + } + + bool operator!=(const Language& rhs) const noexcept + { + return !(*this == rhs); + } + +public: + struct Hash + { + size_t operator()(const Language& language) const + { + size_t seed = std::hash{}(language.index()); + hashCombine(seed, visit([](auto&& v) { + return typename std::decay_t::Hash{}(v); + }, language.v)); + return seed; + } + }; + +private: + Variant v; +}; + +} // namespace Luau::EqSat diff --git a/EqSat/include/Luau/LanguageHash.h b/EqSat/include/Luau/LanguageHash.h new file mode 100644 index 00000000..8c5f837c --- /dev/null +++ b/EqSat/include/Luau/LanguageHash.h @@ -0,0 +1,56 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +namespace Luau::EqSat +{ + +template +struct LanguageHash +{ + size_t operator()(const T& t, decltype(std::hash{}(std::declval()))* = 0) const + { + return std::hash{}(t); + } +}; + +template +size_t languageHash(const T& lang) +{ + return LanguageHash{}(lang); +} + +inline void hashCombine(size_t& seed, size_t hash) +{ + // Golden Ratio constant used for better hash scattering + // See https://softwareengineering.stackexchange.com/a/402543 + seed ^= hash + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +template +struct LanguageHash> +{ + size_t operator()(const std::array& array) const + { + size_t seed = 0; + for (const T& t : array) + hashCombine(seed, languageHash(t)); + return seed; + } +}; + +template +struct LanguageHash> +{ + size_t operator()(const std::vector& vector) const + { + size_t seed = 0; + for (const T& t : vector) + hashCombine(seed, languageHash(t)); + return seed; + } +}; + +} // namespace Luau::EqSat diff --git a/EqSat/include/Luau/Slice.h b/EqSat/include/Luau/Slice.h new file mode 100644 index 00000000..c1c8f098 --- /dev/null +++ b/EqSat/include/Luau/Slice.h @@ -0,0 +1,78 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include +#include + +namespace Luau::EqSat +{ + +template +struct Slice final +{ + Slice() + : _data(nullptr) + , _size(0) + { + } + + /// Use this constructor if you have a dynamically sized vector. + /// The slice is valid for as long as the backing vector has not moved + /// elsewhere in memory. + /// + /// In general, a slice should never be used from vectors except for + /// any vectors whose size are statically unknown, but remains fixed + /// upon the construction of such a slice over a vector. + Slice(T* first, size_t last) + : _data(first) + , _size(last) + { + } + + template + explicit Slice(std::array& array) + : _data(array.data()) + , _size(array.size()) + { + } + + T* data() const + { + return _data; + } + + size_t size() const + { + return _size; + } + + bool empty() const + { + return _size == 0; + } + + T& operator[](size_t i) const + { + LUAU_ASSERT(i < _size); + return _data[i]; + } + +public: + T* _data; + size_t _size; + +public: + T* begin() const + { + return _data; + } + + T* end() const + { + return _data + _size; + } +}; + +} // namespace Luau::EqSat diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h new file mode 100644 index 00000000..559ee119 --- /dev/null +++ b/EqSat/include/Luau/UnionFind.h @@ -0,0 +1,27 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Id.h" + +#include + +namespace Luau::EqSat +{ + +/// See . +struct UnionFind final +{ + Id makeSet(); + Id find(Id id) const; + Id find(Id id); + void merge(Id a, Id b); + +private: + std::vector parents; + std::vector ranks; + +private: + Id canonicalize(Id id) const; +}; + +} // namespace Luau::EqSat diff --git a/EqSat/src/Id.cpp b/EqSat/src/Id.cpp new file mode 100644 index 00000000..960249ba --- /dev/null +++ b/EqSat/src/Id.cpp @@ -0,0 +1,32 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Id.h" + +namespace Luau::EqSat +{ + +Id::Id(size_t id) + : id(id) +{ +} + +Id::operator size_t() const +{ + return id; +} + +bool Id::operator==(Id rhs) const +{ + return id == rhs.id; +} + +bool Id::operator!=(Id rhs) const +{ + return id != rhs.id; +} + +} // namespace Luau::EqSat + +size_t std::hash::operator()(Luau::EqSat::Id id) const +{ + return std::hash()(size_t(id)); +} diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp new file mode 100644 index 00000000..5c01e968 --- /dev/null +++ b/EqSat/src/UnionFind.cpp @@ -0,0 +1,68 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/UnionFind.h" + +#include "Luau/Common.h" + +namespace Luau::EqSat +{ + +Id UnionFind::makeSet() +{ + Id id{parents.size()}; + parents.push_back(id); + ranks.push_back(0); + + return id; +} + +Id UnionFind::find(Id id) const +{ + return canonicalize(id); +} + +Id UnionFind::find(Id id) +{ + Id set = canonicalize(id); + + // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. + while (id != parents[size_t(id)]) + { + // Note: we don't update the ranks here since a rank + // represents the upper bound on the maximum depth of a tree + Id parent = parents[size_t(id)]; + parents[size_t(id)] = set; + id = parent; + } + + return set; +} + +void UnionFind::merge(Id a, Id b) +{ + Id aSet = find(a); + Id bSet = find(b); + if (aSet == bSet) + return; + + // Ensure that the rank of set A is greater than the rank of set B + if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) + std::swap(aSet, bSet); + + parents[size_t(bSet)] = aSet; + + if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) + ranks[size_t(aSet)]++; +} + +Id UnionFind::canonicalize(Id id) const +{ + LUAU_ASSERT(size_t(id) < parents.size()); + + // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. + while (id != parents[size_t(id)]) + id = parents[size_t(id)]; + + return id; +} + +} // namespace Luau::EqSat diff --git a/Makefile b/Makefile index 0f7c1927..7eead323 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,10 @@ ANALYSIS_SOURCES=$(wildcard Analysis/src/*.cpp) ANALYSIS_OBJECTS=$(ANALYSIS_SOURCES:%=$(BUILD)/%.o) ANALYSIS_TARGET=$(BUILD)/libluauanalysis.a +EQSAT_SOURCES=$(wildcard EqSat/src/*.cpp) +EQSAT_OBJECTS=$(EQSAT_SOURCES:%=$(BUILD)/%.o) +EQSAT_TARGET=$(BUILD)/libluaueqsat.a + CODEGEN_SOURCES=$(wildcard CodeGen/src/*.cpp) CODEGEN_OBJECTS=$(CODEGEN_SOURCES:%=$(BUILD)/%.o) CODEGEN_TARGET=$(BUILD)/libluaucodegen.a @@ -69,7 +73,7 @@ ifneq ($(opt),) TESTS_ARGS+=-O$(opt) endif -OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(CONFIG_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(COMPILE_CLI_OBJECTS) $(BYTECODE_CLI_OBJECTS) $(FUZZ_OBJECTS) +OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(CONFIG_OBJECTS) $(ANALYSIS_OBJECTS) $(EQSAT_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(COMPILE_CLI_OBJECTS) $(BYTECODE_CLI_OBJECTS) $(FUZZ_OBJECTS) EXECUTABLE_ALIASES = luau luau-analyze luau-compile luau-bytecode luau-tests # common flags @@ -138,16 +142,17 @@ endif $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include $(CONFIG_OBJECTS): CXXFLAGS+=-std=c++17 -IConfig/include -ICommon/include -IAst/include -$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IConfig/include +$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include +$(EQSAT_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IEqSat/include $(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include $(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include -$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY +$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -IEqSat/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY $(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -Iextern -Iextern/isocline/include -$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IConfig/include -Iextern +$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -Iextern $(COMPILE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include $(BYTECODE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICodeGen/include -IConfig/include +$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IEqSat/include -IVM/include -ICodeGen/include -IConfig/include $(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread @@ -218,9 +223,9 @@ luau-tests: $(TESTS_TARGET) ln -fs $^ $@ # executable targets -$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) +$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) -$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(AST_TARGET) $(CONFIG_TARGET) +$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(BYTECODE_CLI_TARGET): $(BYTECODE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) @@ -228,22 +233,23 @@ $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(COMPILE_CLI_TARGET) $ $(CXX) $^ $(LDFLAGS) -o $@ # executable targets for fuzzing -fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) +fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(CXX) $^ $(LDFLAGS) -o $@ -fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator -fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator +fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator +fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator # static library targets $(AST_TARGET): $(AST_OBJECTS) $(COMPILER_TARGET): $(COMPILER_OBJECTS) $(CONFIG_TARGET): $(CONFIG_OBJECTS) $(ANALYSIS_TARGET): $(ANALYSIS_OBJECTS) +$(EQSAT_TARGET): $(EQSAT_OBJECTS) $(CODEGEN_TARGET): $(CODEGEN_OBJECTS) $(VM_TARGET): $(VM_OBJECTS) $(ISOCLINE_TARGET): $(ISOCLINE_OBJECTS) -$(AST_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(ANALYSIS_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET): +$(AST_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET): ar rcs $@ $^ # object file targets diff --git a/Sources.cmake b/Sources.cmake index 72038e70..c6bbfdbc 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -7,6 +7,7 @@ if(NOT ${CMAKE_VERSION} VERSION_LESS "3.19") Common/include/Luau/BytecodeUtils.h Common/include/Luau/DenseHash.h Common/include/Luau/ExperimentalFlags.h + Common/include/Luau/Variant.h Common/include/Luau/VecDeque.h ) endif() @@ -232,7 +233,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Unifier.h Analysis/include/Luau/Unifier2.h Analysis/include/Luau/UnifierSharedState.h - Analysis/include/Luau/Variant.h Analysis/include/Luau/VisitType.h Analysis/src/Anyification.cpp @@ -295,6 +295,19 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Unifier2.cpp ) +# Luau.EqSat Sources +target_sources(Luau.EqSat PRIVATE + EqSat/include/Luau/EGraph.h + EqSat/include/Luau/Id.h + EqSat/include/Luau/Language.h + EqSat/include/Luau/LanguageHash.h + EqSat/include/Luau/Slice.h + EqSat/include/Luau/UnionFind.h + + EqSat/src/Id.cpp + EqSat/src/UnionFind.cpp +) + # Luau.VM Sources target_sources(Luau.VM PRIVATE VM/include/lua.h @@ -418,6 +431,9 @@ if(TARGET Luau.UnitTest) tests/DiffAsserts.cpp tests/DiffAsserts.h tests/Differ.test.cpp + tests/EqSat.language.test.cpp + tests/EqSat.propositional.test.cpp + tests/EqSat.slice.test.cpp tests/Error.test.cpp tests/Fixture.cpp tests/Fixture.h diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index d06189a4..b9af5247 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -124,7 +124,7 @@ int registerTypes(Luau::Frontend& frontend, Luau::GlobalTypes& globals, bool for // Vector3 stub TypeId vector3MetaType = arena.addType(TableType{}); - TypeId vector3InstanceType = arena.addType(ClassType{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test"}); + TypeId vector3InstanceType = arena.addType(ClassType{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test", {}}); getMutable(vector3InstanceType)->props = { {"X", {builtinTypes.numberType}}, {"Y", {builtinTypes.numberType}}, @@ -138,7 +138,7 @@ int registerTypes(Luau::Frontend& frontend, Luau::GlobalTypes& globals, bool for globals.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; // Instance stub - TypeId instanceType = arena.addType(ClassType{"Instance", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId instanceType = arena.addType(ClassType{"Instance", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(instanceType)->props = { {"Name", {builtinTypes.stringType}}, }; @@ -146,7 +146,7 @@ int registerTypes(Luau::Frontend& frontend, Luau::GlobalTypes& globals, bool for globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; // Part stub - TypeId partType = arena.addType(ClassType{"Part", {}, instanceType, nullopt, {}, {}, "Test"}); + TypeId partType = arena.addType(ClassType{"Part", {}, instanceType, nullopt, {}, {}, "Test", {}}); getMutable(partType)->props = { {"Position", {vector3InstanceType}}, }; diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index 7e35e40a..c369cb30 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -18,9 +18,9 @@ ClassFixture::ClassFixture() unfreeze(arena); - TypeId connectionType = arena.addType(ClassType{"Connection", {}, nullopt, nullopt, {}, {}, "Connection"}); + TypeId connectionType = arena.addType(ClassType{"Connection", {}, nullopt, nullopt, {}, {}, "Connection", {}}); - TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(baseClassInstanceType)->props = { {"BaseMethod", Property::readonly(makeFunction(arena, baseClassInstanceType, {numberType}, {}))}, {"BaseField", {numberType}}, @@ -31,7 +31,7 @@ ClassFixture::ClassFixture() getMutable(connectionType)->props = { {"Connect", {makeFunction(arena, connectionType, {makeFunction(arena, nullopt, {baseClassInstanceType}, {})}, {})}}}; - TypeId baseClassType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId baseClassType = arena.addType(ClassType{"BaseClass", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(baseClassType)->props = { {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, @@ -40,48 +40,48 @@ ClassFixture::ClassFixture() globals.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; addGlobalBinding(globals, "BaseClass", baseClassType, "@test"); - TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); + TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test", {}}); getMutable(childClassInstanceType)->props = { {"Method", {makeFunction(arena, childClassInstanceType, {}, {stringType})}}, }; - TypeId childClassType = arena.addType(ClassType{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); + TypeId childClassType = arena.addType(ClassType{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test", {}}); getMutable(childClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, }; globals.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; addGlobalBinding(globals, "ChildClass", childClassType, "@test"); - TypeId grandChildInstanceType = arena.addType(ClassType{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); + TypeId grandChildInstanceType = arena.addType(ClassType{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test", {}}); getMutable(grandChildInstanceType)->props = { {"Method", {makeFunction(arena, grandChildInstanceType, {}, {stringType})}}, }; - TypeId grandChildType = arena.addType(ClassType{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); + TypeId grandChildType = arena.addType(ClassType{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test", {}}); getMutable(grandChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, }; globals.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; addGlobalBinding(globals, "GrandChild", childClassType, "@test"); - TypeId anotherChildInstanceType = arena.addType(ClassType{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); + TypeId anotherChildInstanceType = arena.addType(ClassType{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test", {}}); getMutable(anotherChildInstanceType)->props = { {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {stringType})}}, }; - TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); + TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test", {}}); getMutable(anotherChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, }; globals.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; addGlobalBinding(globals, "AnotherChild", childClassType, "@test"); - TypeId unrelatedClassInstanceType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId unrelatedClassInstanceType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test", {}}); - TypeId unrelatedClassType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId unrelatedClassType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(unrelatedClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {unrelatedClassInstanceType})}}, }; @@ -90,13 +90,13 @@ ClassFixture::ClassFixture() TypeId vector2MetaType = arena.addType(TableType{}); - vector2InstanceType = arena.addType(ClassType{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test"}); + vector2InstanceType = arena.addType(ClassType{"Vector2", {}, nullopt, vector2MetaType, {}, {}, "Test", {}}); getMutable(vector2InstanceType)->props = { {"X", {numberType}}, {"Y", {numberType}}, }; - vector2Type = arena.addType(ClassType{"Vector2", {}, nullopt, nullopt, {}, {}, "Test"}); + vector2Type = arena.addType(ClassType{"Vector2", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(vector2Type)->props = { {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, }; @@ -110,7 +110,7 @@ ClassFixture::ClassFixture() addGlobalBinding(globals, "Vector2", vector2Type, "@test"); TypeId callableClassMetaType = arena.addType(TableType{}); - TypeId callableClassType = arena.addType(ClassType{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); + TypeId callableClassType = arena.addType(ClassType{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test", {}}); getMutable(callableClassMetaType)->props = { {"__call", {makeFunction(arena, nullopt, {callableClassType, stringType}, {numberType})}}, }; @@ -119,7 +119,7 @@ ClassFixture::ClassFixture() auto addIndexableClass = [&arena, &globals](const char* className, TypeId keyType, TypeId returnType) { TypeId indexableClassMetaType = arena.addType(TableType{}); TypeId indexableClassType = - arena.addType(ClassType{className, {}, nullopt, indexableClassMetaType, {}, {}, "Test", TableIndexer{keyType, returnType}}); + arena.addType(ClassType{className, {}, nullopt, indexableClassMetaType, {}, {}, "Test", {}, TableIndexer{keyType, returnType}}); globals.globalScope->exportedTypeBindings[className] = TypeFun{{}, indexableClassType}; }; diff --git a/tests/EqSat.language.test.cpp b/tests/EqSat.language.test.cpp new file mode 100644 index 00000000..282d4ad2 --- /dev/null +++ b/tests/EqSat.language.test.cpp @@ -0,0 +1,144 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include + +#include "Luau/Id.h" +#include "Luau/Language.h" + +#include +#include + +LUAU_EQSAT_ATOM(I32, int); +LUAU_EQSAT_ATOM(Bool, bool); +LUAU_EQSAT_ATOM(Str, std::string); + +LUAU_EQSAT_FIELD(Left); +LUAU_EQSAT_FIELD(Right); +LUAU_EQSAT_NODE_FIELDS(Add, Left, Right); + +using namespace Luau; + +using Value = EqSat::Language; + +TEST_SUITE_BEGIN("EqSatLanguage"); + +TEST_CASE("atom_equality") +{ + CHECK(I32{0} == I32{0}); + CHECK(I32{0} != I32{1}); +} + +TEST_CASE("node_equality") +{ + CHECK(Add{EqSat::Id{0}, EqSat::Id{0}} == Add{EqSat::Id{0}, EqSat::Id{0}}); + CHECK(Add{EqSat::Id{1}, EqSat::Id{0}} != Add{EqSat::Id{0}, EqSat::Id{0}}); +} + +TEST_CASE("language_get") +{ + Value v{I32{5}}; + + auto i = v.get(); + REQUIRE(i); + CHECK(i->value()); + + CHECK(!v.get()); +} + +TEST_CASE("language_copy_ctor") +{ + Value v1{I32{5}}; + Value v2 = v1; + + auto i1 = v1.get(); + auto i2 = v2.get(); + REQUIRE(i1); + REQUIRE(i2); + CHECK(i1->value() == i2->value()); +} + +TEST_CASE("language_move_ctor") +{ + Value v1{Str{"hello"}}; + { + auto s1 = v1.get(); + REQUIRE(s1); + CHECK(s1->value() == "hello"); + } + + Value v2 = std::move(v1); + + auto s1 = v1.get(); + REQUIRE(s1); + CHECK(s1->value() == ""); // this also tests the dtor. + + auto s2 = v2.get(); + REQUIRE(s2); + CHECK(s2->value() == "hello"); +} + +TEST_CASE("language_equality") +{ + Value v1{I32{0}}; + Value v2{I32{0}}; + Value v3{I32{1}}; + Value v4{Bool{true}}; + Value v5{Add{EqSat::Id{0}, EqSat::Id{1}}}; + + CHECK(v1 == v2); + CHECK(v2 != v3); + CHECK(v3 != v4); + CHECK(v4 != v5); +} + +TEST_CASE("language_is_mappable") +{ + std::unordered_map map; + + Value v1{I32{5}}; + Value v2{I32{5}}; + Value v3{Bool{true}}; + Value v4{Add{EqSat::Id{0}, EqSat::Id{1}}}; + + map[v1] = 1; + map[v2] = 2; + map[v3] = 42; + map[v4] = 37; + + CHECK(map[v1] == 2); + CHECK(map[v2] == 2); + CHECK(map[v3] == 42); + CHECK(map[v4] == 37); +} + +TEST_CASE("node_field") +{ + EqSat::Id left{0}; + EqSat::Id right{1}; + + Add add{left, right}; + + EqSat::Id left2 = add.field(); + EqSat::Id right2 = add.field(); + + CHECK(left == left2); + CHECK(left != right2); + CHECK(right == right2); + CHECK(right != left2); +} + +TEST_CASE("language_operands") +{ + Value v1{I32{0}}; + CHECK(v1.operands().empty()); + + Value v2{Add{EqSat::Id{0}, EqSat::Id{1}}}; + const Add* add = v2.get(); + REQUIRE(add); + + EqSat::Slice actual = v2.operands(); + CHECK(actual.size() == 2); + CHECK(actual[0] == add->field()); + CHECK(actual[1] == add->field()); +} + +TEST_SUITE_END(); diff --git a/tests/EqSat.propositional.test.cpp b/tests/EqSat.propositional.test.cpp new file mode 100644 index 00000000..5b2d34b4 --- /dev/null +++ b/tests/EqSat.propositional.test.cpp @@ -0,0 +1,197 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include + +#include "Luau/EGraph.h" +#include "Luau/Id.h" +#include "Luau/Language.h" + +#include + +LUAU_EQSAT_ATOM(Var, std::string); +LUAU_EQSAT_ATOM(Bool, bool); +LUAU_EQSAT_NODE_ARRAY(Not, 1); +LUAU_EQSAT_NODE_ARRAY(And, 2); +LUAU_EQSAT_NODE_ARRAY(Or, 2); +LUAU_EQSAT_NODE_ARRAY(Implies, 2); + +using namespace Luau; + +using PropositionalLogic = EqSat::Language; + +using EGraph = EqSat::EGraph; + +struct ConstantFold +{ + using Data = std::optional; + + Data make(const EGraph& egraph, const Var& var) const + { + return std::nullopt; + } + + Data make(const EGraph& egraph, const Bool& b) const + { + return b.value(); + } + + Data make(const EGraph& egraph, const Not& n) const + { + Data data = egraph[n[0]].data; + if (data) + return !*data; + + return std::nullopt; + } + + Data make(const EGraph& egraph, const And& a) const + { + Data l = egraph[a[0]].data; + Data r = egraph[a[1]].data; + if (l && r) + return *l && *r; + + return std::nullopt; + } + + Data make(const EGraph& egraph, const Or& o) const + { + Data l = egraph[o[0]].data; + Data r = egraph[o[1]].data; + if (l && r) + return *l || *r; + + return std::nullopt; + } + + Data make(const EGraph& egraph, const Implies& i) const + { + Data antecedent = egraph[i[0]].data; + Data consequent = egraph[i[1]].data; + if (antecedent && consequent) + return !*antecedent || *consequent; + + return std::nullopt; + } + + void join(Data& a, const Data& b) const + { + if (!a && b) + a = b; + } +}; + +TEST_SUITE_BEGIN("EqSatPropositionalLogic"); + +TEST_CASE("egraph_hashconsing") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{true}); + EqSat::Id id3 = egraph.add(Bool{false}); + + CHECK(id1 == id2); + CHECK(id2 != id3); +} + +TEST_CASE("egraph_data") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{false}); + + CHECK(egraph[id1].data == true); + CHECK(egraph[id2].data == false); +} + +TEST_CASE("egraph_merge") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Var{"a"}); + EqSat::Id id2 = egraph.add(Bool{true}); + egraph.merge(id1, id2); + + CHECK(egraph[id1].data == true); + CHECK(egraph[id2].data == true); +} + +TEST_CASE("const_fold_true_and_true") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{true}); + EqSat::Id id3 = egraph.add(And{id1, id2}); + + CHECK(egraph[id3].data == true); +} + +TEST_CASE("const_fold_true_and_false") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{true}); + EqSat::Id id2 = egraph.add(Bool{false}); + EqSat::Id id3 = egraph.add(And{id1, id2}); + + CHECK(egraph[id3].data == false); +} + +TEST_CASE("const_fold_false_and_false") +{ + EGraph egraph; + + EqSat::Id id1 = egraph.add(Bool{false}); + EqSat::Id id2 = egraph.add(Bool{false}); + EqSat::Id id3 = egraph.add(And{id1, id2}); + + CHECK(egraph[id3].data == false); +} + +TEST_CASE("implications") +{ + EGraph egraph; + + EqSat::Id t = egraph.add(Bool{true}); + EqSat::Id f = egraph.add(Bool{false}); + + EqSat::Id a = egraph.add(Implies{t, t}); // true + EqSat::Id b = egraph.add(Implies{t, f}); // false + EqSat::Id c = egraph.add(Implies{f, t}); // true + EqSat::Id d = egraph.add(Implies{f, f}); // true + + CHECK(egraph[a].data == true); + CHECK(egraph[b].data == false); + CHECK(egraph[c].data == true); + CHECK(egraph[d].data == true); +} + +TEST_CASE("merge_x_and_y") +{ + EGraph egraph; + + EqSat::Id x = egraph.add(Var{"x"}); + EqSat::Id y = egraph.add(Var{"y"}); + + EqSat::Id a = egraph.add(Var{"a"}); + EqSat::Id ax = egraph.add(And{a, x}); + EqSat::Id ay = egraph.add(And{a, y}); + + egraph.merge(x, y); // [x y] [ax] [ay] [a] + CHECK_EQ(egraph.size(), 4); + CHECK_EQ(egraph.find(x), egraph.find(y)); + CHECK_NE(egraph.find(ax), egraph.find(ay)); + CHECK_NE(egraph.find(a), egraph.find(x)); + CHECK_NE(egraph.find(a), egraph.find(y)); + + egraph.rebuild(); // [x y] [ax ay] [a] + CHECK_EQ(egraph.size(), 3); + CHECK_EQ(egraph.find(x), egraph.find(y)); + CHECK_EQ(egraph.find(ax), egraph.find(ay)); + CHECK_NE(egraph.find(a), egraph.find(x)); + CHECK_NE(egraph.find(a), egraph.find(y)); +} + +TEST_SUITE_END(); diff --git a/tests/EqSat.slice.test.cpp b/tests/EqSat.slice.test.cpp new file mode 100644 index 00000000..26ca3bfd --- /dev/null +++ b/tests/EqSat.slice.test.cpp @@ -0,0 +1,58 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include + +#include "Luau/Slice.h" + +#include + +using namespace Luau; + +TEST_SUITE_BEGIN("EqSatSlice"); + +TEST_CASE("slice_is_a_view_over_array") +{ + std::array a{1, 2, 3, 4, 5, 6, 7, 8}; + + EqSat::Slice slice{a}; + + CHECK(slice.data() == a.data()); + CHECK(slice.size() == a.size()); + + for (size_t i = 0; i < a.size(); ++i) + { + CHECK(slice[i] == a[i]); + CHECK(&slice[i] == &a[i]); + } +} + +TEST_CASE("slice_is_a_view_over_vector") +{ + std::vector vector{1, 2, 3, 4, 5, 6, 7, 8}; + + EqSat::Slice slice{vector.data(), vector.size()}; + + CHECK(slice.data() == vector.data()); + CHECK(slice.size() == vector.size()); + + for (size_t i = 0; i < vector.size(); ++i) + { + CHECK(slice[i] == vector[i]); + CHECK(&slice[i] == &vector[i]); + } +} + +TEST_CASE("mutate_via_slice") +{ + std::array a{1, 2}; + CHECK(a[0] == 1); + CHECK(a[1] == 2); + + EqSat::Slice slice{a}; + slice[0] = 42; + slice[1] = 37; + + CHECK(a[0] == 42); + CHECK(a[1] == 37); +} + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 3651dfeb..ef0731fa 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -658,7 +658,7 @@ void createSomeClasses(Frontend* frontend) ScopePtr moduleScope = globals.globalScope; - TypeId parentType = arena.addType(ClassType{"Parent", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); + TypeId parentType = arena.addType(ClassType{"Parent", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test", {}}); ClassType* parentClass = getMutable(parentType); parentClass->props["method"] = {makeFunction(arena, parentType, {}, {})}; @@ -668,17 +668,17 @@ void createSomeClasses(Frontend* frontend) addGlobalBinding(globals, "Parent", {parentType}); moduleScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; - TypeId childType = arena.addType(ClassType{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); + TypeId childType = arena.addType(ClassType{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test", {}}); addGlobalBinding(globals, "Child", {childType}); moduleScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; - TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, parentType, std::nullopt, {}, nullptr, "Test"}); + TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, parentType, std::nullopt, {}, nullptr, "Test", {}}); addGlobalBinding(globals, "AnotherChild", {anotherChildType}); moduleScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildType}; - TypeId unrelatedType = arena.addType(ClassType{"Unrelated", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); + TypeId unrelatedType = arena.addType(ClassType{"Unrelated", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test", {}}); addGlobalBinding(globals, "Unrelated", {unrelatedType}); moduleScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp index 901461ae..dabdf258 100644 --- a/tests/Generalization.test.cpp +++ b/tests/Generalization.test.cpp @@ -112,7 +112,7 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "dont_traverse_into_class_types_when_ge { auto [propTy, _] = freshType(); - TypeId cursedClass = arena.addType(ClassType{"Cursed", {{"oh_no", Property::readonly(propTy)}}, std::nullopt, std::nullopt, {}, {}, ""}); + TypeId cursedClass = arena.addType(ClassType{"Cursed", {{"oh_no", Property::readonly(propTy)}}, std::nullopt, std::nullopt, {}, {}, "", {}}); auto genClass = generalize(cursedClass); REQUIRE(genClass); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 807b5e73..b758764d 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1485,7 +1485,7 @@ TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiTyped") { unfreeze(frontend.globals.globalTypes); - TypeId instanceType = frontend.globals.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); + TypeId instanceType = frontend.globals.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test", {}}); persist(instanceType); frontend.globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index dd7538ae..3894881c 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -244,13 +244,13 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") { {"__add", {builtinTypes->anyType}}, }, - std::nullopt, std::nullopt, {}, {}, "Test"}}; + std::nullopt, std::nullopt, {}, {}, "Test", {}}}; Type exampleClass{ClassType{"ExampleClass", { {"PropOne", {builtinTypes->numberType}}, {"PropTwo", {builtinTypes->stringType}}, }, - std::nullopt, &exampleMetaClass, {}, {}, "Test"}}; + std::nullopt, &exampleMetaClass, {}, {}, "Test", {}}}; TypeArena dest; CloneState cloneState{builtinTypes}; diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index e70ec1ae..832049a4 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -146,7 +146,7 @@ struct SubtypeFixture : Fixture TypeId cls(const std::string& name, std::optional parent = std::nullopt) { - return arena.addType(ClassType{name, {}, parent.value_or(builtinTypes->classType), {}, {}, nullptr, ""}); + return arena.addType(ClassType{name, {}, parent.value_or(builtinTypes->classType), {}, {}, nullptr, "", {}}); } TypeId cls(const std::string& name, ClassType::Props&& props) diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 1a0fe411..16f4dc48 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -21,13 +21,13 @@ struct ToDotClassFixture : Fixture TypeId baseClassMetaType = arena.addType(TableType{}); - TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test"}); + TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test", {}}); getMutable(baseClassInstanceType)->props = { {"BaseField", {builtinTypes->numberType}}, }; frontend.globals.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test"}); + TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test", {}}); getMutable(childClassInstanceType)->props = { {"ChildField", {builtinTypes->stringType}}, }; diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index f90324d7..90271e29 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -706,19 +706,19 @@ TEST_CASE_FIXTURE(Fixture, "read_write_class_properties") unfreeze(arena); - TypeId instanceType = arena.addType(ClassType{"Instance", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId instanceType = arena.addType(ClassType{"Instance", {}, nullopt, nullopt, {}, {}, "Test", {}}); getMutable(instanceType)->props = {{"Parent", Property::rw(instanceType)}}; // - TypeId workspaceType = arena.addType(ClassType{"Workspace", {}, nullopt, nullopt, {}, {}, "Test"}); + TypeId workspaceType = arena.addType(ClassType{"Workspace", {}, nullopt, nullopt, {}, {}, "Test", {}}); TypeId scriptType = - arena.addType(ClassType{"Script", {{"Parent", Property::rw(workspaceType, instanceType)}}, instanceType, nullopt, {}, {}, "Test"}); + arena.addType(ClassType{"Script", {{"Parent", Property::rw(workspaceType, instanceType)}}, instanceType, nullopt, {}, {}, "Test", {}}); TypeId partType = arena.addType( ClassType{"Part", {{"BrickColor", Property::rw(builtinTypes->stringType)}, {"Parent", Property::rw(workspaceType, instanceType)}}, - instanceType, nullopt, {}, {}, "Test"}); + instanceType, nullopt, {}, {}, "Test", {}}); getMutable(workspaceType)->props = {{"Script", Property::readonly(scriptType)}, {"Part", Property::readonly(partType)}}; diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 868aa9f2..b0d509ac 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -66,14 +66,14 @@ struct RefinementClassFixture : BuiltinsFixture std::optional rootSuper = std::make_optional(builtinTypes->classType); unfreeze(arena); - TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); + TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test", {}}); getMutable(vec3)->props = { {"X", Property{builtinTypes->numberType}}, {"Y", Property{builtinTypes->numberType}}, {"Z", Property{builtinTypes->numberType}}, }; - TypeId inst = arena.addType(ClassType{"Instance", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); + TypeId inst = arena.addType(ClassType{"Instance", {}, rootSuper, std::nullopt, {}, nullptr, "Test", {}}); TypePackId isAParams = arena.addTypePack({inst, builtinTypes->stringType}); TypePackId isARets = arena.addTypePack({builtinTypes->booleanType}); @@ -86,8 +86,8 @@ struct RefinementClassFixture : BuiltinsFixture {"IsA", Property{isA}}, }; - TypeId folder = frontend.globals.globalTypes.addType(ClassType{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); - TypeId part = frontend.globals.globalTypes.addType(ClassType{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); + TypeId folder = frontend.globals.globalTypes.addType(ClassType{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test", {}}); + TypeId part = frontend.globals.globalTypes.addType(ClassType{"Part", {}, inst, std::nullopt, {}, nullptr, "Test", {}}); getMutable(part)->props = { {"Position", Property{vec3}}, }; diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 674a4155..683e9027 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -314,7 +314,7 @@ TEST_CASE("tagging_tables") TEST_CASE("tagging_classes") { - Type base{ClassType{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; + Type base{ClassType{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test", {}}}; CHECK(!Luau::hasTag(&base, "foo")); Luau::attachTag(&base, "foo"); CHECK(Luau::hasTag(&base, "foo")); @@ -322,8 +322,8 @@ TEST_CASE("tagging_classes") TEST_CASE("tagging_subclasses") { - Type base{ClassType{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}}; - Type derived{ClassType{"Derived", {}, &base, std::nullopt, {}, nullptr, "Test"}}; + Type base{ClassType{"Base", {}, std::nullopt, std::nullopt, {}, nullptr, "Test", {}}}; + Type derived{ClassType{"Derived", {}, &base, std::nullopt, {}, nullptr, "Test", {}}}; CHECK(!Luau::hasTag(&base, "foo")); CHECK(!Luau::hasTag(&derived, "foo"));