// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include "Luau/EGraph.h" #include "Luau/Id.h" #include "Luau/Language.h" #include "Luau/Lexer.h" // For Allocator #include "Luau/NotNull.h" #include "Luau/TypeArena.h" #include "Luau/TypeFwd.h" namespace Luau { struct TypeFunction; } namespace Luau::EqSatSimplification { using StringId = uint32_t; using Id = Luau::EqSat::Id; LUAU_EQSAT_UNIT(TNil); LUAU_EQSAT_UNIT(TBoolean); LUAU_EQSAT_UNIT(TNumber); LUAU_EQSAT_UNIT(TString); LUAU_EQSAT_UNIT(TThread); LUAU_EQSAT_UNIT(TTopFunction); LUAU_EQSAT_UNIT(TTopTable); LUAU_EQSAT_UNIT(TTopClass); LUAU_EQSAT_UNIT(TBuffer); // Used for any type that eqsat can't do anything interesting with. LUAU_EQSAT_ATOM(TOpaque, TypeId); LUAU_EQSAT_ATOM(SBoolean, bool); LUAU_EQSAT_ATOM(SString, StringId); LUAU_EQSAT_ATOM(TFunction, TypeId); LUAU_EQSAT_ATOM(TImportedTable, TypeId); LUAU_EQSAT_ATOM(TClass, TypeId); LUAU_EQSAT_UNIT(TAny); LUAU_EQSAT_UNIT(TError); LUAU_EQSAT_UNIT(TUnknown); LUAU_EQSAT_UNIT(TNever); LUAU_EQSAT_NODE_SET(Union); LUAU_EQSAT_NODE_SET(Intersection); LUAU_EQSAT_NODE_ARRAY(Negation, 1); LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*); LUAU_EQSAT_UNIT(TNoRefine); LUAU_EQSAT_UNIT(Invalid); // enodes are immutable, but types are cyclic. We need a way to tie the knot. // We handle this by generating TBound nodes at points where we encounter cycles. // Each TBound has an ordinal that we later map onto the type. // We use a substitution rule to replace all TBound nodes with their referrent. LUAU_EQSAT_ATOM(TBound, size_t); // Tables are sufficiently unlike other enodes that the Language.h macros won't cut it. struct TTable { explicit TTable(Id basis); TTable(Id basis, std::vector propNames_, std::vector propTypes_); // All TTables extend some other table. This may be TTopTable. // // It will frequently be a TImportedTable, in which case we can reuse things // like source location and documentation info. Id getBasis() const; EqSat::Slice propTypes() const; // TODO: Also support read-only table props // TODO: Indexer type, index result type. std::vector propNames; // The enode interface EqSat::Slice mutableOperands(); EqSat::Slice operands() const; bool operator==(const TTable& rhs) const; bool operator!=(const TTable& rhs) const { return !(*this == rhs); } struct Hash { size_t operator()(const TTable& value) const; }; private: // The first element of this vector is the basis. Subsequent elements are // property types. As we add other things like read-only properties and // indexers, the structure of this array is likely to change. // // We encode our data in this way so that the operands() method can properly // return a Slice. std::vector storage; }; using EType = EqSat::Language< TNil, TBoolean, TNumber, TString, TThread, TTopFunction, TTopTable, TTopClass, TBuffer, TOpaque, SBoolean, SString, TFunction, TTable, TImportedTable, TClass, TAny, TError, TUnknown, TNever, Union, Intersection, Negation, TTypeFun, Invalid, TNoRefine, TBound>; struct StringCache { Allocator allocator; DenseHashMap strings{{}}; std::vector views; StringId add(std::string_view s); std::string_view asStringView(StringId id) const; std::string asString(StringId id) const; }; using EGraph = Luau::EqSat::EGraph; struct Simplify { using Data = bool; template Data make(const EGraph&, const T&) const; void join(Data& left, const Data& right) const; }; struct Subst { Id eclass; Id newClass; std::string desc; Subst(Id eclass, Id newClass, std::string desc = ""); }; struct Simplifier { NotNull arena; NotNull builtinTypes; EGraph egraph; StringCache stringCache; // enodes are immutable but types can be cyclic, so we need some way to // encode the cycle. This map is used to connect TBound nodes to the right // eclass. // // The cyclicIntersection rewrite rule uses this to sense when a cycle can // be deleted from an intersection or union. std::unordered_map mappingIdToClass; std::vector substs; using RewriteRuleFn = void (Simplifier::*)(Id id); Simplifier(NotNull arena, NotNull builtinTypes); // Utilities const EqSat::EClass& get(Id id) const; Id find(Id id) const; Id add(EType enode); template const Tag* isTag(Id id) const; template const Tag* isTag(const EType& enode) const; void subst(Id from, Id to); void subst(Id from, Id to, const std::string& ruleName); void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map& forceNodes); void unionClasses(std::vector& hereParts, Id there); // Rewrite rules void simplifyUnion(Id id); void uninhabitedIntersection(Id id); void intersectWithNegatedClass(Id id); void intersectWithNoRefine(Id id); void cyclicIntersectionOfUnion(Id id); void cyclicUnionOfIntersection(Id id); void expandNegation(Id id); void intersectionOfUnion(Id id); void intersectTableProperty(Id id); void uninhabitedTable(Id id); void unneededTableModification(Id id); void builtinTypeFunctions(Id id); void iffyTypeFunctions(Id id); }; template struct QueryIterator { QueryIterator(); QueryIterator(EGraph* egraph, Id eclass); bool operator==(const QueryIterator& other) const; bool operator!=(const QueryIterator& other) const; std::pair operator*() const; QueryIterator& operator++(); QueryIterator& operator++(int); private: EGraph* egraph = nullptr; Id eclass; size_t index = 0; }; template struct Query { EGraph* egraph; Id eclass; Query(EGraph* egraph, Id eclass) : egraph(egraph) , eclass(eclass) { } QueryIterator begin() { return QueryIterator{egraph, eclass}; } QueryIterator end() { return QueryIterator{}; } }; template QueryIterator::QueryIterator() : egraph(nullptr) , eclass(Id{0}) , index(0) { } template QueryIterator::QueryIterator(EGraph* egraph_, Id eclass) : egraph(egraph_) , eclass(eclass) , index(0) { const auto& ecl = (*egraph)[eclass]; static constexpr const int idx = EType::VariantTy::getTypeId(); for (const auto& enode : ecl.nodes) { if (enode.index() < idx) ++index; else break; } if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx) { egraph = nullptr; index = 0; } } template bool QueryIterator::operator==(const QueryIterator& rhs) const { if (egraph == nullptr && rhs.egraph == nullptr) return true; return egraph == rhs.egraph && eclass == rhs.eclass && index == rhs.index; } template bool QueryIterator::operator!=(const QueryIterator& rhs) const { return !(*this == rhs); } template std::pair QueryIterator::operator*() const { LUAU_ASSERT(egraph != nullptr); EGraph::EClassT& ecl = (*egraph)[eclass]; LUAU_ASSERT(index < ecl.nodes.size()); auto& enode = ecl.nodes[index]; Tag* result = enode.template get(); LUAU_ASSERT(result); return {result, index}; } // pre-increment template QueryIterator& QueryIterator::operator++() { const auto& ecl = (*egraph)[eclass]; ++index; if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId()) { egraph = nullptr; index = 0; } return *this; } // post-increment template QueryIterator& QueryIterator::operator++(int) { QueryIterator res = *this; ++res; return res; } } // namespace Luau::EqSatSimplification