From 0f0c0e4d28db9bfecb641080c6a1ff577a335398 Mon Sep 17 00:00:00 2001 From: Vyacheslav Egorov Date: Fri, 12 Apr 2024 13:44:40 +0300 Subject: [PATCH] Sync to upstream/release/621 --- Analysis/include/Luau/Clone.h | 2 - Analysis/include/Luau/ConstraintGenerator.h | 8 +- Analysis/include/Luau/ConstraintSolver.h | 5 +- Analysis/include/Luau/Error.h | 22 +- Analysis/include/Luau/Instantiation2.h | 15 +- Analysis/include/Luau/Normalize.h | 10 +- Analysis/include/Luau/Substitution.h | 25 +- Analysis/include/Luau/Subtyping.h | 2 +- Analysis/include/Luau/TableLiteralInference.h | 14 +- Analysis/include/Luau/TypeFamily.h | 19 +- .../include/Luau/TypeFamilyReductionGuesser.h | 2 +- Analysis/include/Luau/TypeUtils.h | 2 + Analysis/src/Clone.cpp | 580 +----------------- Analysis/src/ConstraintGenerator.cpp | 148 ++--- Analysis/src/ConstraintSolver.cpp | 111 ++-- Analysis/src/DataFlowGraph.cpp | 4 +- Analysis/src/Error.cpp | 40 ++ Analysis/src/Frontend.cpp | 3 - Analysis/src/Instantiation2.cpp | 30 +- Analysis/src/IostreamHelpers.cpp | 28 + Analysis/src/Normalize.cpp | 237 ++++++- Analysis/src/Simplify.cpp | 10 +- Analysis/src/Substitution.cpp | 44 +- Analysis/src/Subtyping.cpp | 6 +- Analysis/src/TableLiteralInference.cpp | 48 +- Analysis/src/ToDot.cpp | 9 +- Analysis/src/Type.cpp | 4 + Analysis/src/TypeChecker2.cpp | 184 ++++-- Analysis/src/TypeFamily.cpp | 197 ++++-- Analysis/src/TypeFamilyReductionGuesser.cpp | 6 +- Analysis/src/TypeInfer.cpp | 36 +- Analysis/src/TypeUtils.cpp | 22 +- Analysis/src/Unifier.cpp | 222 ++++--- CLI/Repl.cpp | 152 ++--- CodeGen/include/Luau/CodeGen.h | 30 + CodeGen/include/Luau/IrVisitUseDef.h | 6 +- CodeGen/include/Luau/SharedCodeAllocator.h | 20 +- CodeGen/src/AssemblyBuilderA64.cpp | 46 +- CodeGen/src/CodeGen.cpp | 240 ++++---- CodeGen/src/CodeGenContext.cpp | 159 ++++- CodeGen/src/CodeGenContext.h | 43 +- CodeGen/src/CodeGenLower.h | 4 +- CodeGen/src/EmitBuiltinsX64.cpp | 12 +- CodeGen/src/IrLoweringA64.cpp | 14 +- CodeGen/src/IrTranslateBuiltins.cpp | 6 +- CodeGen/src/IrValueLocationTracking.cpp | 2 +- CodeGen/src/OptimizeConstProp.cpp | 12 +- CodeGen/src/OptimizeDeadStore.cpp | 239 +++++++- CodeGen/src/SharedCodeAllocator.cpp | 70 ++- Common/include/Luau/ExperimentalFlags.h | 1 - Sources.cmake | 3 +- VM/src/ldblib.cpp | 7 +- VM/src/lgcdebug.cpp | 13 +- tests/AssemblyBuilderA64.test.cpp | 4 - tests/Conformance.test.cpp | 3 - tests/DataFlowGraph.test.cpp | 20 + tests/Error.test.cpp | 4 +- tests/Instantiation2.test.cpp | 53 ++ tests/IrBuilder.test.cpp | 505 ++++++++++++++- tests/IrLowering.test.cpp | 20 +- tests/Module.test.cpp | 44 +- tests/Normalize.test.cpp | 113 ++-- tests/RequireByString.test.cpp | 22 - tests/SharedCodeAllocator.test.cpp | 116 ++++ tests/Simplify.test.cpp | 11 + tests/TypeFamily.test.cpp | 58 +- tests/TypeInfer.aliases.test.cpp | 53 ++ tests/TypeInfer.cfa.test.cpp | 55 +- tests/TypeInfer.functions.test.cpp | 59 ++ tests/TypeInfer.oop.test.cpp | 4 - tests/TypeInfer.provisional.test.cpp | 32 +- tests/TypeInfer.singletons.test.cpp | 12 + tests/TypeInfer.tables.test.cpp | 37 ++ tests/TypeInfer.test.cpp | 39 +- tests/TypeInfer.tryUnify.test.cpp | 25 - tests/TypeInfer.unionTypes.test.cpp | 5 - tools/faillist.txt | 11 +- 77 files changed, 2800 insertions(+), 1679 deletions(-) create mode 100644 tests/Instantiation2.test.cpp diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 28ab9931..103b5bbd 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -20,8 +20,6 @@ struct CloneState SeenTypes seenTypes; SeenTypePacks seenTypePacks; - - int recursionCount = 0; }; TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 48a1b77a..4f217142 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -95,10 +95,6 @@ struct ConstraintGenerator // will enqueue them during solving. std::vector unqueuedConstraints; - // Type family instances created by the generator. This is used to ensure - // that these instances are reduced fully by the solver. - std::vector familyInstances; - // The private scope of type aliases for which the type parameters belong to. DenseHashMap astTypeAliasDefiningScopes{nullptr}; @@ -264,8 +260,8 @@ private: std::optional assignedTy; }; - LValueBounds checkLValue(const ScopePtr& scope, AstExpr* expr, bool transform); - LValueBounds checkLValue(const ScopePtr& scope, AstExprLocal* local, bool transform); + LValueBounds checkLValue(const ScopePtr& scope, AstExpr* expr); + LValueBounds checkLValue(const ScopePtr& scope, AstExprLocal* local); LValueBounds checkLValue(const ScopePtr& scope, AstExprGlobal* global); LValueBounds checkLValue(const ScopePtr& scope, AstExprIndexName* indexName); LValueBounds checkLValue(const ScopePtr& scope, AstExprIndexExpr* indexExpr); diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 3d00ec06..9ad885a7 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -75,9 +75,6 @@ struct ConstraintSolver // A constraint can be both blocked and unsolved, for instance. std::vector> unsolvedConstraints; - // This is a set of type families that need to be reduced after all constraints have been dispatched. - DenseHashSet familyInstances{nullptr}; - // A mapping of constraint pointer to how many things the constraint is // blocked on. Can be empty or 0 for constraints that are not blocked on // anything. @@ -137,7 +134,7 @@ struct ConstraintSolver bool tryDispatch(const HasPropConstraint& c, NotNull constraint); bool tryDispatch(const SetPropConstraint& c, NotNull constraint); - bool tryDispatchHasIndexer(int& recursionDepth, NotNull constraint, TypeId subjectType, TypeId indexType, TypeId resultType); + bool tryDispatchHasIndexer(int& recursionDepth, NotNull constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set& seen); bool tryDispatch(const HasIndexerConstraint& c, NotNull constraint); /// (dispatched, found) where diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 4fbb4089..2ba284a1 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -409,6 +409,26 @@ struct CheckedFunctionIncorrectArgs bool operator==(const CheckedFunctionIncorrectArgs& rhs) const; }; +struct CannotAssignToNever +{ + // type of the rvalue being assigned + TypeId rhsType; + + // Originating type. + std::vector cause; + + enum class Reason + { + // when assigning to a property in a union of tables, the properties type + // is narrowed to the intersection of its type in each variant. + PropertyNarrowed, + }; + + Reason reason; + + bool operator==(const CannotAssignToNever& rhs) const; +}; + struct UnexpectedTypeInSubtyping { TypeId ty; @@ -427,7 +447,7 @@ using TypeErrorData = Variant instantiate2( + TypeArena* arena, DenseHashMap genericSubstitutions, DenseHashMap genericPackSubstitutions, TypeId ty); +std::optional instantiate2(TypeArena* arena, DenseHashMap genericSubstitutions, + DenseHashMap genericPackSubstitutions, TypePackId tp); + } // namespace Luau diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index 9b992a5e..35e0c7a1 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -283,6 +283,11 @@ struct NormalizedType // The generic/free part of the type. NormalizedTyvars tyvars; + // Free types, blocked types, and certain other types change shape as type + // inference is done. If we were to cache the normalization of these types, + // we'd be reusing bad, stale data. + bool isCacheable = true; + NormalizedType(NotNull builtinTypes); NormalizedType() = delete; @@ -330,7 +335,7 @@ struct NormalizedType class Normalizer { - std::unordered_map> cachedNormals; + std::unordered_map> cachedNormals; std::unordered_map cachedIntersections; std::unordered_map cachedUnions; std::unordered_map> cachedTypeIds; @@ -355,7 +360,8 @@ public: Normalizer& operator=(Normalizer&) = delete; // If this returns null, the typechecker should emit a "too complex" error - const NormalizedType* normalize(TypeId ty); + const NormalizedType* DEPRECATED_normalize(TypeId ty); + std::shared_ptr normalize(TypeId ty); void clearNormal(NormalizedType& norm); // ------- Cached TypeIds diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 8e3bdcd5..16e36e09 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -183,13 +183,21 @@ struct Tarjan struct Substitution : Tarjan { protected: - Substitution(const TxnLog* log_, TypeArena* arena) - : arena(arena) - { - log = log_; - LUAU_ASSERT(log); - LUAU_ASSERT(arena); - } + Substitution(const TxnLog* log_, TypeArena* arena); + + /* + * By default, Substitution assumes that the types produced by clean() are + * freshly allocated types that are safe to mutate. + * + * If your clean() implementation produces a type that is not safe to + * mutate, you must call dontTraverseInto on this type (or type pack) to + * prevent Substitution from attempting to perform substitutions within the + * cleaned type. + * + * See the test weird_cyclic_instantiation for an example. + */ + void dontTraverseInto(TypeId ty); + void dontTraverseInto(TypePackId tp); public: TypeArena* arena; @@ -198,6 +206,9 @@ public: DenseHashSet replacedTypes{nullptr}; DenseHashSet replacedTypePacks{nullptr}; + DenseHashSet noTraverseTypes{nullptr}; + DenseHashSet noTraverseTypePacks{nullptr}; + std::optional substitute(TypeId ty); std::optional substitute(TypePackId tp); diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h index 0a2c3c7f..649b76b5 100644 --- a/Analysis/include/Luau/Subtyping.h +++ b/Analysis/include/Luau/Subtyping.h @@ -208,7 +208,7 @@ private: SubtypingResult isCovariantWith(SubtypingEnvironment& env, const TableIndexer& subIndexer, const TableIndexer& superIndexer); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const Property& subProperty, const Property& superProperty, const std::string& name); - SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedType* subNorm, const NormalizedType* superNorm); + SubtypingResult isCovariantWith(SubtypingEnvironment& env, const std::shared_ptr& subNorm, const std::shared_ptr& superNorm); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const NormalizedClassType& superClass); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedClassType& subClass, const TypeIds& superTables); SubtypingResult isCovariantWith(SubtypingEnvironment& env, const NormalizedStringType& subString, const NormalizedStringType& superString); diff --git a/Analysis/include/Luau/TableLiteralInference.h b/Analysis/include/Luau/TableLiteralInference.h index 1a6d51ea..6f541a00 100644 --- a/Analysis/include/Luau/TableLiteralInference.h +++ b/Analysis/include/Luau/TableLiteralInference.h @@ -14,15 +14,7 @@ struct BuiltinTypes; struct Unifier2; class AstExpr; -TypeId matchLiteralType( - NotNull> astTypes, - NotNull> astExpectedTypes, - NotNull builtinTypes, - NotNull arena, - NotNull unifier, - TypeId expectedType, - TypeId exprType, - const AstExpr* expr -); - +TypeId matchLiteralType(NotNull> astTypes, NotNull> astExpectedTypes, + NotNull builtinTypes, NotNull arena, NotNull unifier, TypeId expectedType, TypeId exprType, + const AstExpr* expr, std::vector& toBlock); } diff --git a/Analysis/include/Luau/TypeFamily.h b/Analysis/include/Luau/TypeFamily.h index 99f4f446..eef26e78 100644 --- a/Analysis/include/Luau/TypeFamily.h +++ b/Analysis/include/Luau/TypeFamily.h @@ -19,6 +19,22 @@ struct TypeArena; struct TxnLog; class Normalizer; +struct TypeFamilyQueue +{ + NotNull> queuedTys; + NotNull> queuedTps; + + void add(TypeId instanceTy); + void add(TypePackId instanceTp); + + template + void add(const std::vector& ts) + { + for (const T& t : ts) + enqueue(t); + } +}; + struct TypeFamilyContext { NotNull arena; @@ -60,6 +76,7 @@ struct TypeFamilyContext NotNull pushConstraint(ConstraintV&& c); }; + /// Represents a reduction result, which may have successfully reduced the type, /// may have concretely failed to reduce the type, or may simply be stuck /// without more information. @@ -83,7 +100,7 @@ struct TypeFamilyReductionResult template using ReducerFunction = - std::function(T, const std::vector&, const std::vector&, NotNull)>; + std::function(T, NotNull, const std::vector&, const std::vector&, NotNull)>; /// Represents a type function that may be applied to map a series of types and /// type packs to a single output type. diff --git a/Analysis/include/Luau/TypeFamilyReductionGuesser.h b/Analysis/include/Luau/TypeFamilyReductionGuesser.h index 9903e381..29114f72 100644 --- a/Analysis/include/Luau/TypeFamilyReductionGuesser.h +++ b/Analysis/include/Luau/TypeFamilyReductionGuesser.h @@ -68,7 +68,7 @@ private: bool operandIsAssignable(TypeId ty); std::optional tryAssignOperandType(TypeId ty); - const NormalizedType* normalize(TypeId ty); + std::shared_ptr normalize(TypeId ty); void step(); void infer(); bool done(); diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index b4726b5c..b02319ce 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -62,6 +62,8 @@ std::optional findTablePropertyRespectingMeta( std::optional findTablePropertyRespectingMeta( NotNull builtinTypes, ErrorVec& errors, TypeId ty, const std::string& name, ValueContext context, Location location); +bool occursCheck(TypeId needle, TypeId haystack); + // Returns the minimum and maximum number of types the argument list can accept. std::pair> getParameterExtents(const TxnLog* log, TypePackId tp, bool includeHiddenVariadics = false); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index bf02f743..a96e5866 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -2,16 +2,13 @@ #include "Luau/Clone.h" #include "Luau/NotNull.h" -#include "Luau/RecursionCounter.h" -#include "Luau/TxnLog.h" #include "Luau/Type.h" #include "Luau/TypePack.h" #include "Luau/Unifiable.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) -LUAU_FASTFLAGVARIABLE(LuauStacklessTypeClone3, false) +// For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit. LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000) namespace Luau @@ -28,7 +25,7 @@ const T* get(const Kind& kind) return get_if(&kind); } -class TypeCloner2 +class TypeCloner { NotNull arena; NotNull builtinTypes; @@ -44,7 +41,7 @@ class TypeCloner2 int steps = 0; public: - TypeCloner2(NotNull arena, NotNull builtinTypes, NotNull types, NotNull packs) + TypeCloner(NotNull arena, NotNull builtinTypes, NotNull types, NotNull packs) : arena(arena) , builtinTypes(builtinTypes) , types(types) @@ -204,15 +201,14 @@ private: if (auto ty = p.writeTy) cloneWriteTy = shallowClone(*ty); - std::optional cloned = Property::create(cloneReadTy, cloneWriteTy); - LUAU_ASSERT(cloned); - cloned->deprecated = p.deprecated; - cloned->deprecatedSuggestion = p.deprecatedSuggestion; - cloned->location = p.location; - cloned->tags = p.tags; - cloned->documentationSymbol = p.documentationSymbol; - cloned->typeLocation = p.typeLocation; - return *cloned; + Property cloned = Property::create(cloneReadTy, cloneWriteTy); + cloned.deprecated = p.deprecated; + cloned.deprecatedSuggestion = p.deprecatedSuggestion; + cloned.location = p.location; + cloned.tags = p.tags; + cloned.documentationSymbol = p.documentationSymbol; + cloned.typeLocation = p.typeLocation; + return cloned; } else { @@ -453,469 +449,13 @@ private: } // namespace -namespace -{ - -Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState) -{ - if (FFlag::DebugLuauDeferredConstraintResolution) - { - std::optional cloneReadTy; - if (auto ty = prop.readTy) - cloneReadTy = clone(*ty, dest, cloneState); - - std::optional cloneWriteTy; - if (auto ty = prop.writeTy) - cloneWriteTy = clone(*ty, dest, cloneState); - - std::optional cloned = Property::create(cloneReadTy, cloneWriteTy); - LUAU_ASSERT(cloned); - cloned->deprecated = prop.deprecated; - cloned->deprecatedSuggestion = prop.deprecatedSuggestion; - cloned->location = prop.location; - cloned->tags = prop.tags; - cloned->documentationSymbol = prop.documentationSymbol; - cloned->typeLocation = prop.typeLocation; - return *cloned; - } - else - { - return Property{ - clone(prop.type(), dest, cloneState), - prop.deprecated, - prop.deprecatedSuggestion, - prop.location, - prop.tags, - prop.documentationSymbol, - prop.typeLocation, - }; - } -} - -static TableIndexer clone(const TableIndexer& indexer, TypeArena& dest, CloneState& cloneState) -{ - return TableIndexer{clone(indexer.indexType, dest, cloneState), clone(indexer.indexResultType, dest, cloneState)}; -} - -struct TypePackCloner; - -/* - * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. - * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. - */ - -struct TypeCloner -{ - TypeCloner(TypeArena& dest, TypeId typeId, CloneState& cloneState) - : dest(dest) - , typeId(typeId) - , seenTypes(cloneState.seenTypes) - , seenTypePacks(cloneState.seenTypePacks) - , cloneState(cloneState) - { - } - - TypeArena& dest; - TypeId typeId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; - - template - void defaultClone(const T& t); - - void operator()(const FreeType& t); - void operator()(const LocalType& t); - void operator()(const GenericType& t); - void operator()(const BoundType& t); - void operator()(const ErrorType& t); - void operator()(const BlockedType& t); - void operator()(const PendingExpansionType& t); - void operator()(const PrimitiveType& t); - void operator()(const SingletonType& t); - void operator()(const FunctionType& t); - void operator()(const TableType& t); - void operator()(const MetatableType& t); - void operator()(const ClassType& t); - void operator()(const AnyType& t); - void operator()(const UnionType& t); - void operator()(const IntersectionType& t); - void operator()(const LazyType& t); - void operator()(const UnknownType& t); - void operator()(const NeverType& t); - void operator()(const NegationType& t); - void operator()(const TypeFamilyInstanceType& t); -}; - -struct TypePackCloner -{ - TypeArena& dest; - TypePackId typePackId; - SeenTypes& seenTypes; - SeenTypePacks& seenTypePacks; - CloneState& cloneState; - - TypePackCloner(TypeArena& dest, TypePackId typePackId, CloneState& cloneState) - : dest(dest) - , typePackId(typePackId) - , seenTypes(cloneState.seenTypes) - , seenTypePacks(cloneState.seenTypePacks) - , cloneState(cloneState) - { - } - - template - void defaultClone(const T& t) - { - TypePackId cloned = dest.addTypePack(TypePackVar{t}); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const FreeTypePack& t) - { - defaultClone(t); - } - - void operator()(const GenericTypePack& t) - { - defaultClone(t); - } - - void operator()(const ErrorTypePack& t) - { - defaultClone(t); - } - - void operator()(const BlockedTypePack& t) - { - defaultClone(t); - } - - // While we are a-cloning, we can flatten out bound Types and make things a bit tighter. - // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. - void operator()(const Unifiable::Bound& t) - { - TypePackId cloned = clone(t.boundTo, dest, cloneState); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const VariadicTypePack& t) - { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, cloneState), /*hidden*/ t.hidden}}); - seenTypePacks[typePackId] = cloned; - } - - void operator()(const TypePack& t) - { - TypePackId cloned = dest.addTypePack(TypePack{}); - TypePack* destTp = getMutable(cloned); - LUAU_ASSERT(destTp != nullptr); - seenTypePacks[typePackId] = cloned; - - for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, cloneState)); - - if (t.tail) - destTp->tail = clone(*t.tail, dest, cloneState); - } - - void operator()(const TypeFamilyInstanceTypePack& t) - { - TypePackId cloned = dest.addTypePack(TypeFamilyInstanceTypePack{t.family, {}, {}}); - TypeFamilyInstanceTypePack* destTp = getMutable(cloned); - LUAU_ASSERT(destTp); - seenTypePacks[typePackId] = cloned; - - destTp->typeArguments.reserve(t.typeArguments.size()); - for (TypeId ty : t.typeArguments) - destTp->typeArguments.push_back(clone(ty, dest, cloneState)); - - destTp->packArguments.reserve(t.packArguments.size()); - for (TypePackId tp : t.packArguments) - destTp->packArguments.push_back(clone(tp, dest, cloneState)); - } -}; - -template -void TypeCloner::defaultClone(const T& t) -{ - TypeId cloned = dest.addType(t); - seenTypes[typeId] = cloned; -} - -void TypeCloner::operator()(const FreeType& t) -{ - if (FFlag::DebugLuauDeferredConstraintResolution) - { - FreeType ft{nullptr, clone(t.lowerBound, dest, cloneState), clone(t.upperBound, dest, cloneState)}; - TypeId res = dest.addType(ft); - seenTypes[typeId] = res; - } - else - defaultClone(t); -} - -void TypeCloner::operator()(const LocalType& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const GenericType& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const Unifiable::Bound& t) -{ - TypeId boundTo = clone(t.boundTo, dest, cloneState); - seenTypes[typeId] = boundTo; -} - -void TypeCloner::operator()(const Unifiable::Error& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const BlockedType& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const PendingExpansionType& t) -{ - TypeId res = dest.addType(PendingExpansionType{t.prefix, t.name, t.typeArguments, t.packArguments}); - PendingExpansionType* petv = getMutable(res); - LUAU_ASSERT(petv); - - seenTypes[typeId] = res; - - std::vector typeArguments; - for (TypeId arg : t.typeArguments) - typeArguments.push_back(clone(arg, dest, cloneState)); - - std::vector packArguments; - for (TypePackId arg : t.packArguments) - packArguments.push_back(clone(arg, dest, cloneState)); - - petv->typeArguments = std::move(typeArguments); - petv->packArguments = std::move(packArguments); -} - -void TypeCloner::operator()(const PrimitiveType& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const SingletonType& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const FunctionType& t) -{ - // FISHY: We always erase the scope when we clone things. clone() was - // originally written so that we could copy a module's type surface into an - // export arena. This probably dates to that. - TypeId result = dest.addType(FunctionType{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); - FunctionType* ftv = getMutable(result); - LUAU_ASSERT(ftv != nullptr); - - seenTypes[typeId] = result; - - for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, cloneState)); - - for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, cloneState)); - - ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, cloneState); - ftv->argNames = t.argNames; - ftv->retTypes = clone(t.retTypes, dest, cloneState); - ftv->hasNoFreeOrGenericTypes = t.hasNoFreeOrGenericTypes; - ftv->isCheckedFunction = t.isCheckedFunction; -} - -void TypeCloner::operator()(const TableType& t) -{ - // If table is now bound to another one, we ignore the content of the original - if (t.boundTo) - { - TypeId boundTo = clone(*t.boundTo, dest, cloneState); - seenTypes[typeId] = boundTo; - return; - } - - TypeId result = dest.addType(TableType{}); - TableType* ttv = getMutable(result); - LUAU_ASSERT(ttv != nullptr); - - *ttv = t; - - seenTypes[typeId] = result; - - ttv->level = TypeLevel{0, 0}; - - for (const auto& [name, prop] : t.props) - ttv->props[name] = clone(prop, dest, cloneState); - - if (t.indexer) - ttv->indexer = clone(*t.indexer, dest, cloneState); - - for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, cloneState); - - for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, cloneState); - - ttv->definitionModuleName = t.definitionModuleName; - ttv->definitionLocation = t.definitionLocation; - ttv->tags = t.tags; -} - -void TypeCloner::operator()(const MetatableType& t) -{ - TypeId result = dest.addType(MetatableType{}); - MetatableType* mtv = getMutable(result); - seenTypes[typeId] = result; - - mtv->table = clone(t.table, dest, cloneState); - mtv->metatable = clone(t.metatable, dest, cloneState); -} - -void TypeCloner::operator()(const ClassType& t) -{ - TypeId result = dest.addType(ClassType{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData, t.definitionModuleName}); - ClassType* ctv = getMutable(result); - - seenTypes[typeId] = result; - - for (const auto& [name, prop] : t.props) - ctv->props[name] = clone(prop, dest, cloneState); - - if (t.parent) - ctv->parent = clone(*t.parent, dest, cloneState); - - if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, cloneState); - - if (t.indexer) - ctv->indexer = clone(*t.indexer, dest, cloneState); -} - -void TypeCloner::operator()(const AnyType& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const UnionType& t) -{ - // We're just using this FreeType as a placeholder until we've finished - // cloning the parts of this union so it is okay that its bounds are - // nullptr. We'll never indirect them. - TypeId result = dest.addType(FreeType{nullptr, /*lowerBound*/ nullptr, /*upperBound*/ nullptr}); - seenTypes[typeId] = result; - - std::vector options; - options.reserve(t.options.size()); - - for (TypeId ty : t.options) - options.push_back(clone(ty, dest, cloneState)); - - asMutable(result)->ty.emplace(std::move(options)); -} - -void TypeCloner::operator()(const IntersectionType& t) -{ - TypeId result = dest.addType(IntersectionType{}); - seenTypes[typeId] = result; - - IntersectionType* option = getMutable(result); - LUAU_ASSERT(option != nullptr); - - for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, cloneState)); -} - -void TypeCloner::operator()(const LazyType& t) -{ - if (TypeId unwrapped = t.unwrapped.load()) - { - seenTypes[typeId] = clone(unwrapped, dest, cloneState); - } - else - { - defaultClone(t); - } -} - -void TypeCloner::operator()(const UnknownType& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const NeverType& t) -{ - defaultClone(t); -} - -void TypeCloner::operator()(const NegationType& t) -{ - TypeId result = dest.addType(AnyType{}); - seenTypes[typeId] = result; - - TypeId ty = clone(t.ty, dest, cloneState); - asMutable(result)->ty = NegationType{ty}; -} - -void TypeCloner::operator()(const TypeFamilyInstanceType& t) -{ - TypeId result = dest.addType(TypeFamilyInstanceType{ - t.family, - {}, - {}, - }); - - seenTypes[typeId] = result; - - TypeFamilyInstanceType* tfit = getMutable(result); - LUAU_ASSERT(tfit != nullptr); - - tfit->typeArguments.reserve(t.typeArguments.size()); - for (TypeId p : t.typeArguments) - tfit->typeArguments.push_back(clone(p, dest, cloneState)); - - tfit->packArguments.reserve(t.packArguments.size()); - for (TypePackId p : t.packArguments) - tfit->packArguments.push_back(clone(p, dest, cloneState)); -} - -} // anonymous namespace - TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) { if (tp->persistent) return tp; - if (FFlag::LuauStacklessTypeClone3) - { - TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; - return cloner.clone(tp); - } - else - { - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypePackId& res = cloneState.seenTypePacks[tp]; - - if (res == nullptr) - { - TypePackCloner cloner{dest, tp, cloneState}; - Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. - } - - return res; - } + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.clone(tp); } TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) @@ -923,91 +463,35 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (typeId->persistent) return typeId; - if (FFlag::LuauStacklessTypeClone3) - { - TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; - return cloner.clone(typeId); - } - else - { - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypeId& res = cloneState.seenTypes[typeId]; - - if (res == nullptr) - { - TypeCloner cloner{dest, typeId, cloneState}; - Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - - // Persistent types are not being cloned and we get the original type back which might be read-only - if (!res->persistent) - { - asMutable(res)->documentationSymbol = typeId->documentationSymbol; - } - } - - return res; - } + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.clone(typeId); } TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) { - if (FFlag::LuauStacklessTypeClone3) + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + + TypeFun copy = typeFun; + + for (auto& param : copy.typeParams) { - TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + param.ty = cloner.clone(param.ty); - TypeFun copy = typeFun; - - for (auto& param : copy.typeParams) - { - param.ty = cloner.clone(param.ty); - - if (param.defaultValue) - param.defaultValue = cloner.clone(*param.defaultValue); - } - - for (auto& param : copy.typePackParams) - { - param.tp = cloner.clone(param.tp); - - if (param.defaultValue) - param.defaultValue = cloner.clone(*param.defaultValue); - } - - copy.type = cloner.clone(copy.type); - - return copy; + if (param.defaultValue) + param.defaultValue = cloner.clone(*param.defaultValue); } - else + + for (auto& param : copy.typePackParams) { - TypeFun result; + param.tp = cloner.clone(param.tp); - for (auto param : typeFun.typeParams) - { - TypeId ty = clone(param.ty, dest, cloneState); - std::optional defaultValue; - - if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, cloneState); - - result.typeParams.push_back({ty, defaultValue}); - } - - for (auto param : typeFun.typePackParams) - { - TypePackId tp = clone(param.tp, dest, cloneState); - std::optional defaultValue; - - if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, cloneState); - - result.typePackParams.push_back({tp, defaultValue}); - } - - result.type = clone(typeFun.type, dest, cloneState); - - return result; + if (param.defaultValue) + param.defaultValue = cloner.clone(*param.defaultValue); } + + copy.type = cloner.clone(copy.type); + + return copy; } } // namespace Luau diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 4a4c8c3c..91be2a3a 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -8,15 +8,16 @@ #include "Luau/ControlFlow.h" #include "Luau/DcrLogger.h" #include "Luau/DenseHash.h" -#include "Luau/InsertionOrderedMap.h" #include "Luau/ModuleResolver.h" #include "Luau/RecursionCounter.h" #include "Luau/Refinement.h" #include "Luau/Scope.h" #include "Luau/Simplify.h" +#include "Luau/StringUtils.h" #include "Luau/TableLiteralInference.h" #include "Luau/Type.h" #include "Luau/TypeFamily.h" +#include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" #include "Luau/VisitType.h" @@ -27,7 +28,6 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); -LUAU_FASTFLAG(LuauLoopControlFlowAnalysis); namespace Luau { @@ -641,9 +641,9 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStat* stat) else if (auto s = stat->as()) return visit(scope, s); else if (stat->is()) - return FFlag::LuauLoopControlFlowAnalysis ? ControlFlow::Breaks : ControlFlow::None; + return ControlFlow::Breaks; else if (stat->is()) - return FFlag::LuauLoopControlFlowAnalysis ? ControlFlow::Continues : ControlFlow::None; + return ControlFlow::Continues; else if (auto r = stat->as()) return visit(scope, r); else if (auto e = stat->as()) @@ -989,9 +989,9 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f else scope->bindings[localName->local] = Binding{generalizedType, localName->location}; - sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; - sig.bodyScope->lvalueTypes[def] = sig.signature; - sig.bodyScope->rvalueRefinements[def] = sig.signature; + scope->bindings[localName->local] = Binding{sig.signature, localName->location}; + scope->lvalueTypes[def] = sig.signature; + scope->rvalueRefinements[def] = sig.signature; } else if (AstExprGlobal* globalName = function->name->as()) { @@ -1001,9 +1001,9 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f if (!sigFullyDefined) generalizedType = *existingFunctionTy; - sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; - sig.bodyScope->lvalueTypes[def] = sig.signature; - sig.bodyScope->rvalueRefinements[def] = sig.signature; + scope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; + scope->lvalueTypes[def] = sig.signature; + scope->rvalueRefinements[def] = sig.signature; } else if (AstExprIndexName* indexName = function->name->as()) { @@ -1121,53 +1121,11 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* ass Checkpoint lvalueBeginCheckpoint = checkpoint(this); - size_t i = 0; for (AstExpr* lvalue : assign->vars) { - // This is a really weird thing to do, but it's critically important for some kinds of - // assignments with the current type state behavior. Consider this code: - // local function f(l, r) - // local i = l - // for _ = l, r do - // i = i + 1 - // end - // end - // - // With type states now, we will not create a new state for `i` within the loop. This means - // that, in the absence of the analysis below, we would infer a too-broad bound for i: the - // cyclic type t1 where t1 = add. In order to stop this, we say that - // assignments to a definition with a self-referential binary expression do not transform - // the type of the definition. This will only apply for loops, where the definition is - // shared in more places; for non-loops, there will be a separate DefId for the lvalue in - // the assignment, so we will deem the expression to be transformative. - // - // Deeming the addition in the code sample above as non-transformative means that i is known - // to be exactly number further on, ensuring the type family reduces down to number, as is - // expected for this code snippet. - // - // There is a potential for spurious errors here if the expression is more complex than a - // simple binary expression, e.g. i = (i + 1) * 2. At the time of writing, this case hasn't - // materialized. - bool transform = true; - - if (assign->values.size > i) - { - AstExpr* value = assign->values.data[i]; - if (auto bexp = value->as()) - { - DefId lvalueDef = dfg->getDef(lvalue); - DefId lDef = dfg->getDef(bexp->left); - DefId rDef = dfg->getDef(bexp->right); - - if (lvalueDef == lDef || lvalueDef == rDef) - transform = false; - } - } - - auto [upperBound, typeState] = checkLValue(scope, lvalue, transform); + auto [upperBound, typeState] = checkLValue(scope, lvalue); upperBounds.push_back(upperBound.value_or(builtinTypes->unknownType)); typeStates.push_back(typeState.value_or(builtinTypes->unknownType)); - ++i; } Checkpoint lvalueEndCheckpoint = checkpoint(this); @@ -1196,7 +1154,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatCompoundAss AstExprBinary binop = AstExprBinary{assign->location, assign->op, assign->var, assign->value}; TypeId resultTy = check(scope, &binop).ty; - auto [upperBound, typeState] = checkLValue(scope, assign->var, true); + auto [upperBound, typeState] = checkLValue(scope, assign->var); Constraint* sc = nullptr; if (upperBound) @@ -1246,7 +1204,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatIf* ifState if (elsecf == ControlFlow::None) scope->inheritAssignments(elseScope); - if (FFlag::LuauLoopControlFlowAnalysis && thencf == elsecf) + if (thencf == elsecf) return thencf; else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) return ControlFlow::Returns; @@ -1254,25 +1212,6 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatIf* ifState return ControlFlow::None; } -static bool occursCheck(TypeId needle, TypeId haystack) -{ - LUAU_ASSERT(get(needle)); - haystack = follow(haystack); - - auto checkHaystack = [needle](TypeId haystack) { - return occursCheck(needle, haystack); - }; - - if (needle == haystack) - return true; - else if (auto ut = get(haystack)) - return std::any_of(begin(ut), end(ut), checkHaystack); - else if (auto it = get(haystack)) - return std::any_of(begin(it), end(it), checkHaystack); - - return false; -} - ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias* alias) { if (alias->name == kParseNameError) @@ -1298,11 +1237,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias* if (bindingIt == typeBindings->end() || defnScope == nullptr) return ControlFlow::None; - TypeId ty = resolveType(*defnScope, alias->type, /* inTypeArguments */ false); + TypeId ty = resolveType(*defnScope, alias->type, /* inTypeArguments */ false, /* replaceErrorWithFresh */ false); TypeId aliasTy = bindingIt->second.type; LUAU_ASSERT(get(aliasTy)); - if (occursCheck(aliasTy, ty)) { asMutable(aliasTy)->ty.emplace(builtinTypes->anyType); @@ -2377,10 +2315,10 @@ std::tuple ConstraintGenerator::checkBinary( } } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExpr* expr, bool transform) +ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExpr* expr) { if (auto local = expr->as()) - return checkLValue(scope, local, transform); + return checkLValue(scope, local); else if (auto global = expr->as()) return checkLValue(scope, global); else if (auto indexName = expr->as()) @@ -2396,7 +2334,7 @@ ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePt ice->ice("checkLValue is inexhaustive"); } -ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExprLocal* local, bool transform) +ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePtr& scope, AstExprLocal* local) { std::optional annotatedTy = scope->lookup(local->local); LUAU_ASSERT(annotatedTy); @@ -2406,16 +2344,13 @@ ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePt if (ty) { - if (transform) + if (auto lt = getMutable(*ty)) + ++lt->blockCount; + else if (auto ut = getMutable(*ty)) { - if (auto lt = getMutable(*ty)) - ++lt->blockCount; - else if (auto ut = getMutable(*ty)) - { - for (TypeId optTy : ut->options) - if (auto lt = getMutable(optTy)) - ++lt->blockCount; - } + for (TypeId optTy : ut->options) + if (auto lt = getMutable(optTy)) + ++lt->blockCount; } } else @@ -2441,27 +2376,22 @@ ConstraintGenerator::LValueBounds ConstraintGenerator::checkLValue(const ScopePt } // TODO: Need to clip this, but this requires more code to be reworked first before we can clip this. - std::optional assignedTy; + std::optional assignedTy = arena->addType(BlockedType{}); - if (transform) + auto unpackC = addConstraint(scope, local->location, + UnpackConstraint{arena->addTypePack({*ty}), arena->addTypePack({*assignedTy}), + /*resultIsLValue*/ true}); + + if (auto blocked = get(*ty)) { - assignedTy = arena->addType(BlockedType{}); - - auto unpackC = addConstraint(scope, local->location, - UnpackConstraint{arena->addTypePack({*ty}), arena->addTypePack({*assignedTy}), - /*resultIsLValue*/ true}); - - if (auto blocked = get(*ty)) - { - if (blocked->getOwner()) - unpackC->dependencies.push_back(NotNull{blocked->getOwner()}); - else if (auto blocked = getMutable(*ty)) - blocked->setOwner(unpackC); - } - - recordInferredBinding(local->local, *ty); + if (blocked->getOwner()) + unpackC->dependencies.push_back(NotNull{blocked->getOwner()}); + else if (auto blocked = getMutable(*ty)) + blocked->setOwner(unpackC); } + recordInferredBinding(local->local, *ty); + return {annotatedTy, assignedTy}; } @@ -2518,7 +2448,8 @@ ConstraintGenerator::LValueBounds ConstraintGenerator::updateProperty(const Scop TypeId subjectType = check(scope, indexExpr->expr).ty; TypeId indexType = check(scope, indexExpr->index).ty; TypeId assignedTy = arena->addType(BlockedType{}); - addConstraint(scope, expr->location, SetIndexerConstraint{subjectType, indexType, assignedTy}); + auto sic = addConstraint(scope, expr->location, SetIndexerConstraint{subjectType, indexType, assignedTy}); + getMutable(assignedTy)->setOwner(sic); module->astTypes[expr] = assignedTy; @@ -2696,7 +2627,9 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, if (expectedType) { Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice}; - matchLiteralType(NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr); + std::vector toBlock; + matchLiteralType( + NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr, toBlock); } return Inference{ty}; @@ -3472,7 +3405,6 @@ TypeId ConstraintGenerator::createFamilyInstance(TypeFamilyInstanceType instance { TypeId result = arena->addType(std::move(instance)); addConstraint(scope, location, ReduceConstraint{result}); - familyInstances.push_back(result); return result; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 7d337c19..948722c6 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -276,7 +276,6 @@ struct InstantiationQueuer : TypeOnceVisitor bool visit(TypeId ty, const TypeFamilyInstanceType&) override { solver->pushConstraint(scope, location, ReduceConstraint{ty}); - solver->familyInstances.insert(ty); return true; } @@ -455,16 +454,6 @@ void ConstraintSolver::run() progress |= runSolverPass(true); } while (progress); - for (TypeId instance : familyInstances) - { - if (FFlag::DebugLuauLogSolver) - printf("Post-solve family reduction of %s\n", toString(instance).c_str()); - - TypeCheckLimits limits{}; - FamilyGraphReductionResult result = - reduceFamilies(instance, Location{}, TypeFamilyContext{arena, builtinTypes, rootScope, normalizer, NotNull{&iceReporter}, NotNull{&limits}}, false); - } - if (FFlag::DebugLuauLogSolver || FFlag::DebugLuauLogBindings) dumpBindings(rootScope, opts); @@ -843,6 +832,18 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } + // Due to how pending expansion types and TypeFun's are created + // If this check passes, we have created a cyclic / corecursive type alias + // of size 0 + TypeId lhs = c.target; + TypeId rhs = tf->type; + if (occursCheck(lhs, rhs)) + { + reportError(OccursCheckFailed{}, constraint->location); + bindResult(errorRecoveryType()); + return true; + } + auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); bool sameTypes = std::equal(typeArguments.begin(), typeArguments.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& p) { @@ -1106,8 +1107,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull subst = instantiation.substitute(result); + std::optional subst = instantiate2(arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions), result); if (!subst) { reportError(CodeTooComplex{}, constraint->location); @@ -1183,6 +1183,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNull res = replacer.substitute(fn); if (res) { + if (*res != fn) + { + FunctionType* ftvMut = getMutable(*res); + LUAU_ASSERT(ftvMut); + ftvMut->generics.clear(); + ftvMut->genericPacks.clear(); + } + fn = *res; ftv = get(*res); LUAU_ASSERT(ftv); @@ -1233,7 +1241,12 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullis()) { Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; - (void) matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr); + std::vector toBlock; + (void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr, toBlock); + for (auto t : toBlock) + block(t, constraint); + if (!toBlock.empty()) + return false; } } @@ -1449,13 +1462,17 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull constraint, TypeId subjectType, TypeId indexType, TypeId resultType) +bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNull constraint, TypeId subjectType, TypeId indexType, TypeId resultType, Set& seen) { RecursionLimiter _rl{&recursionDepth, FInt::LuauSolverRecursionLimit}; subjectType = follow(subjectType); indexType = follow(indexType); + if (seen.contains(subjectType)) + return false; + seen.insert(subjectType); + LUAU_ASSERT(get(resultType)); LUAU_ASSERT(canMutate(resultType, constraint)); @@ -1496,7 +1513,7 @@ bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNull(subjectType)) - return tryDispatchHasIndexer(recursionDepth, constraint, mt->table, indexType, resultType); + return tryDispatchHasIndexer(recursionDepth, constraint, mt->table, indexType, resultType, seen); else if (auto ct = get(subjectType)) { if (auto indexer = ct->indexer) @@ -1531,10 +1548,10 @@ bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNulladdType(BlockedType{}); getMutable(r)->setOwner(const_cast(constraint.get())); - bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r); - // FIXME: It's too late to stop and block now I think? We should - // scan for blocked types before we actually do anything. - LUAU_ASSERT(ok); + bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); + // If we've cut a recursive loop short, skip it. + if (!ok) + continue; r = follow(r); if (!get(r)) @@ -1563,9 +1580,10 @@ bool ConstraintSolver::tryDispatchHasIndexer(int& recursionDepth, NotNulladdType(BlockedType{}); getMutable(r)->setOwner(const_cast(constraint.get())); - bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r); - // We should have found all the blocked types ahead of time (see BlockedTypeFinder below) - LUAU_ASSERT(ok); + bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); + // If we've cut a recursive loop short, skip it. + if (!ok) + continue; r = follow(r); if (!get(r)) @@ -1628,7 +1646,9 @@ bool ConstraintSolver::tryDispatch(const HasIndexerConstraint& c, NotNull seen{nullptr}; + + return tryDispatchHasIndexer(recursionDepth, constraint, subjectType, indexType, c.resultType, seen); } std::pair ConstraintSolver::tryDispatchSetIndexer(NotNull constraint, TypeId subjectType, TypeId indexType, TypeId propType, bool expandFreeTypeBounds) @@ -1641,24 +1661,16 @@ std::pair ConstraintSolver::tryDispatchSetIndexer(NotNullindexer) { unify(constraint, indexType, tt->indexer->indexType); - - // We have a `BoundType` check here because we must mutate only our owning `BlockedType`, not some other constraint's `BlockedType`. - // TODO: We should rather have a `bool mutateProp` parameter that is set to false if we're traversing a union or intersection type. - // The union or intersection type themselves should be the one to mutate the `propType`, not each or first `TableType` in a union/intersection type. - // - // Fixing this requires fixing other ones first. - if (!get(propType) && get(propType)) - emplaceType(asMutable(propType), tt->indexer->indexResultType); + bindBlockedType(propType, tt->indexer->indexResultType, subjectType, constraint); return {true, true}; } else if (tt->state == TableState::Free || tt->state == TableState::Unsealed) { + bindBlockedType(propType, freshType(arena, builtinTypes, constraint->scope.get()), subjectType, constraint); tt->indexer = TableIndexer{indexType, propType}; return {true, true}; } - else - return {true, false}; } else if (auto ft = getMutable(subjectType); ft && expandFreeTypeBounds) { @@ -1669,6 +1681,10 @@ std::pair ConstraintSolver::tryDispatchSetIndexer(NotNullupperBound, indexType, propType, /*expandFreeTypeBounds=*/ false); if (dispatched && !found) { + // Despite that we haven't found a table type, adding a table type causes us to have one that we can /now/ find. + found = true; + bindBlockedType(propType, freshType(arena, builtinTypes, constraint->scope.get()), subjectType, constraint); + TypeId tableTy = arena->addType(TableType{TableState::Sealed, TypeLevel{}, constraint->scope.get()}); TableType* tt2 = getMutable(tableTy); tt2->indexer = TableIndexer{indexType, propType}; @@ -1690,6 +1706,11 @@ std::pair ConstraintSolver::tryDispatchSetIndexer(NotNull(subjectType) && expandFreeTypeBounds) + { + bindBlockedType(propType, subjectType, subjectType, constraint); + return {true, true}; + } return {true, false}; } @@ -1701,8 +1722,14 @@ bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNullerrorRecoveryType(), subjectType, constraint); + unblock(c.propType, constraint->location); + } + return dispatched; } @@ -1833,10 +1860,10 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullblockCount) asMutable(resultTy)->ty.emplace(lt->domain); } - else if (get(*resultIter) || get(*resultIter)) + else if (get(resultTy) || get(resultTy)) { - asMutable(*resultIter)->ty.emplace(builtinTypes->nilType); - unblock(*resultIter, constraint->location); + asMutable(resultTy)->ty.emplace(builtinTypes->nilType); + unblock(resultTy, constraint->location); } ++resultIter; @@ -2321,8 +2348,16 @@ std::pair, std::optional> ConstraintSolver::lookupTa { TypeId one = *begin(options); TypeId two = *(++begin(options)); + + // if we're in an lvalue context, we need the _common_ type here. + if (context == ValueContext::LValue) + return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + return {{}, simplifyUnion(builtinTypes, arena, one, two).result}; } + // if we're in an lvalue context, we need the _common_ type here. + else if (context == ValueContext::LValue) + return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; else return {{}, arena->addType(UnionType{std::vector(begin(options), end(options))})}; } diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index b67614c4..33b41698 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -6,12 +6,10 @@ #include "Luau/Common.h" #include "Luau/Error.h" -#include #include LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauLoopControlFlowAnalysis) namespace Luau { @@ -403,7 +401,7 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) else if ((thencf | elsecf) == ControlFlow::None) join(scope, thenScope, elseScope); - if (FFlag::LuauLoopControlFlowAnalysis && thencf == elsecf) + if (thencf == elsecf) return thencf; else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) return ControlFlow::Returns; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 68e732e6..dcd591b4 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -591,6 +591,25 @@ struct ErrorConverter { return "Encountered an unexpected type pack in subtyping: " + toString(e.tp); } + + std::string operator()(const CannotAssignToNever& e) const + { + std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never"; + + switch (e.reason) + { + case CannotAssignToNever::Reason::PropertyNarrowed: + if (!e.cause.empty()) + { + result += "\ncaused by the property being given the following incompatible types:\n"; + for (auto ty : e.cause) + result += " " + toString(ty) + "\n"; + result += "There are no values that could safely satisfy all of these types at once."; + } + } + + return result; + } }; struct InvalidNameChecker @@ -950,6 +969,20 @@ bool UnexpectedTypePackInSubtyping::operator==(const UnexpectedTypePackInSubtypi return tp == rhs.tp; } +bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const +{ + if (cause.size() != rhs.cause.size()) + return false; + + for (size_t i = 0; i < cause.size(); ++i) + { + if (*cause[i] != *rhs.cause[i]) + return false; + } + + return *rhsType == *rhs.rhsType && reason == rhs.reason; +} + std::string toString(const TypeError& error) { return toString(error, TypeErrorToStringOptions{}); @@ -1140,6 +1173,13 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState) e.ty = clone(e.ty); else if constexpr (std::is_same_v) e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + { + e.rhsType = clone(e.rhsType); + + for (auto& ty : e.cause) + ty = clone(ty); + } else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 1ac76835..55cff7f6 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1268,9 +1268,6 @@ ModulePtr check(const SourceModule& sourceModule, Mode mode, const std::vector(ft->lowerBound)) - return ft->upperBound; - // we default to the lower bound which represents the most specific type for the free type. - return ft->lowerBound; + TypeId res = get(ft->lowerBound) + ? ft->upperBound + : ft->lowerBound; + + // Instantiation should not traverse into the type that we are substituting for. + dontTraverseInto(res); + + return res; } TypePackId Instantiation2::clean(TypePackId tp) { - return genericPackSubstitutions[tp]; + TypePackId res = genericPackSubstitutions[tp]; + dontTraverseInto(res); + return res; +} + +std::optional instantiate2( + TypeArena* arena, DenseHashMap genericSubstitutions, DenseHashMap genericPackSubstitutions, TypeId ty) +{ + Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)}; + return instantiation.substitute(ty); +} + +std::optional instantiate2( + TypeArena* arena, DenseHashMap genericSubstitutions, DenseHashMap genericPackSubstitutions, TypePackId tp) +{ + Instantiation2 instantiation{arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions)}; + return instantiation.substitute(tp); } } // namespace Luau diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index dd392faa..59aa577e 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -225,10 +225,38 @@ static void errorToString(std::ostream& stream, const T& err) stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }"; else if constexpr (std::is_same_v) stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }"; + else if constexpr (std::is_same_v) + { + stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { "; + + bool first = true; + for (TypeId ty : err.cause) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << toString(ty) << "'"; + } + + stream << " } } "; + } else static_assert(always_false_v, "Non-exhaustive type switch"); } +std::ostream& operator<<(std::ostream& stream, const CannotAssignToNever::Reason& reason) +{ + switch (reason) + { + case CannotAssignToNever::Reason::PropertyNarrowed: + return stream << "PropertyNarrowed"; + default: + return stream << "UnknownReason"; + } +} + std::ostream& operator<<(std::ostream& stream, const TypeErrorData& data) { auto cb = [&](const auto& e) { diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 4a3d8b6b..ce95e635 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -16,16 +16,28 @@ #include "Luau/Unifier.h" LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) +LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false) +LUAU_FASTFLAGVARIABLE(LuauFixNormalizeCaching, false); // This could theoretically be 2000 on amd64, but x86 requires this. LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTFLAG(LuauTransitiveSubtyping) -LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + +static bool fixNormalizeCaching() +{ + return FFlag::LuauFixNormalizeCaching || FFlag::DebugLuauDeferredConstraintResolution; +} namespace Luau { +// helper to make `FFlag::LuauNormalizeAwayUninhabitableTables` not explicitly required when DCR is enabled. +static bool normalizeAwayUninhabitableTables() +{ + return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::DebugLuauDeferredConstraintResolution; +} + TypeIds::TypeIds(std::initializer_list tys) { for (TypeId ty : tys) @@ -528,8 +540,16 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set& seen) return isInhabited(mtv->metatable, seen); } - const NormalizedType* norm = normalize(ty); - return isInhabited(norm, seen); + if (fixNormalizeCaching()) + { + std::shared_ptr norm = normalize(ty); + return isInhabited(norm.get(), seen); + } + else + { + const NormalizedType* norm = DEPRECATED_normalize(ty); + return isInhabited(norm, seen); + } } NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right) @@ -829,7 +849,7 @@ Normalizer::Normalizer(TypeArena* arena, NotNull builtinTypes, Not { } -const NormalizedType* Normalizer::normalize(TypeId ty) +const NormalizedType* Normalizer::DEPRECATED_normalize(TypeId ty) { if (!arena) sharedState->iceHandler->ice("Normalizing types outside a module"); @@ -848,12 +868,102 @@ const NormalizedType* Normalizer::normalize(TypeId ty) clearNormal(norm); norm.tops = builtinTypes->unknownType; } - std::unique_ptr uniq = std::make_unique(std::move(norm)); - const NormalizedType* result = uniq.get(); - cachedNormals[ty] = std::move(uniq); + std::shared_ptr shared = std::make_shared(std::move(norm)); + const NormalizedType* result = shared.get(); + cachedNormals[ty] = std::move(shared); return result; } +static bool isCacheable(TypeId ty, Set& seen); + +static bool isCacheable(TypePackId tp, Set& seen) +{ + tp = follow(tp); + + auto it = begin(tp); + auto endIt = end(tp); + for (; it != endIt; ++it) + { + if (!isCacheable(*it, seen)) + return false; + } + + if (auto tail = it.tail()) + { + if (get(*tail) || get(*tail) || get(*tail)) + return false; + } + + return true; +} + +static bool isCacheable(TypeId ty, Set& seen) +{ + if (seen.contains(ty)) + return true; + seen.insert(ty); + + ty = follow(ty); + + if (get(ty) || get(ty) || get(ty)) + return false; + + if (auto tfi = get(ty)) + { + for (TypeId t: tfi->typeArguments) + { + if (!isCacheable(t, seen)) + return false; + } + + for (TypePackId tp: tfi->packArguments) + { + if (!isCacheable(tp, seen)) + return false; + } + } + + return true; +} + +static bool isCacheable(TypeId ty) +{ + if (!fixNormalizeCaching()) + return true; + + Set seen{nullptr}; + return isCacheable(ty, seen); +} + +std::shared_ptr Normalizer::normalize(TypeId ty) +{ + if (!arena) + sharedState->iceHandler->ice("Normalizing types outside a module"); + + auto found = cachedNormals.find(ty); + if (found != cachedNormals.end()) + return found->second; + + NormalizedType norm{builtinTypes}; + Set seenSetTypes{nullptr}; + NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes); + if (res != NormalizationResult::True) + return nullptr; + + if (norm.isUnknown()) + { + clearNormal(norm); + norm.tops = builtinTypes->unknownType; + } + + std::shared_ptr shared = std::make_shared(std::move(norm)); + + if (shared->isCacheable) + cachedNormals[ty] = shared; + + return shared; +} + NormalizationResult Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType) { if (!arena) @@ -1498,6 +1608,11 @@ void Normalizer::unionFunctionsWithFunction(NormalizedFunctionType& heres, TypeI void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) { // TODO: remove unions of tables where possible + + // we can always skip `never` + if (normalizeAwayUninhabitableTables() && get(there)) + return; + heres.insert(there); } @@ -1539,8 +1654,10 @@ void Normalizer::unionTables(TypeIds& heres, const TypeIds& theres) // That's what you get for having a type system with generics, intersection and union types. NormalizationResult Normalizer::unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { + here.isCacheable &= there.isCacheable; + TypeId tops = unionOfTops(here.tops, there.tops); - if (FFlag::LuauTransitiveSubtyping && get(tops) && (get(here.errors) || get(there.errors))) + if (get(tops) && (get(here.errors) || get(there.errors))) tops = builtinTypes->anyType; if (!get(tops)) { @@ -1617,17 +1734,15 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t if (get(there) || get(there)) { TypeId tops = unionOfTops(here.tops, there); - if (FFlag::LuauTransitiveSubtyping && get(tops) && get(here.errors)) + if (get(tops) && get(here.errors)) tops = builtinTypes->anyType; clearNormal(here); here.tops = tops; return NormalizationResult::True; } - else if (!FFlag::LuauTransitiveSubtyping && (get(there) || !get(here.tops))) + else if (get(there) || get(here.tops)) return NormalizationResult::True; - else if (FFlag::LuauTransitiveSubtyping && (get(there) || get(here.tops))) - return NormalizationResult::True; - else if (FFlag::LuauTransitiveSubtyping && get(there) && get(here.tops)) + else if (get(there) && get(here.tops)) { here.tops = builtinTypes->anyType; return NormalizationResult::True; @@ -1663,7 +1778,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t } return unionNormals(here, norm); } - else if (FFlag::LuauTransitiveSubtyping && get(here.tops)) + else if (get(here.tops)) return NormalizationResult::True; else if (get(there) || get(there) || get(there) || get(there) || get(there)) @@ -1673,6 +1788,9 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t NormalizedType inter{builtinTypes}; inter.tops = builtinTypes->unknownType; here.tyvars.insert_or_assign(there, std::make_unique(std::move(inter))); + + if (!isCacheable(there)) + here.isCacheable = false; } else if (auto lt = get(there)) { @@ -1734,8 +1852,19 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t } else if (const NegationType* ntv = get(there)) { - const NormalizedType* thereNormal = normalize(ntv->ty); - std::optional tn = negateNormal(*thereNormal); + std::optional tn; + + if (fixNormalizeCaching()) + { + std::shared_ptr thereNormal = normalize(ntv->ty); + tn = negateNormal(*thereNormal); + } + else + { + const NormalizedType* thereNormal = DEPRECATED_normalize(ntv->ty); + tn = negateNormal(*thereNormal); + } + if (!tn) return NormalizationResult::False; @@ -1766,6 +1895,8 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t std::optional Normalizer::negateNormal(const NormalizedType& here) { NormalizedType result{builtinTypes}; + result.isCacheable = here.isCacheable; + if (!get(here.tops)) { // The negation of unknown or any is never. Easy. @@ -2409,6 +2540,10 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there { if (tprop.readTy.has_value()) { + // if the intersection of the read types of a property is uninhabited, the whole table is `never`. + if (normalizeAwayUninhabitableTables() && NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy)) + return {builtinTypes->neverType}; + TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; prop.readTy = ty; hereSubThere &= (ty == hprop.readTy); @@ -2896,6 +3031,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type NormalizedType topNorm{builtinTypes}; topNorm.tops = builtinTypes->unknownType; thereNorm.tyvars.insert_or_assign(there, std::make_unique(std::move(topNorm))); + here.isCacheable = false; return intersectNormals(here, thereNorm); } else if (auto lt = get(there)) @@ -2990,22 +3126,61 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type subtractSingleton(here, follow(ntv->ty)); else if (get(t)) { - const NormalizedType* normal = normalize(t); - std::optional negated = negateNormal(*normal); - if (!negated) - return NormalizationResult::False; - intersectNormals(here, *negated); - } - else if (const UnionType* itv = get(t)) - { - for (TypeId part : itv->options) + if (fixNormalizeCaching()) { - const NormalizedType* normalPart = normalize(part); - std::optional negated = negateNormal(*normalPart); + std::shared_ptr normal = normalize(t); + std::optional negated = negateNormal(*normal); if (!negated) return NormalizationResult::False; intersectNormals(here, *negated); } + else + { + const NormalizedType* normal = DEPRECATED_normalize(t); + std::optional negated = negateNormal(*normal); + if (!negated) + return NormalizationResult::False; + intersectNormals(here, *negated); + } + } + else if (const UnionType* itv = get(t)) + { + if (fixNormalizeCaching()) + { + for (TypeId part : itv->options) + { + std::shared_ptr normalPart = normalize(part); + std::optional negated = negateNormal(*normalPart); + if (!negated) + return NormalizationResult::False; + intersectNormals(here, *negated); + } + } + else + { + if (fixNormalizeCaching()) + { + for (TypeId part : itv->options) + { + std::shared_ptr normalPart = normalize(part); + std::optional negated = negateNormal(*normalPart); + if (!negated) + return NormalizationResult::False; + intersectNormals(here, *negated); + } + } + else + { + for (TypeId part : itv->options) + { + const NormalizedType* normalPart = DEPRECATED_normalize(part); + std::optional negated = negateNormal(*normalPart); + if (!negated) + return NormalizationResult::False; + intersectNormals(here, *negated); + } + } + } } else if (get(t)) { @@ -3185,9 +3360,6 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { - if (!FFlag::LuauTransitiveSubtyping && !FFlag::DebugLuauDeferredConstraintResolution) - return isConsistentSubtype(subTy, superTy, scope, builtinTypes, ice); - UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; @@ -3210,9 +3382,6 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) { - if (!FFlag::LuauTransitiveSubtyping && !FFlag::DebugLuauDeferredConstraintResolution) - return isConsistentSubtype(subPack, superPack, scope, builtinTypes, ice); - UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index 8b15e919..d29546a2 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -1265,9 +1265,15 @@ TypeId TypeSimplifier::union_(TypeId left, TypeId right) if (!changed) return left; - if (1 == newParts.size()) + if (0 == newParts.size()) + { + // If the left-side is changed but has no parts, then the left-side union is uninhabited. + return right; + } + else if (1 == newParts.size()) return *begin(newParts); - return arena->addType(UnionType{std::vector{begin(newParts), end(newParts)}}); + else + return arena->addType(UnionType{std::vector{begin(newParts), end(newParts)}}); } else if (get(right)) return union_(right, left); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 0deba9bd..bc899798 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -10,7 +10,6 @@ LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauPreallocateTarjanVectors, false); LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256); namespace Luau @@ -150,14 +149,11 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a Tarjan::Tarjan() { - if (FFlag::LuauPreallocateTarjanVectors) - { - nodes.reserve(FInt::LuauTarjanPreallocationSize); - stack.reserve(FInt::LuauTarjanPreallocationSize); - edgesTy.reserve(FInt::LuauTarjanPreallocationSize); - edgesTp.reserve(FInt::LuauTarjanPreallocationSize); - worklist.reserve(FInt::LuauTarjanPreallocationSize); - } + nodes.reserve(FInt::LuauTarjanPreallocationSize); + stack.reserve(FInt::LuauTarjanPreallocationSize); + edgesTy.reserve(FInt::LuauTarjanPreallocationSize); + edgesTp.reserve(FInt::LuauTarjanPreallocationSize); + worklist.reserve(FInt::LuauTarjanPreallocationSize); } void Tarjan::visitChildren(TypeId ty, int index) @@ -529,6 +525,24 @@ TarjanResult Tarjan::findDirty(TypePackId tp) return visitRoot(tp); } +Substitution::Substitution(const TxnLog* log_, TypeArena* arena) + : arena(arena) +{ + log = log_; + LUAU_ASSERT(log); + LUAU_ASSERT(arena); +} + +void Substitution::dontTraverseInto(TypeId ty) +{ + noTraverseTypes.insert(ty); +} + +void Substitution::dontTraverseInto(TypePackId tp) +{ + noTraverseTypePacks.insert(tp); +} + std::optional Substitution::substitute(TypeId ty) { ty = log->follow(ty); @@ -544,7 +558,8 @@ std::optional Substitution::substitute(TypeId ty) { if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) { - replaceChildren(newTy); + if (!noTraverseTypes.contains(newTy)) + replaceChildren(newTy); replacedTypes.insert(newTy); } } @@ -552,7 +567,8 @@ std::optional Substitution::substitute(TypeId ty) { if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) { - replaceChildren(newTp); + if (!noTraverseTypePacks.contains(newTp)) + replaceChildren(newTp); replacedTypePacks.insert(newTp); } } @@ -575,7 +591,8 @@ std::optional Substitution::substitute(TypePackId tp) { if (!ignoreChildren(oldTy) && !replacedTypes.contains(newTy)) { - replaceChildren(newTy); + if (!noTraverseTypes.contains(newTy)) + replaceChildren(newTy); replacedTypes.insert(newTy); } } @@ -583,7 +600,8 @@ std::optional Substitution::substitute(TypePackId tp) { if (!ignoreChildren(oldTp) && !replacedTypePacks.contains(newTp)) { - replaceChildren(newTp); + if (!noTraverseTypePacks.contains(newTp)) + replaceChildren(newTp); replacedTypePacks.insert(newTp); } } diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index 7ec00266..ade4d2c9 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -346,9 +346,9 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy) TypeId lowerBound = makeAggregateType(lb, builtinTypes->neverType); TypeId upperBound = makeAggregateType(ub, builtinTypes->unknownType); - const NormalizedType* nt = normalizer->normalize(upperBound); + std::shared_ptr nt = normalizer->normalize(upperBound); // we say that the result is true if normalization failed because complex types are likely to be inhabited. - NormalizationResult res = nt ? normalizer->isInhabited(nt) : NormalizationResult::True; + NormalizationResult res = nt ? normalizer->isInhabited(nt.get()) : NormalizationResult::True; if (!nt || res == NormalizationResult::HitLimits) result.normalizationTooComplex = true; @@ -1421,7 +1421,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Prop return res; } -SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const NormalizedType* subNorm, const NormalizedType* superNorm) +SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const std::shared_ptr& subNorm, const std::shared_ptr& superNorm) { if (!subNorm || !superNorm) return {false, true}; diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp index b23f614a..b93bdfd2 100644 --- a/Analysis/src/TableLiteralInference.cpp +++ b/Analysis/src/TableLiteralInference.cpp @@ -116,16 +116,9 @@ static std::optional extractMatchingTableType(std::vector& table return std::nullopt; } -TypeId matchLiteralType( - NotNull> astTypes, - NotNull> astExpectedTypes, - NotNull builtinTypes, - NotNull arena, - NotNull unifier, - TypeId expectedType, - TypeId exprType, - const AstExpr* expr -) +TypeId matchLiteralType(NotNull> astTypes, NotNull> astExpectedTypes, + NotNull builtinTypes, NotNull arena, NotNull unifier, TypeId expectedType, TypeId exprType, + const AstExpr* expr, std::vector& toBlock) { /* * Table types that arise from literal table expressions have some @@ -244,7 +237,7 @@ TypeId matchLiteralType( if (tt) { - TypeId res = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *tt, exprType, expr); + TypeId res = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *tt, exprType, expr, toBlock); parts.push_back(res); return arena->addType(UnionType{std::move(parts)}); @@ -281,7 +274,8 @@ TypeId matchLiteralType( (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; (*astExpectedTypes)[item.value] = expectedTableTy->indexer->indexResultType; - TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, expectedTableTy->indexer->indexResultType, propTy, item.value); + TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, + expectedTableTy->indexer->indexResultType, propTy, item.value, toBlock); if (tableTy->indexer) unifier->unify(matchedType, tableTy->indexer->indexResultType); @@ -311,19 +305,22 @@ TypeId matchLiteralType( // quadratic in a hurry. if (expectedProp.isShared()) { - matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value); + matchedType = + matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value, toBlock); prop.readTy = matchedType; prop.writeTy = matchedType; } else if (expectedReadTy) { - matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value); + matchedType = + matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedReadTy, propTy, item.value, toBlock); prop.readTy = matchedType; prop.writeTy.reset(); } else if (expectedWriteTy) { - matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedWriteTy, propTy, item.value); + matchedType = + matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, *expectedWriteTy, propTy, item.value, toBlock); prop.readTy.reset(); prop.writeTy = matchedType; } @@ -351,14 +348,31 @@ TypeId matchLiteralType( LUAU_ASSERT(propTy); unifier->unify(expectedTableTy->indexer->indexType, builtinTypes->numberType); - TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, expectedTableTy->indexer->indexResultType, *propTy, item.value); + TypeId matchedType = matchLiteralType(astTypes, astExpectedTypes, builtinTypes, arena, unifier, + expectedTableTy->indexer->indexResultType, *propTy, item.value, toBlock); tableTy->indexer->indexResultType = matchedType; } } else if (item.kind == AstExprTable::Item::General) { - LUAU_ASSERT(!"TODO"); + + // We have { ..., [blocked] : somePropExpr, ...} + // If blocked resolves to a string, we will then take care of this above + // If it resolves to some other kind of expression, we don't have a way of folding this information into indexer + // because there is no named prop to remove + // We should just block here + const TypeId* keyTy = astTypes->find(item.key); + LUAU_ASSERT(keyTy); + TypeId tKey = follow(*keyTy); + if (get(tKey)) + toBlock.push_back(tKey); + + const TypeId* propTy = astTypes->find(item.value); + LUAU_ASSERT(propTy); + TypeId tProp = follow(*propTy); + if (get(tProp)) + toBlock.push_back(tProp); } else LUAU_ASSERT(!"Unexpected"); diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index c4241711..9093b38a 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -4,6 +4,7 @@ #include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/Type.h" +#include "Luau/TypeFamily.h" #include "Luau/StringUtils.h" #include @@ -352,9 +353,15 @@ void StateDot::visitChildren(TypeId ty, int index) } else if constexpr (std::is_same_v) { - formatAppend(result, "TypeFamilyInstanceType %d", index); + formatAppend(result, "TypeFamilyInstanceType %s %d", t.family->name.c_str(), index); finishNodeLabel(ty); finishNode(); + + for (TypeId tyParam : t.typeArguments) + visitChild(tyParam, index); + + for (TypePackId tpParam : t.packArguments) + visitChild(tpParam, index); } else static_assert(always_false_v, "unknown type kind"); diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index 718e9e8f..7454be32 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -9,6 +9,7 @@ #include "Luau/RecursionCounter.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" +#include "Luau/TypeFamily.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/VecDeque.h" @@ -422,6 +423,9 @@ bool maybeSingleton(TypeId ty) for (TypeId part : itv) if (maybeSingleton(part)) // will i regret this? return true; + if (const TypeFamilyInstanceType* tfit = get(ty)) + if (tfit->family->name == "keyof" || tfit->family->name == "rawkeyof") + return true; return false; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index c82bb6e9..cfb49f21 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -916,7 +916,7 @@ struct TypeChecker2 } }; - const NormalizedType* iteratorNorm = normalizer.normalize(iteratorTy); + std::shared_ptr iteratorNorm = normalizer.normalize(iteratorTy); if (!iteratorNorm) reportError(NormalizationTooComplex{}, firstValue->location); @@ -1042,6 +1042,37 @@ struct TypeChecker2 return std::nullopt; } + // this should only be called if the type of `lhs` is `never`. + void reportErrorsFromAssigningToNever(AstExpr* lhs, TypeId rhsType) + { + + if (auto indexName = lhs->as()) + { + TypeId indexedType = lookupType(indexName->expr); + + // if it's already never, I don't think we have anything to do here. + if (get(indexedType)) + return; + + std::string prop = indexName->index.value; + + std::shared_ptr norm = normalizer.normalize(indexedType); + if (!norm) + { + reportError(NormalizationTooComplex{}, lhs->location); + return; + } + + // if the type is error suppressing, we don't actually have any work left to do. + if (norm->shouldSuppressErrors()) + return; + + const auto propTypes = lookupProp(norm.get(), prop, ValueContext::LValue, lhs->location, builtinTypes->stringType, module->errors); + + reportError(CannotAssignToNever{rhsType, propTypes.typesOfProp, CannotAssignToNever::Reason::PropertyNarrowed}, lhs->location); + } + } + void visit(AstStatAssign* assign) { size_t count = std::min(assign->vars.size, assign->values.size); @@ -1057,7 +1088,10 @@ struct TypeChecker2 TypeId rhsType = lookupType(rhs); if (get(lhsType)) + { + reportErrorsFromAssigningToNever(lhs, rhsType); continue; + } bool ok = testIsSubtype(rhsType, lhsType, rhs->location); @@ -1352,7 +1386,7 @@ struct TypeChecker2 auto norm = normalizer.normalize(fnTy); if (!norm) reportError(NormalizationTooComplex{}, call->func->location); - auto isInhabited = normalizer.isInhabited(norm); + auto isInhabited = normalizer.isInhabited(norm.get()); if (isInhabited == NormalizationResult::HitLimits) reportError(NormalizationTooComplex{}, call->func->location); @@ -1554,7 +1588,7 @@ struct TypeChecker2 TypeId inferredFnTy = lookupType(fn); functionDeclStack.push_back(inferredFnTy); - const NormalizedType* normalizedFnTy = normalizer.normalize(inferredFnTy); + std::shared_ptr normalizedFnTy = normalizer.normalize(inferredFnTy); const FunctionType* inferredFtv = get(normalizedFnTy->functions.parts.front()); if (!normalizedFnTy) { @@ -1731,7 +1765,7 @@ struct TypeChecker2 { DenseHashSet seen{nullptr}; int recursionCount = 0; - const NormalizedType* nty = normalizer.normalize(operandType); + std::shared_ptr nty = normalizer.normalize(operandType); if (nty && nty->shouldSuppressErrors()) return; @@ -1783,8 +1817,8 @@ struct TypeChecker2 leftType = stripNil(builtinTypes, module->internalTypes, leftType); } - const NormalizedType* normLeft = normalizer.normalize(leftType); - const NormalizedType* normRight = normalizer.normalize(rightType); + std::shared_ptr normLeft = normalizer.normalize(leftType); + std::shared_ptr normRight = normalizer.normalize(rightType); bool isStringOperation = (normLeft ? normLeft->isSubtypeOfString() : isString(leftType)) && (normRight ? normRight->isSubtypeOfString() : isString(rightType)); @@ -2584,6 +2618,30 @@ struct TypeChecker2 reportError(std::move(e)); } + struct PropertyTypes + { + // a vector of all the types assigned to the given property. + std::vector typesOfProp; + + // a vector of all the types that are missing the given property. + std::vector missingProp; + + bool foundOneProp() const + { + return !typesOfProp.empty(); + } + + bool noneMissingProp() const + { + return missingProp.empty(); + } + + bool foundMissingProp() const + { + return !missingProp.empty(); + } + }; + /* A helper for checkIndexTypeFromType. * * Returns a pair: @@ -2591,10 +2649,10 @@ struct TypeChecker2 * contains the prop, and * * A vector of types that do not contain the prop. */ - std::pair> lookupProp(const NormalizedType* norm, const std::string& prop, ValueContext context, - const Location& location, TypeId astIndexExprType, std::vector& errors) + PropertyTypes lookupProp(const NormalizedType* norm, const std::string& prop, ValueContext context, const Location& location, + TypeId astIndexExprType, std::vector& errors) { - bool foundOneProp = false; + std::vector typesOfProp; std::vector typesMissingTheProp; // this is `false` if we ever hit the resource limits during any of our uses of `fetch`. @@ -2608,16 +2666,18 @@ struct TypeChecker2 return; DenseHashSet seen{nullptr}; - NormalizationResult found = hasIndexTypeFromType(ty, prop, context, location, seen, astIndexExprType, errors); + PropertyType res = hasIndexTypeFromType(ty, prop, context, location, seen, astIndexExprType, errors); - if (found == NormalizationResult::HitLimits) + if (res.present == NormalizationResult::HitLimits) { normValid = false; return; } - foundOneProp |= found == NormalizationResult::True; - if (found == NormalizationResult::False) + if (res.present == NormalizationResult::True && res.result) + typesOfProp.emplace_back(*res.result); + + if (res.present == NormalizationResult::False) typesMissingTheProp.push_back(ty); }; @@ -2631,6 +2691,9 @@ struct TypeChecker2 for (const auto& [ty, _negations] : norm->classes.classes) { fetch(ty); + + if (!normValid) + break; } } @@ -2644,9 +2707,18 @@ struct TypeChecker2 fetch(builtinTypes->stringType); if (normValid) fetch(norm->threads); - for (TypeId ty : norm->tables) - if (normValid) + + if (normValid) + { + for (TypeId ty : norm->tables) + { fetch(ty); + + if (!normValid) + break; + } + } + if (normValid && norm->functions.isTop) fetch(builtinTypes->functionType); else if (normValid && !norm->functions.isNever()) @@ -2672,16 +2744,19 @@ struct TypeChecker2 } else fetch(tyvar); + + if (!normValid) + break; } } - return {foundOneProp, typesMissingTheProp}; + return {typesOfProp, typesMissingTheProp}; } // If the provided type does not have the named property, report an error. void checkIndexTypeFromType(TypeId tableTy, const std::string& prop, ValueContext context, const Location& location, TypeId astIndexExprType) { - const NormalizedType* norm = normalizer.normalize(tableTy); + std::shared_ptr norm = normalizer.normalize(tableTy); if (!norm) { reportError(NormalizationTooComplex{}, location); @@ -2693,20 +2768,20 @@ struct TypeChecker2 return; std::vector dummy; - const auto [foundOneProp, typesMissingTheProp] = lookupProp(norm, prop, context, location, astIndexExprType, module->errors); + const auto propTypes = lookupProp(norm.get(), prop, context, location, astIndexExprType, module->errors); - if (!typesMissingTheProp.empty()) + if (propTypes.foundMissingProp()) { - if (foundOneProp) - reportError(MissingUnionProperty{tableTy, typesMissingTheProp, prop}, location); + if (propTypes.foundOneProp()) + reportError(MissingUnionProperty{tableTy, propTypes.missingProp, prop}, location); // For class LValues, we don't want to report an extension error, // because classes come into being with full knowledge of their // shape. We instead want to report the unknown property error of // the `else` branch. else if (context == ValueContext::LValue && !get(tableTy)) { - const auto [lvFoundOneProp, lvTypesMissingTheProp] = lookupProp(norm, prop, ValueContext::RValue, location, astIndexExprType, dummy); - if (lvFoundOneProp && lvTypesMissingTheProp.empty()) + const auto lvPropTypes = lookupProp(norm.get(), prop, ValueContext::RValue, location, astIndexExprType, dummy); + if (lvPropTypes.foundOneProp() && lvPropTypes.noneMissingProp()) reportError(PropertyAccessViolation{tableTy, prop, PropertyAccessViolation::CannotWrite}, location); else if (get(tableTy) || get(tableTy)) reportError(NotATable{tableTy}, location); @@ -2715,8 +2790,8 @@ struct TypeChecker2 } else if (context == ValueContext::RValue && !get(tableTy)) { - const auto [rvFoundOneProp, rvTypesMissingTheProp] = lookupProp(norm, prop, ValueContext::LValue, location, astIndexExprType, dummy); - if (rvFoundOneProp && rvTypesMissingTheProp.empty()) + const auto rvPropTypes = lookupProp(norm.get(), prop, ValueContext::LValue, location, astIndexExprType, dummy); + if (rvPropTypes.foundOneProp() && rvPropTypes.noneMissingProp()) reportError(PropertyAccessViolation{tableTy, prop, PropertyAccessViolation::CannotRead}, location); else reportError(UnknownProperty{tableTy, prop}, location); @@ -2726,18 +2801,24 @@ struct TypeChecker2 } } - NormalizationResult hasIndexTypeFromType(TypeId ty, const std::string& prop, ValueContext context, const Location& location, - DenseHashSet& seen, TypeId astIndexExprType, std::vector& errors) + struct PropertyType + { + NormalizationResult present; + std::optional result; + }; + + PropertyType hasIndexTypeFromType(TypeId ty, const std::string& prop, ValueContext context, const Location& location, DenseHashSet& seen, + TypeId astIndexExprType, std::vector& errors) { // If we have already encountered this type, we must assume that some // other codepath will do the right thing and signal false if the // property is not present. if (seen.contains(ty)) - return NormalizationResult::True; + return {NormalizationResult::True, {}}; seen.insert(ty); if (get(ty) || get(ty) || get(ty)) - return NormalizationResult::True; + return {NormalizationResult::True, {ty}}; if (isString(ty)) { @@ -2748,24 +2829,24 @@ struct TypeChecker2 if (auto tt = getTableType(ty)) { - if (findTablePropertyRespectingMeta(builtinTypes, errors, ty, prop, context, location)) - return NormalizationResult::True; + if (auto resTy = findTablePropertyRespectingMeta(builtinTypes, errors, ty, prop, context, location)) + return {NormalizationResult::True, resTy}; if (tt->indexer) { TypeId indexType = follow(tt->indexer->indexType); if (isPrim(indexType, PrimitiveType::String)) - return NormalizationResult::True; + return {NormalizationResult::True, {tt->indexer->indexResultType}}; // If the indexer looks like { [any] : _} - the prop lookup should be allowed! else if (get(indexType) || get(indexType)) - return NormalizationResult::True; + return {NormalizationResult::True, {tt->indexer->indexResultType}}; } // if we are in a conditional context, we treat the property as present and `unknown` because // we may be _refining_ `tableTy` to include that property. we will want to revisit this a bit // in the future once luau has support for exact tables since this only applies when inexact. - return inConditional(typeContext) ? NormalizationResult::True : NormalizationResult::False; + return {inConditional(typeContext) ? NormalizationResult::True : NormalizationResult::False, {builtinTypes->unknownType}}; } else if (const ClassType* cls = get(ty)) { @@ -2774,40 +2855,51 @@ struct TypeChecker2 // is compatible with the indexer's indexType // Construct the intersection and test inhabitedness! if (auto property = lookupClassProp(cls, prop)) - return NormalizationResult::True; + return {NormalizationResult::True, context == ValueContext::LValue ? property->writeTy : property->readTy}; if (cls->indexer) { TypeId inhabitatedTestType = module->internalTypes.addType(IntersectionType{{cls->indexer->indexType, astIndexExprType}}); - return normalizer.isInhabited(inhabitatedTestType); + return {normalizer.isInhabited(inhabitatedTestType), {cls->indexer->indexResultType}}; } - return NormalizationResult::False; + return {NormalizationResult::False, {}}; } else if (const UnionType* utv = get(ty)) { + std::vector parts; + parts.reserve(utv->options.size()); + for (TypeId part : utv) { - NormalizationResult result = hasIndexTypeFromType(part, prop, context, location, seen, astIndexExprType, errors); - if (result != NormalizationResult::True) - return result; + PropertyType result = hasIndexTypeFromType(part, prop, context, location, seen, astIndexExprType, errors); + if (result.present != NormalizationResult::True) + return {result.present, {}}; + if (result.result) + parts.emplace_back(*result.result); } - return NormalizationResult::True; + TypeId propTy; + if (context == ValueContext::LValue) + module->internalTypes.addType(IntersectionType{parts}); + else + module->internalTypes.addType(UnionType{parts}); + + return {NormalizationResult::True, propTy}; } else if (const IntersectionType* itv = get(ty)) { for (TypeId part : itv) { - NormalizationResult result = hasIndexTypeFromType(part, prop, context, location, seen, astIndexExprType, errors); - if (result != NormalizationResult::False) + PropertyType result = hasIndexTypeFromType(part, prop, context, location, seen, astIndexExprType, errors); + if (result.present != NormalizationResult::False) return result; } - return NormalizationResult::False; + return {NormalizationResult::False, {}}; } else if (const PrimitiveType* pt = get(ty)) - return (inConditional(typeContext) && pt->type == PrimitiveType::Table) ? NormalizationResult::True : NormalizationResult::False; + return {(inConditional(typeContext) && pt->type == PrimitiveType::Table) ? NormalizationResult::True : NormalizationResult::False, {ty}}; else - return NormalizationResult::False; + return {NormalizationResult::False, {}}; } void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const diff --git a/Analysis/src/TypeFamily.cpp b/Analysis/src/TypeFamily.cpp index f1ee929a..59302085 100644 --- a/Analysis/src/TypeFamily.cpp +++ b/Analysis/src/TypeFamily.cpp @@ -24,9 +24,15 @@ #include "Luau/VecDeque.h" #include "Luau/VisitType.h" +#include + // used to control emitting CodeTooComplex warnings on type family reduction LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); +// used to control the limits of type family application over union type arguments +// e.g. `mul` blows up into `mul | mul | mul | mul` +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'000); + // used to control falling back to a more conservative reduction based on guessing // when this value is set to a negative value, guessing will be totally disabled. LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); @@ -339,7 +345,8 @@ struct FamilyReducer if (tryGuessing(subject)) return; - TypeFamilyReductionResult result = tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + TypeFamilyQueue queue{NotNull{&queuedTys}, NotNull{&queuedTps}}; + TypeFamilyReductionResult result = tfit->family->reducer(subject, NotNull{&queue}, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); handleFamilyReduction(subject, result); } } @@ -363,7 +370,8 @@ struct FamilyReducer if (tryGuessing(subject)) return; - TypeFamilyReductionResult result = tfit->family->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); + TypeFamilyQueue queue{NotNull{&queuedTys}, NotNull{&queuedTps}}; + TypeFamilyReductionResult result = tfit->family->reducer(subject, NotNull{&queue}, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); handleFamilyReduction(subject, result); } } @@ -436,13 +444,25 @@ FamilyGraphReductionResult reduceFamilies(TypePackId entrypoint, Location locati return reduceFamiliesInternal(std::move(collector.tys), std::move(collector.tps), std::move(collector.shouldGuess), std::move(collector.cyclicInstance), location, ctx, force); } +void TypeFamilyQueue::add(TypeId instanceTy) +{ + LUAU_ASSERT(get(instanceTy)); + queuedTys->push_back(instanceTy); +} + +void TypeFamilyQueue::add(TypePackId instanceTp) +{ + LUAU_ASSERT(get(instanceTp)); + queuedTps->push_back(instanceTp); +} + bool isPending(TypeId ty, ConstraintSolver* solver) { return is(ty) || (solver && solver->hasUnresolvedConstraints(ty)); } TypeFamilyReductionResult notFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -460,7 +480,7 @@ TypeFamilyReductionResult notFamilyFn( } TypeFamilyReductionResult lenFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -475,7 +495,7 @@ TypeFamilyReductionResult lenFamilyFn( if (isPending(operandTy, ctx->solver) || get(operandTy)) return {std::nullopt, false, {operandTy}, {}}; - const NormalizedType* normTy = ctx->normalizer->normalize(operandTy); + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); // if the type failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy) @@ -536,7 +556,7 @@ TypeFamilyReductionResult lenFamilyFn( } TypeFamilyReductionResult unmFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -550,7 +570,7 @@ TypeFamilyReductionResult unmFamilyFn( if (isPending(operandTy, ctx->solver)) return {std::nullopt, false, {operandTy}, {}}; - const NormalizedType* normTy = ctx->normalizer->normalize(operandTy); + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy) @@ -609,6 +629,7 @@ TypeFamilyReductionResult unmFamilyFn( NotNull TypeFamilyContext::pushConstraint(ConstraintV&& c) { + LUAU_ASSERT(solver); NotNull newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c)); // Every constraint that is blocked on the current constraint must also be @@ -619,7 +640,7 @@ NotNull TypeFamilyContext::pushConstraint(ConstraintV&& c) return newConstraint; } -TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const std::vector& typeParams, +TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx, const std::string metamethod) { if (typeParams.size() != 2 || !packParams.empty()) @@ -631,6 +652,14 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const st TypeId lhsTy = follow(typeParams.at(0)); TypeId rhsTy = follow(typeParams.at(1)); + // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. + if (lhsTy == instance || rhsTy == instance) + return {ctx->builtins->neverType, false, {}, {}}; + + // if we have a `never`, we can never observe that the math operator is unreachable. + if (is(lhsTy) || is(rhsTy)) + return {ctx->builtins->neverType, false, {}, {}}; + const Location location = ctx->constraint ? ctx->constraint->location : Location{}; // check to see if both operand types are resolved enough, and wait to reduce if not @@ -639,8 +668,9 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const st else if (isPending(rhsTy, ctx->solver)) return {std::nullopt, false, {rhsTy}, {}}; - const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy); - const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy); + // TODO: Normalization needs to remove cyclic type families from a `NormalizedType`. + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy) @@ -650,14 +680,79 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const st if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) return {ctx->builtins->anyType, false, {}, {}}; - // if we have a `never`, we can never observe that the numeric operator didn't work. - if (is(lhsTy) || is(rhsTy)) - return {ctx->builtins->neverType, false, {}, {}}; - // if we're adding two `number` types, the result is `number`. if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) return {ctx->builtins->numberType, false, {}, {}}; + // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) + std::vector results; + bool uninhabited = false; + std::vector blockedTypes; + std::vector arguments = typeParams; + auto distributeFamilyApp = [&](const UnionType* ut, size_t argumentIndex) { + // Returning true here means we completed the loop without any problems. + for (TypeId option : ut) + { + arguments[argumentIndex] = option; + + TypeFamilyReductionResult result = numericBinopFamilyFn(instance, queue, arguments, packParams, ctx, metamethod); + blockedTypes.insert(blockedTypes.end(), result.blockedTypes.begin(), result.blockedTypes.end()); + uninhabited |= result.uninhabited; + + if (result.uninhabited) + return false; + else if (!result.result) + return false; + else + results.push_back(*result.result); + } + + return true; + }; + + const UnionType* lhsUnion = get(lhsTy); + const UnionType* rhsUnion = get(rhsTy); + if (lhsUnion || rhsUnion) + { + // TODO: We'd like to report that the type family application is too complex here. + size_t lhsUnionSize = lhsUnion ? std::distance(begin(lhsUnion), end(lhsUnion)) : 1; + size_t rhsUnionSize = rhsUnion ? std::distance(begin(rhsUnion), end(rhsUnion)) : 1; + if (size_t(DFInt::LuauTypeFamilyApplicationCartesianProductLimit) <= lhsUnionSize * rhsUnionSize) + return {std::nullopt, true, {}, {}}; + + if (lhsUnion && !distributeFamilyApp(lhsUnion, 0)) + return {std::nullopt, uninhabited, std::move(blockedTypes), {}}; + + if (rhsUnion && !distributeFamilyApp(rhsUnion, 1)) + return {std::nullopt, uninhabited, std::move(blockedTypes), {}}; + + if (results.empty()) + { + // If this happens, it means `distributeFamilyApp` has improperly returned `true` even + // though there exists no arm of the union that is inhabited or have a reduced type. + ctx->ice->ice("`distributeFamilyApp` failed to add any types to the results vector?"); + } + else if (results.size() == 1) + return {results[0], false, {}, {}}; + else if (results.size() == 2) + { + TypeId resultTy = ctx->arena->addType(TypeFamilyInstanceType{ + NotNull{&kBuiltinTypeFamilies.unionFamily}, + std::move(results), + {}, + }); + + queue->add(resultTy); + return {resultTy, false, {}, {}}; + } + else + { + // TODO: We need to generalize `union<...>` type family to be variadic. + TypeId resultTy = ctx->arena->addType(UnionType{std::move(results)}); + return {resultTy, false, {}, {}}; + } + } + // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. ErrorVec dummy; @@ -700,7 +795,7 @@ TypeFamilyReductionResult numericBinopFamilyFn(TypeId instance, const st } TypeFamilyReductionResult addFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -708,11 +803,11 @@ TypeFamilyReductionResult addFamilyFn( LUAU_ASSERT(false); } - return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__add"); + return numericBinopFamilyFn(instance, queue, typeParams, packParams, ctx, "__add"); } TypeFamilyReductionResult subFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -720,11 +815,11 @@ TypeFamilyReductionResult subFamilyFn( LUAU_ASSERT(false); } - return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__sub"); + return numericBinopFamilyFn(instance, queue, typeParams, packParams, ctx, "__sub"); } TypeFamilyReductionResult mulFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -732,11 +827,11 @@ TypeFamilyReductionResult mulFamilyFn( LUAU_ASSERT(false); } - return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__mul"); + return numericBinopFamilyFn(instance, queue, typeParams, packParams, ctx, "__mul"); } TypeFamilyReductionResult divFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -744,11 +839,11 @@ TypeFamilyReductionResult divFamilyFn( LUAU_ASSERT(false); } - return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__div"); + return numericBinopFamilyFn(instance, queue, typeParams, packParams, ctx, "__div"); } TypeFamilyReductionResult idivFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -756,11 +851,11 @@ TypeFamilyReductionResult idivFamilyFn( LUAU_ASSERT(false); } - return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__idiv"); + return numericBinopFamilyFn(instance, queue, typeParams, packParams, ctx, "__idiv"); } TypeFamilyReductionResult powFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -768,11 +863,11 @@ TypeFamilyReductionResult powFamilyFn( LUAU_ASSERT(false); } - return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__pow"); + return numericBinopFamilyFn(instance, queue, typeParams, packParams, ctx, "__pow"); } TypeFamilyReductionResult modFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -780,11 +875,11 @@ TypeFamilyReductionResult modFamilyFn( LUAU_ASSERT(false); } - return numericBinopFamilyFn(instance, typeParams, packParams, ctx, "__mod"); + return numericBinopFamilyFn(instance, queue, typeParams, packParams, ctx, "__mod"); } TypeFamilyReductionResult concatFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -801,8 +896,8 @@ TypeFamilyReductionResult concatFamilyFn( else if (isPending(rhsTy, ctx->solver)) return {std::nullopt, false, {rhsTy}, {}}; - const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy); - const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy); + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy) @@ -870,7 +965,7 @@ TypeFamilyReductionResult concatFamilyFn( } TypeFamilyReductionResult andFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -907,7 +1002,7 @@ TypeFamilyReductionResult andFamilyFn( } TypeFamilyReductionResult orFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -942,7 +1037,7 @@ TypeFamilyReductionResult orFamilyFn( return {overallResult.result, false, std::move(blockedTypes), {}}; } -static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, const std::vector& typeParams, +static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx, const std::string metamethod) { @@ -996,8 +1091,8 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, con // check to see if both operand types are resolved enough, and wait to reduce if not - const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy); - const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy); + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy) @@ -1059,7 +1154,7 @@ static TypeFamilyReductionResult comparisonFamilyFn(TypeId instance, con } TypeFamilyReductionResult ltFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1067,11 +1162,11 @@ TypeFamilyReductionResult ltFamilyFn( LUAU_ASSERT(false); } - return comparisonFamilyFn(instance, typeParams, packParams, ctx, "__lt"); + return comparisonFamilyFn(instance, queue, typeParams, packParams, ctx, "__lt"); } TypeFamilyReductionResult leFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1079,11 +1174,11 @@ TypeFamilyReductionResult leFamilyFn( LUAU_ASSERT(false); } - return comparisonFamilyFn(instance, typeParams, packParams, ctx, "__le"); + return comparisonFamilyFn(instance, queue, typeParams, packParams, ctx, "__le"); } TypeFamilyReductionResult eqFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1100,8 +1195,8 @@ TypeFamilyReductionResult eqFamilyFn( else if (isPending(rhsTy, ctx->solver)) return {std::nullopt, false, {rhsTy}, {}}; - const NormalizedType* normLhsTy = ctx->normalizer->normalize(lhsTy); - const NormalizedType* normRhsTy = ctx->normalizer->normalize(rhsTy); + std::shared_ptr normLhsTy = ctx->normalizer->normalize(lhsTy); + std::shared_ptr normRhsTy = ctx->normalizer->normalize(rhsTy); // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy) @@ -1188,7 +1283,7 @@ struct FindRefinementBlockers : TypeOnceVisitor TypeFamilyReductionResult refineFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1229,8 +1324,8 @@ TypeFamilyReductionResult refineFamilyFn( return {targetTy, false, {}, {}}; TypeId intersection = ctx->arena->addType(IntersectionType{{targetTy, discriminantTy}}); - const NormalizedType* normIntersection = ctx->normalizer->normalize(intersection); - const NormalizedType* normType = ctx->normalizer->normalize(targetTy); + std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); + std::shared_ptr normType = ctx->normalizer->normalize(targetTy); // if the intersection failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normIntersection || !normType) @@ -1246,7 +1341,7 @@ TypeFamilyReductionResult refineFamilyFn( } TypeFamilyReductionResult unionFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1277,7 +1372,7 @@ TypeFamilyReductionResult unionFamilyFn( TypeFamilyReductionResult intersectFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 2 || !packParams.empty()) { @@ -1380,7 +1475,7 @@ TypeFamilyReductionResult keyofFamilyImpl( TypeId operandTy = follow(typeParams.at(0)); - const NormalizedType* normTy = ctx->normalizer->normalize(operandTy); + std::shared_ptr normTy = ctx->normalizer->normalize(operandTy); // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy) @@ -1487,7 +1582,7 @@ TypeFamilyReductionResult keyofFamilyImpl( } TypeFamilyReductionResult keyofFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { @@ -1499,7 +1594,7 @@ TypeFamilyReductionResult keyofFamilyFn( } TypeFamilyReductionResult rawkeyofFamilyFn( - TypeId instance, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) + TypeId instance, NotNull queue, const std::vector& typeParams, const std::vector& packParams, NotNull ctx) { if (typeParams.size() != 1 || !packParams.empty()) { diff --git a/Analysis/src/TypeFamilyReductionGuesser.cpp b/Analysis/src/TypeFamilyReductionGuesser.cpp index f1130f82..7f865998 100644 --- a/Analysis/src/TypeFamilyReductionGuesser.cpp +++ b/Analysis/src/TypeFamilyReductionGuesser.cpp @@ -245,7 +245,7 @@ bool TypeFamilyReductionGuesser::operandIsAssignable(TypeId ty) return false; } -const NormalizedType* TypeFamilyReductionGuesser::normalize(TypeId ty) +std::shared_ptr TypeFamilyReductionGuesser::normalize(TypeId ty) { return normalizer->normalize(ty); } @@ -379,8 +379,8 @@ TypeFamilyInferenceResult TypeFamilyReductionGuesser::inferOrAndFamily(const Typ rhsTy = follow(*ty); TypeFamilyInferenceResult defaultAndOrInference{{builtins->unknownType, builtins->unknownType}, builtins->booleanType}; - const NormalizedType* lty = normalize(lhsTy); - const NormalizedType* rty = normalize(lhsTy); + std::shared_ptr lty = normalize(lhsTy); + std::shared_ptr rty = normalize(lhsTy); bool lhsTruthy = lty ? lty->isTruthy() : false; bool rhsTruthy = rty ? rty->isTruthy() : false; // If at the end, we still don't have good substitutions, return the default type diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 66c3cc20..2f8fad49 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -3,12 +3,10 @@ #include "Luau/ApplyTypeFunction.h" #include "Luau/Cancellation.h" -#include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Instantiation.h" #include "Luau/ModuleResolver.h" #include "Luau/Normalize.h" -#include "Luau/Parser.h" #include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" @@ -36,11 +34,11 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) -LUAU_FASTFLAGVARIABLE(LuauLoopControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) LUAU_FASTFLAGVARIABLE(LuauForbidAliasNamedTypeof, false) LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) +LUAU_FASTFLAG(LuauFixNormalizeCaching) namespace Luau { @@ -351,9 +349,9 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStat& program) else if (auto repeat = program.as()) return check(scope, *repeat); else if (program.is()) - return FFlag::LuauLoopControlFlowAnalysis ? ControlFlow::Breaks : ControlFlow::None; + return FFlag::LuauTinyControlFlowAnalysis ? ControlFlow::Breaks : ControlFlow::None; else if (program.is()) - return FFlag::LuauLoopControlFlowAnalysis ? ControlFlow::Continues : ControlFlow::None; + return FFlag::LuauTinyControlFlowAnalysis ? ControlFlow::Continues : ControlFlow::None; else if (auto return_ = program.as()) return check(scope, *return_); else if (auto expr = program.as()) @@ -756,7 +754,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) scope->inheritRefinements(thenScope); - if (FFlag::LuauLoopControlFlowAnalysis && thencf == elsecf) + if (FFlag::LuauTinyControlFlowAnalysis && thencf == elsecf) return thencf; else if (matches(thencf, ControlFlow::Returns | ControlFlow::Throws) && matches(elsecf, ControlFlow::Returns | ControlFlow::Throws)) return ControlFlow::Returns; @@ -2648,12 +2646,28 @@ static std::optional areEqComparable(NotNull arena, NotNulladdType(IntersectionType{{a, b}}); - const NormalizedType* n = normalizer->normalize(c); - if (!n) - return std::nullopt; + NormalizationResult nr; - switch (normalizer->isInhabited(n)) + if (FFlag::LuauFixNormalizeCaching) + { + TypeId c = arena->addType(IntersectionType{{a, b}}); + std::shared_ptr n = normalizer->normalize(c); + if (!n) + return std::nullopt; + + nr = normalizer->isInhabited(n.get()); + } + else + { + TypeId c = arena->addType(IntersectionType{{a, b}}); + const NormalizedType* n = normalizer->DEPRECATED_normalize(c); + if (!n) + return std::nullopt; + + nr = normalizer->isInhabited(n); + } + + switch (nr) { case NormalizationResult::HitLimits: return std::nullopt; diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 6a03039a..588b1da1 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -19,6 +19,25 @@ bool inConditional(const TypeContext& context) return context == TypeContext::Condition; } +bool occursCheck(TypeId needle, TypeId haystack) +{ + LUAU_ASSERT(get(needle) || get(needle)); + haystack = follow(haystack); + + auto checkHaystack = [needle](TypeId haystack) { + return occursCheck(needle, haystack); + }; + + if (needle == haystack) + return true; + else if (auto ut = get(haystack)) + return std::any_of(begin(ut), end(ut), checkHaystack); + else if (auto it = get(haystack)) + return std::any_of(begin(it), end(it), checkHaystack); + + return false; +} + std::optional findMetatableEntry( NotNull builtinTypes, ErrorVec& errors, TypeId type, const std::string& entry, Location location) { @@ -330,7 +349,8 @@ TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty) ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty) { - const NormalizedType* normType = normalizer->normalize(ty); + LUAU_ASSERT(FFlag::DebugLuauDeferredConstraintResolution); + std::shared_ptr normType = normalizer->normalize(ty); if (!normType) return ErrorSuppression::NormalizationFailed; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index a3129969..67f49722 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -23,6 +23,7 @@ LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false) +LUAU_FASTFLAG(LuauFixNormalizeCaching) namespace Luau { @@ -575,33 +576,38 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (log.get(superTy)) return tryUnifyWithAny(subTy, builtinTypes->anyType); - if (!FFlag::LuauTransitiveSubtyping && log.get(superTy)) - return tryUnifyWithAny(subTy, builtinTypes->errorType); - - if (!FFlag::LuauTransitiveSubtyping && log.get(superTy)) - return tryUnifyWithAny(subTy, builtinTypes->unknownType); - if (log.get(subTy)) { - if (FFlag::LuauTransitiveSubtyping && normalize) + if (normalize) { - // TODO: there are probably cheaper ways to check if any <: T. - const NormalizedType* superNorm = normalizer->normalize(superTy); + if (FFlag::LuauFixNormalizeCaching) + { + // TODO: there are probably cheaper ways to check if any <: T. + std::shared_ptr superNorm = normalizer->normalize(superTy); - if (!superNorm) - return reportError(location, NormalizationTooComplex{}); + if (!superNorm) + return reportError(location, NormalizationTooComplex{}); - if (!log.get(superNorm->tops)) - failure = true; + if (!log.get(superNorm->tops)) + failure = true; + } + else + { + // TODO: there are probably cheaper ways to check if any <: T. + const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); + + if (!superNorm) + return reportError(location, NormalizationTooComplex{}); + + if (!log.get(superNorm->tops)) + failure = true; + } } else failure = true; return tryUnifyWithAny(superTy, builtinTypes->anyType); } - if (!FFlag::LuauTransitiveSubtyping && log.get(subTy)) - return tryUnifyWithAny(superTy, builtinTypes->errorType); - if (log.get(subTy)) return tryUnifyWithAny(superTy, builtinTypes->neverType); @@ -649,32 +655,32 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); } - else if (FFlag::LuauTransitiveSubtyping && log.get(subTy)) + else if (log.get(subTy)) { tryUnifyWithAny(superTy, builtinTypes->unknownType); failure = true; } - else if (FFlag::LuauTransitiveSubtyping && log.get(subTy) && log.get(superTy)) + else if (log.get(subTy) && log.get(superTy)) { // error <: error } - else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + else if (log.get(superTy)) { tryUnifyWithAny(subTy, builtinTypes->errorType); failure = true; } - else if (FFlag::LuauTransitiveSubtyping && log.get(subTy)) + else if (log.get(subTy)) { tryUnifyWithAny(superTy, builtinTypes->errorType); failure = true; } - else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + else if (log.get(superTy)) { // At this point, all the supertypes of `error` have been handled, // and if `error unknownType); } - else if (FFlag::LuauTransitiveSubtyping && log.get(superTy)) + else if (log.get(superTy)) { tryUnifyWithAny(subTy, builtinTypes->unknownType); } @@ -765,10 +771,10 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; - else if (FFlag::LuauTransitiveSubtyping ? innerState.failure : !innerState.errors.empty()) + else if (innerState.failure) { // If errors were suppressed, we store the log up, so we can commit it if no other option succeeds. - if (FFlag::LuauTransitiveSubtyping && innerState.errors.empty()) + if (innerState.errors.empty()) logs.push_back(std::move(innerState.log)); // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' else if (!firstFailedOption && !isNil(type)) @@ -827,7 +833,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ { if (firstFailedOption) reportError(location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption, mismatchContext()}); - else if (!FFlag::LuauTransitiveSubtyping || !errorsSuppressed) + else if (!errorsSuppressed) reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); failure = true; } @@ -874,7 +880,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } } - if (FFlag::LuauTransitiveSubtyping && !foundHeuristic) + if (!foundHeuristic) { for (size_t i = 0; i < uv->options.size(); ++i) { @@ -914,7 +920,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp innerState.normalize = false; innerState.tryUnify_(subTy, type, isFunctionCall); - if (FFlag::LuauTransitiveSubtyping ? !innerState.failure : innerState.errors.empty()) + if (!innerState.failure) { found = true; if (useNewSolver) @@ -925,7 +931,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp break; } } - else if (FFlag::LuauTransitiveSubtyping && innerState.errors.empty()) + else if (innerState.errors.empty()) { errorsSuppressed = true; } @@ -949,21 +955,38 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { reportError(*unificationTooComplex); } - else if (FFlag::LuauTransitiveSubtyping && !found && normalize) + else if (!found && normalize) { // It is possible that T <: A | B even though T normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); Unifier innerState = makeChildUnifier(); - if (!subNorm || !superNorm) - return reportError(location, NormalizationTooComplex{}); - else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - innerState.tryUnifyNormalizedTypes( - subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + + if (FFlag::LuauFixNormalizeCaching) + { + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + return reportError(location, NormalizationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + innerState.tryUnifyNormalizedTypes( + subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + else + innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + } else - innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + { + const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); + const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); + if (!subNorm || !superNorm) + return reportError(location, NormalizationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + innerState.tryUnifyNormalizedTypes( + subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + else + innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + } + if (!innerState.failure) log.concat(std::move(innerState.log)); else if (errorsSuppressed || innerState.errors.empty()) @@ -976,18 +999,32 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp // It is possible that T <: A | B even though T normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); - if (!subNorm || !superNorm) - reportError(location, NormalizationTooComplex{}); - else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + if (FFlag::LuauFixNormalizeCaching) + { + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + reportError(location, NormalizationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + else + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + } else - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + { + const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); + const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); + if (!subNorm || !superNorm) + reportError(location, NormalizationTooComplex{}); + else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption); + else + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + } } else if (!found) { - if (FFlag::LuauTransitiveSubtyping && errorsSuppressed) + if (errorsSuppressed) failure = true; else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) reportError( @@ -1086,12 +1123,24 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* // It is possible that A & B <: T even though A normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); + if (FFlag::LuauFixNormalizeCaching) + { + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); + else + reportError(location, NormalizationTooComplex{}); + } else - reportError(location, NormalizationTooComplex{}); + { + const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); + const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); + else + reportError(location, NormalizationTooComplex{}); + } return; } @@ -1113,7 +1162,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* { found = true; errorsSuppressed = innerState.failure; - if (useNewSolver || (FFlag::LuauTransitiveSubtyping && innerState.failure)) + if (useNewSolver || innerState.failure) logs.push_back(std::move(innerState.log)); else { @@ -1130,7 +1179,7 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* if (useNewSolver) log.concat(combineLogsIntoIntersection(std::move(logs))); - else if (FFlag::LuauTransitiveSubtyping && errorsSuppressed) + else if (errorsSuppressed) log.concat(std::move(logs.front())); if (unificationTooComplex) @@ -1140,12 +1189,25 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* // It is possible that A & B <: T even though A normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); + + if (FFlag::LuauFixNormalizeCaching) + { + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); + else + reportError(location, NormalizationTooComplex{}); + } else - reportError(location, NormalizationTooComplex{}); + { + const NormalizedType* subNorm = normalizer->DEPRECATED_normalize(subTy); + const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); + if (subNorm && superNorm) + tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); + else + reportError(location, NormalizationTooComplex{}); + } } else if (!found) { @@ -1158,31 +1220,25 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* void Unifier::tryUnifyNormalizedTypes( TypeId subTy, TypeId superTy, const NormalizedType& subNorm, const NormalizedType& superNorm, std::string reason, std::optional error) { - if (!FFlag::LuauTransitiveSubtyping && get(superNorm.tops)) - return; - else if (get(superNorm.tops)) + if (get(superNorm.tops)) return; else if (get(subNorm.tops)) { failure = true; return; } - else if (!FFlag::LuauTransitiveSubtyping && get(subNorm.tops)) - return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.errors)) if (!get(superNorm.errors)) { failure = true; - if (!FFlag::LuauTransitiveSubtyping) - reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); return; } - if (FFlag::LuauTransitiveSubtyping && get(superNorm.tops)) + if (get(superNorm.tops)) return; - if (FFlag::LuauTransitiveSubtyping && get(subNorm.tops)) + if (get(subNorm.tops)) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); if (get(subNorm.booleans)) @@ -2654,16 +2710,32 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) if (!log.get(subTy) && !log.get(superTy)) ice("tryUnifyNegations superTy or subTy must be a negation type"); - const NormalizedType* subNorm = normalizer->normalize(subTy); - const NormalizedType* superNorm = normalizer->normalize(superTy); - if (!subNorm || !superNorm) - return reportError(location, NormalizationTooComplex{}); + if (FFlag::LuauFixNormalizeCaching) + { + std::shared_ptr subNorm = normalizer->normalize(subTy); + std::shared_ptr superNorm = normalizer->normalize(superTy); + if (!subNorm || !superNorm) + return reportError(location, NormalizationTooComplex{}); - // T DEPRECATED_normalize(subTy); + const NormalizedType* superNorm = normalizer->DEPRECATED_normalize(superTy); + if (!subNorm || !superNorm) + return reportError(location, NormalizationTooComplex{}); + + // T & queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index d1122ae6..7eb7e181 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -41,7 +41,6 @@ LUAU_FASTFLAG(DebugLuauTimeTracing) -LUAU_FASTFLAGVARIABLE(LuauUpdatedRequireByStringSemantics, false) constexpr int MaxTraversalLimit = 50; @@ -121,129 +120,60 @@ static int finishrequire(lua_State* L) static int lua_require(lua_State* L) { - if (FFlag::LuauUpdatedRequireByStringSemantics) - { - std::string name = luaL_checkstring(L, 1); + std::string name = luaL_checkstring(L, 1); - RequireResolver::ResolvedRequire resolvedRequire = RequireResolver::resolveRequire(L, std::move(name)); + RequireResolver::ResolvedRequire resolvedRequire = RequireResolver::resolveRequire(L, std::move(name)); - if (resolvedRequire.status == RequireResolver::ModuleStatus::Cached) - return finishrequire(L); - else if (resolvedRequire.status == RequireResolver::ModuleStatus::NotFound) - luaL_errorL(L, "error requiring module"); - - // module needs to run in a new thread, isolated from the rest - // note: we create ML on main thread so that it doesn't inherit environment of L - lua_State* GL = lua_mainthread(L); - lua_State* ML = lua_newthread(GL); - lua_xmove(GL, L, 1); - - // new thread needs to have the globals sandboxed - luaL_sandboxthread(ML); - - // now we can compile & run module on the new thread - std::string bytecode = Luau::compile(resolvedRequire.sourceCode, copts()); - if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0) - { - if (codegen) - Luau::CodeGen::compile(ML, -1); - - if (coverageActive()) - coverageTrack(ML, -1); - - int status = lua_resume(ML, L, 0); - - if (status == 0) - { - if (lua_gettop(ML) == 0) - lua_pushstring(ML, "module must return a value"); - else if (!lua_istable(ML, -1) && !lua_isfunction(ML, -1)) - lua_pushstring(ML, "module must return a table or function"); - } - else if (status == LUA_YIELD) - { - lua_pushstring(ML, "module can not yield"); - } - else if (!lua_isstring(ML, -1)) - { - lua_pushstring(ML, "unknown error while running module"); - } - } - - // there's now a return value on top of ML; L stack: _MODULES ML - lua_xmove(ML, L, 1); - lua_pushvalue(L, -1); - lua_setfield(L, -4, resolvedRequire.absolutePath.c_str()); - - // L stack: _MODULES ML result + if (resolvedRequire.status == RequireResolver::ModuleStatus::Cached) return finishrequire(L); - } - else + else if (resolvedRequire.status == RequireResolver::ModuleStatus::NotFound) + luaL_errorL(L, "error requiring module"); + + // module needs to run in a new thread, isolated from the rest + // note: we create ML on main thread so that it doesn't inherit environment of L + lua_State* GL = lua_mainthread(L); + lua_State* ML = lua_newthread(GL); + lua_xmove(GL, L, 1); + + // new thread needs to have the globals sandboxed + luaL_sandboxthread(ML); + + // now we can compile & run module on the new thread + std::string bytecode = Luau::compile(resolvedRequire.sourceCode, copts()); + if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { - std::string name = luaL_checkstring(L, 1); - std::string chunkname = "=" + name; + if (codegen) + Luau::CodeGen::compile(ML, -1); - luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + if (coverageActive()) + coverageTrack(ML, -1); - // return the module from the cache - lua_getfield(L, -1, name.c_str()); - if (!lua_isnil(L, -1)) + int status = lua_resume(ML, L, 0); + + if (status == 0) { - // L stack: _MODULES result - return finishrequire(L); + if (lua_gettop(ML) == 0) + lua_pushstring(ML, "module must return a value"); + else if (!lua_istable(ML, -1) && !lua_isfunction(ML, -1)) + lua_pushstring(ML, "module must return a table or function"); } - - lua_pop(L, 1); - - std::optional source = readFile(name + ".luau"); - if (!source) + else if (status == LUA_YIELD) { - source = readFile(name + ".lua"); // try .lua if .luau doesn't exist - if (!source) - luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error + lua_pushstring(ML, "module can not yield"); } - - // module needs to run in a new thread, isolated from the rest - // note: we create ML on main thread so that it doesn't inherit environment of L - lua_State* GL = lua_mainthread(L); - lua_State* ML = lua_newthread(GL); - lua_xmove(GL, L, 1); - // new thread needs to have the globals sandboxed - luaL_sandboxthread(ML); - - // now we can compile & run module on the new thread - std::string bytecode = Luau::compile(*source, copts()); - if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) + else if (!lua_isstring(ML, -1)) { - if (codegen) - Luau::CodeGen::compile(ML, -1); - if (coverageActive()) - coverageTrack(ML, -1); - int status = lua_resume(ML, L, 0); - if (status == 0) - { - if (lua_gettop(ML) == 0) - lua_pushstring(ML, "module must return a value"); - else if (!lua_istable(ML, -1) && !lua_isfunction(ML, -1)) - lua_pushstring(ML, "module must return a table or function"); - } - else if (status == LUA_YIELD) - { - lua_pushstring(ML, "module can not yield"); - } - else if (!lua_isstring(ML, -1)) - { - lua_pushstring(ML, "unknown error while running module"); - } + lua_pushstring(ML, "unknown error while running module"); } - // there's now a return value on top of ML; L stack: _MODULES ML - lua_xmove(ML, L, 1); - lua_pushvalue(L, -1); - lua_setfield(L, -4, name.c_str()); - - // L stack: _MODULES ML result - return finishrequire(L); } + + // there's now a return value on top of ML; L stack: _MODULES ML + lua_xmove(ML, L, 1); + lua_pushvalue(L, -1); + lua_setfield(L, -4, resolvedRequire.absolutePath.c_str()); + + // L stack: _MODULES ML result + return finishrequire(L); } static int lua_collectgarbage(lua_State* L) diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 64522386..22dd000c 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -2,6 +2,8 @@ #pragma once #include +#include +#include #include #include @@ -69,14 +71,39 @@ struct CompilationStats uint32_t functionsTotal = 0; uint32_t functionsCompiled = 0; + uint32_t functionsBound = 0; }; using AllocationCallback = void(void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize); bool isSupported(); +class SharedCodeGenContext; + +struct SharedCodeGenContextDeleter +{ + void operator()(const SharedCodeGenContext* context) const noexcept; +}; + +using UniqueSharedCodeGenContext = std::unique_ptr; + +// Creates a new SharedCodeGenContext that can be used by multiple Luau VMs +// concurrently, using either the default allocator parameters or custom +// allocator parameters. +[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(); + +[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext); + +[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext( + size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext); + +// Destroys the provided SharedCodeGenContext. All Luau VMs using the +// SharedCodeGenContext must be destroyed before this function is called. +void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept; + void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext); void create(lua_State* L); +void create(lua_State* L, SharedCodeGenContext* codeGenContext); // Check if native execution is enabled [[nodiscard]] bool isNativeExecutionEnabled(lua_State* L); @@ -84,9 +111,12 @@ void create(lua_State* L); // Enable or disable native execution according to `enabled` argument void setNativeExecutionEnabled(lua_State* L, bool enabled); +using ModuleId = std::array; + // Builds target function and all inner functions CodeGenCompilationResult compile_DEPRECATED(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); CompilationResult compile(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); diff --git a/CodeGen/include/Luau/IrVisitUseDef.h b/CodeGen/include/Luau/IrVisitUseDef.h index 09167ef3..58c88661 100644 --- a/CodeGen/include/Luau/IrVisitUseDef.h +++ b/CodeGen/include/Luau/IrVisitUseDef.h @@ -4,7 +4,7 @@ #include "Luau/Common.h" #include "Luau/IrData.h" -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) namespace Luau { @@ -188,7 +188,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i visitor.def(inst.b); break; case IrCmd::FALLBACK_FORGPREP: - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) { // This instruction doesn't always redefine Rn, Rn+1, Rn+2, so we have to mark it as implicit use visitor.useRange(vmRegOp(inst.b), 3); @@ -216,7 +216,7 @@ static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrInst& i // After optimizations with DebugLuauAbortingChecks enabled, CHECK_TAG Rn, tag, block instructions are generated case IrCmd::CHECK_TAG: - if (!FFlag::LuauCodegenRemoveDeadStores4) + if (!FFlag::LuauCodegenRemoveDeadStores5) visitor.maybeUse(inst.a); break; diff --git a/CodeGen/include/Luau/SharedCodeAllocator.h b/CodeGen/include/Luau/SharedCodeAllocator.h index 10841893..7796096a 100644 --- a/CodeGen/include/Luau/SharedCodeAllocator.h +++ b/CodeGen/include/Luau/SharedCodeAllocator.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/CodeGen.h" #include "Luau/Common.h" #include "Luau/NativeProtoExecData.h" @@ -8,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -26,7 +28,6 @@ namespace CodeGen // The module is the unit of shared ownership (i.e., it is where the reference // count is maintained). -using ModuleId = std::array; struct CodeAllocator; class NativeModule; @@ -41,7 +42,7 @@ class SharedCodeAllocator; class NativeModule { public: - NativeModule(SharedCodeAllocator* allocator, const ModuleId& moduleId, const uint8_t* moduleBaseAddress, + NativeModule(SharedCodeAllocator* allocator, const std::optional& moduleId, const uint8_t* moduleBaseAddress, std::vector nativeProtos) noexcept; NativeModule(const NativeModule&) = delete; @@ -59,6 +60,8 @@ public: size_t release() const noexcept; [[nodiscard]] size_t getRefcount() const noexcept; + [[nodiscard]] const std::optional& getModuleId() const noexcept; + // Gets the base address of the executable native code for the module. [[nodiscard]] const uint8_t* getModuleBaseAddress() const noexcept; @@ -72,7 +75,7 @@ private: mutable std::atomic refcount = 0; SharedCodeAllocator* allocator = nullptr; - ModuleId moduleId = {}; + std::optional moduleId = {}; const uint8_t* moduleBaseAddress = nullptr; std::vector nativeProtos = {}; @@ -85,7 +88,7 @@ class NativeModuleRef { public: NativeModuleRef() noexcept = default; - NativeModuleRef(NativeModule* nativeModule) noexcept; + NativeModuleRef(const NativeModule* nativeModule) noexcept; NativeModuleRef(const NativeModuleRef& other) noexcept; NativeModuleRef(NativeModuleRef&& other) noexcept; @@ -132,11 +135,14 @@ public: std::pair getOrInsertNativeModule(const ModuleId& moduleId, std::vector nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize); + NativeModuleRef insertAnonymousNativeModule( + std::vector nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize); + // If a NativeModule exists for the given ModuleId and that NativeModule // is no longer referenced, the NativeModule is destroyed. This should // usually only be called by NativeModule::release() when the reference // count becomes zero - void eraseNativeModuleIfUnreferenced(const ModuleId& moduleId); + void eraseNativeModuleIfUnreferenced(const NativeModule& nativeModule); private: struct ModuleIdHash @@ -148,7 +154,9 @@ private: mutable std::mutex mutex; - std::unordered_map, ModuleIdHash, std::equal_to<>> nativeModules; + std::unordered_map, ModuleIdHash, std::equal_to<>> identifiedModules; + + std::atomic anonymousModuleCount = 0; CodeAllocator* codeAllocator = nullptr; }; diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index 85d34db8..9d0522c0 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -7,8 +7,6 @@ #include #include -LUAU_FASTFLAG(LuauCodeGenOptVecA64) - namespace Luau { namespace CodeGen @@ -559,42 +557,26 @@ void AssemblyBuilderA64::fmov(RegisterA64 dst, RegisterA64 src) void AssemblyBuilderA64::fmov(RegisterA64 dst, double src) { - if (FFlag::LuauCodeGenOptVecA64) + CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::q); + + int imm = getFmovImm(src); + CODEGEN_ASSERT(imm >= 0 && imm <= 256); + + // fmov can't encode 0, but movi can; movi is otherwise not useful for fp immediates because it encodes repeating patterns + if (dst.kind == KindA64::d) { - CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::q); - - int imm = getFmovImm(src); - CODEGEN_ASSERT(imm >= 0 && imm <= 256); - - // fmov can't encode 0, but movi can; movi is otherwise not useful for fp immediates because it encodes repeating patterns - if (dst.kind == KindA64::d) - { - if (imm == 256) - placeFMOV("movi", dst, src, 0b001'0111100000'000'1110'01'00000); - else - placeFMOV("fmov", dst, src, 0b000'11110'01'1'00000000'100'00000 | (imm << 8)); - } - else - { - if (imm == 256) - placeFMOV("movi.4s", dst, src, 0b010'0111100000'000'0000'01'00000); - else - placeFMOV("fmov.4s", dst, src, 0b010'0111100000'000'1111'0'1'00000 | ((imm >> 5) << 11) | (imm & 31)); - } - } - else - { - CODEGEN_ASSERT(dst.kind == KindA64::d); - - int imm = getFmovImm(src); - CODEGEN_ASSERT(imm >= 0 && imm <= 256); - - // fmov can't encode 0, but movi can; movi is otherwise not useful for 64-bit fp immediates because it encodes repeating patterns if (imm == 256) placeFMOV("movi", dst, src, 0b001'0111100000'000'1110'01'00000); else placeFMOV("fmov", dst, src, 0b000'11110'01'1'00000000'100'00000 | (imm << 8)); } + else + { + if (imm == 256) + placeFMOV("movi.4s", dst, src, 0b010'0111100000'000'0000'01'00000); + else + placeFMOV("fmov.4s", dst, src, 0b010'0111100000'000'1111'0'1'00000 | ((imm >> 5) << 11) | (imm & 31)); + } } void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src) diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index c3c03193..9c78a784 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -15,6 +15,7 @@ #include "Luau/AssemblyBuilderA64.h" #include "Luau/AssemblyBuilderX64.h" +#include "CodeGenContext.h" #include "NativeState.h" #include "CodeGenA64.h" @@ -58,7 +59,7 @@ LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockLimit, 32'768) // 32 K // Current value is based on some member variables being limited to 16 bits LUAU_FASTINTVARIABLE(CodegenHeuristicsBlockInstructionLimit, 65'536) // 64 K -LUAU_FASTFLAG(LuauCodegenHeapSizeReport) +LUAU_FASTFLAG(LuauCodegenContext) namespace Luau { @@ -87,7 +88,7 @@ struct ExtraExecData static int alignTo(int value, int align) { - CODEGEN_ASSERT(FFlag::LuauCodegenHeapSizeReport); + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); CODEGEN_ASSERT(align > 0 && (align & (align - 1)) == 0); return (value + (align - 1)) & ~(align - 1); } @@ -96,7 +97,7 @@ static int alignTo(int value, int align) // Always a multiple of 4 bytes static int calculateExecDataSize(Proto* proto) { - CODEGEN_ASSERT(FFlag::LuauCodegenHeapSizeReport); + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); int size = proto->sizecode * sizeof(uint32_t); size = alignTo(size, 16); @@ -109,7 +110,7 @@ static int calculateExecDataSize(Proto* proto) // Even though 'execdata' is a field in Proto, we require it to support cases where it's not attached to Proto during construction ExtraExecData* getExtraExecData(Proto* proto, void* execdata) { - CODEGEN_ASSERT(FFlag::LuauCodegenHeapSizeReport); + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); int size = proto->sizecode * sizeof(uint32_t); size = alignTo(size, 16); @@ -119,61 +120,43 @@ ExtraExecData* getExtraExecData(Proto* proto, void* execdata) static OldNativeProto createOldNativeProto(Proto* proto, const IrBuilder& ir) { - if (FFlag::LuauCodegenHeapSizeReport) + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); + + int execDataSize = calculateExecDataSize(proto); + CODEGEN_ASSERT(execDataSize % 4 == 0); + + uint32_t* execData = new uint32_t[execDataSize / 4]; + uint32_t instTarget = ir.function.entryLocation; + + for (int i = 0; i < proto->sizecode; i++) { - int execDataSize = calculateExecDataSize(proto); - CODEGEN_ASSERT(execDataSize % 4 == 0); + CODEGEN_ASSERT(ir.function.bcMapping[i].asmLocation >= instTarget); - uint32_t* execData = new uint32_t[execDataSize / 4]; - uint32_t instTarget = ir.function.entryLocation; - - for (int i = 0; i < proto->sizecode; i++) - { - CODEGEN_ASSERT(ir.function.bcMapping[i].asmLocation >= instTarget); - - execData[i] = ir.function.bcMapping[i].asmLocation - instTarget; - } - - // Set first instruction offset to 0 so that entering this function still executes any generated entry code. - execData[0] = 0; - - ExtraExecData* extra = getExtraExecData(proto, execData); - memset(extra, 0, sizeof(ExtraExecData)); - - extra->execDataSize = execDataSize; - - // entry target will be relocated when assembly is finalized - return {proto, execData, instTarget}; + execData[i] = ir.function.bcMapping[i].asmLocation - instTarget; } - else - { - int sizecode = proto->sizecode; - uint32_t* instOffsets = new uint32_t[sizecode]; - uint32_t instTarget = ir.function.entryLocation; + // Set first instruction offset to 0 so that entering this function still executes any generated entry code. + execData[0] = 0; - for (int i = 0; i < sizecode; i++) - { - CODEGEN_ASSERT(ir.function.bcMapping[i].asmLocation >= instTarget); + ExtraExecData* extra = getExtraExecData(proto, execData); + memset(extra, 0, sizeof(ExtraExecData)); - instOffsets[i] = ir.function.bcMapping[i].asmLocation - instTarget; - } + extra->execDataSize = execDataSize; - // Set first instruction offset to 0 so that entering this function still executes any generated entry code. - instOffsets[0] = 0; - - // entry target will be relocated when assembly is finalized - return {proto, instOffsets, instTarget}; - } + // entry target will be relocated when assembly is finalized + return {proto, execData, instTarget}; } static void destroyExecData(void* execdata) { + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); + delete[] static_cast(execdata); } static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) { + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); CODEGEN_ASSERT(p->source); const char* source = getstr(p->source); @@ -190,6 +173,8 @@ template static std::optional createNativeFunction( AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result) { + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); + IrBuilder ir; ir.buildFunctionIr(proto); @@ -210,17 +195,23 @@ static std::optional createNativeFunction( static NativeState* getNativeState(lua_State* L) { + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); + return static_cast(L->global->ecb.context); } static void onCloseState(lua_State* L) { + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); + delete getNativeState(L); L->global->ecb = lua_ExecutionCallbacks(); } static void onDestroyFunction(lua_State* L, Proto* proto) { + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); + destroyExecData(proto->execdata); proto->execdata = nullptr; proto->exectarget = 0; @@ -229,6 +220,8 @@ static void onDestroyFunction(lua_State* L, Proto* proto) static int onEnter(lua_State* L, Proto* proto) { + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); + NativeState* data = getNativeState(L); CODEGEN_ASSERT(proto->execdata); @@ -243,6 +236,8 @@ static int onEnter(lua_State* L, Proto* proto) // used to disable native execution, unconditionally static int onEnterDisabled(lua_State* L, Proto* proto) { + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); + return 1; } @@ -287,7 +282,7 @@ void onDisable(lua_State* L, Proto* proto) static size_t getMemorySize(lua_State* L, Proto* proto) { - CODEGEN_ASSERT(FFlag::LuauCodegenHeapSizeReport); + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); ExtraExecData* extra = getExtraExecData(proto, proto->execdata); // While execDataSize is exactly the size of the allocation we made and hold for 'execdata' field, the code size is approximate @@ -354,8 +349,9 @@ bool isSupported() #endif } -void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) +static void create_OLD(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) { + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); CODEGEN_ASSERT(isSupported()); std::unique_ptr data = std::make_unique(allocationCallback, allocationCallbackContext); @@ -390,29 +386,68 @@ void create(lua_State* L, AllocationCallback* allocationCallback, void* allocati ecb->destroy = onDestroyFunction; ecb->enter = onEnter; ecb->disable = onDisable; + ecb->getmemorysize = getMemorySize; +} - if (FFlag::LuauCodegenHeapSizeReport) - ecb->getmemorysize = getMemorySize; +void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) +{ + if (FFlag::LuauCodegenContext) + { + create_NEW(L, allocationCallback, allocationCallbackContext); + } + else + { + create_OLD(L, allocationCallback, allocationCallbackContext); + } } void create(lua_State* L) { - create(L, nullptr, nullptr); + if (FFlag::LuauCodegenContext) + { + create_NEW(L); + } + else + { + create(L, nullptr, nullptr); + } +} + +void create(lua_State* L, SharedCodeGenContext* codeGenContext) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + + create_NEW(L, codeGenContext); } [[nodiscard]] bool isNativeExecutionEnabled(lua_State* L) { - return getNativeState(L) ? (L->global->ecb.enter == onEnter) : false; + if (FFlag::LuauCodegenContext) + { + return isNativeExecutionEnabled_NEW(L); + } + else + { + return getNativeState(L) ? (L->global->ecb.enter == onEnter) : false; + } } void setNativeExecutionEnabled(lua_State* L, bool enabled) { - if (getNativeState(L)) - L->global->ecb.enter = enabled ? onEnter : onEnterDisabled; + if (FFlag::LuauCodegenContext) + { + setNativeExecutionEnabled_NEW(L, enabled); + } + else + { + if (getNativeState(L)) + L->global->ecb.enter = enabled ? onEnter : onEnterDisabled; + } } CodeGenCompilationResult compile_DEPRECATED(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) { + CODEGEN_ASSERT(!FFlag::LuauCodegenContext); CODEGEN_ASSERT(!FFlag::LuauCodegenDetailedCompilationResult); CODEGEN_ASSERT(lua_isLfunction(L, idx)); @@ -506,39 +541,20 @@ CodeGenCompilationResult compile_DEPRECATED(lua_State* L, int idx, unsigned int return CodeGenCompilationResult::AllocationFailed; } - if (FFlag::LuauCodegenHeapSizeReport) + if (gPerfLogFn && results.size() > 0) + gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), ""); + + for (size_t i = 0; i < results.size(); ++i) { - if (gPerfLogFn && results.size() > 0) - gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), ""); + uint32_t begin = uint32_t(results[i].exectarget); + uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0])); + CODEGEN_ASSERT(begin < end); - for (size_t i = 0; i < results.size(); ++i) - { - uint32_t begin = uint32_t(results[i].exectarget); - uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0])); - CODEGEN_ASSERT(begin < end); + if (gPerfLogFn) + logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin); - if (gPerfLogFn) - logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin); - - ExtraExecData* extra = getExtraExecData(results[i].p, results[i].execdata); - extra->codeSize = end - begin; - } - } - else - { - if (gPerfLogFn && results.size() > 0) - { - gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), ""); - - for (size_t i = 0; i < results.size(); ++i) - { - uint32_t begin = uint32_t(results[i].exectarget); - uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0])); - CODEGEN_ASSERT(begin < end); - - logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin); - } - } + ExtraExecData* extra = getExtraExecData(results[i].p, results[i].execdata); + extra->codeSize = end - begin; } for (const OldNativeProto& result : results) @@ -567,7 +583,7 @@ CodeGenCompilationResult compile_DEPRECATED(lua_State* L, int idx, unsigned int return codeGenCompilationResult; } -CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +static CompilationResult compile_OLD(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) { CODEGEN_ASSERT(FFlag::LuauCodegenDetailedCompilationResult); @@ -667,39 +683,20 @@ CompilationResult compile(lua_State* L, int idx, unsigned int flags, Compilation return compilationResult; } - if (FFlag::LuauCodegenHeapSizeReport) + if (gPerfLogFn && results.size() > 0) + gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), ""); + + for (size_t i = 0; i < results.size(); ++i) { - if (gPerfLogFn && results.size() > 0) - gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), ""); + uint32_t begin = uint32_t(results[i].exectarget); + uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0])); + CODEGEN_ASSERT(begin < end); - for (size_t i = 0; i < results.size(); ++i) - { - uint32_t begin = uint32_t(results[i].exectarget); - uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0])); - CODEGEN_ASSERT(begin < end); + if (gPerfLogFn) + logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin); - if (gPerfLogFn) - logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin); - - ExtraExecData* extra = getExtraExecData(results[i].p, results[i].execdata); - extra->codeSize = end - begin; - } - } - else - { - if (gPerfLogFn && results.size() > 0) - { - gPerfLogFn(gPerfLogContext, uintptr_t(codeStart), uint32_t(results[0].exectarget), ""); - - for (size_t i = 0; i < results.size(); ++i) - { - uint32_t begin = uint32_t(results[i].exectarget); - uint32_t end = i + 1 < results.size() ? uint32_t(results[i + 1].exectarget) : uint32_t(build.code.size() * sizeof(build.code[0])); - CODEGEN_ASSERT(begin < end); - - logPerfFunction(results[i].p, uintptr_t(codeStart) + begin, end - begin); - } - } + ExtraExecData* extra = getExtraExecData(results[i].p, results[i].execdata); + extra->codeSize = end - begin; } for (const OldNativeProto& result : results) @@ -728,6 +725,25 @@ CompilationResult compile(lua_State* L, int idx, unsigned int flags, Compilation return compilationResult; } +CompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +{ + if (FFlag::LuauCodegenContext) + { + return compile_NEW(L, idx, flags, stats); + } + else + { + return compile_OLD(L, idx, flags, stats); + } +} + +CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + + return compile_NEW(moduleId, L, idx, flags, stats); +} + void setPerfLog(void* context, PerfLogFn logFn) { gPerfLogContext = context; diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index 26764a89..f8adf3b7 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -13,6 +13,8 @@ #include "lapi.h" +LUAU_FASTFLAGVARIABLE(LuauCodegenContext, false) + LUAU_FASTINT(LuauCodeGenBlockSize) LUAU_FASTINT(LuauCodeGenMaxTotalSize) @@ -31,6 +33,7 @@ unsigned int getCpuFeaturesA64(); static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(p->source); const char* source = getstr(p->source); @@ -46,6 +49,8 @@ static void logPerfFunction(Proto* p, uintptr_t addr, unsigned size) static void logPerfFunctions( const std::vector& moduleProtos, const uint8_t* nativeModuleBaseAddress, const std::vector& nativeProtos) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + if (gPerfLogFn == nullptr) return; @@ -75,9 +80,11 @@ static void logPerfFunctions( // StandaloneCodeContext). If Release is false, the native proto will not be // removed from the vector (for use with the SharedCodeContext). template -static size_t bindNativeProtos(const std::vector& moduleProtos, NativeProtosVector& nativeProtos) +[[nodiscard]] static uint32_t bindNativeProtos(const std::vector& moduleProtos, NativeProtosVector& nativeProtos) { - size_t protosBound = 0; + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + + uint32_t protosBound = 0; auto protoIt = moduleProtos.begin(); @@ -117,6 +124,7 @@ static size_t bindNativeProtos(const std::vector& moduleProtos, NativePr BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) : codeAllocator{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(isSupported()); #if defined(_WIN32) @@ -134,6 +142,8 @@ BaseCodeGenContext::BaseCodeGenContext(size_t blockSize, size_t maxTotalSize, Al [[nodiscard]] bool BaseCodeGenContext::initHeaderFunctions() { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + #if defined(__x86_64__) || defined(_M_X64) if (!X64::initHeaderFunctions(*this)) return false; @@ -153,23 +163,28 @@ StandaloneCodeGenContext::StandaloneCodeGenContext( size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) : BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); } -[[nodiscard]] std::optional StandaloneCodeGenContext::tryBindExistingModule(const ModuleId&, const std::vector&) +[[nodiscard]] std::optional StandaloneCodeGenContext::tryBindExistingModule(const ModuleId&, const std::vector&) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + // The StandaloneCodeGenContext does not support sharing of native code return {}; } -[[nodiscard]] CodeGenCompilationResult StandaloneCodeGenContext::bindModule(const ModuleId&, const std::vector& moduleProtos, +[[nodiscard]] ModuleBindResult StandaloneCodeGenContext::bindModule(const std::optional&, const std::vector& moduleProtos, std::vector nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + uint8_t* nativeData = nullptr; size_t sizeNativeData = 0; uint8_t* codeStart = nullptr; if (!codeAllocator.allocate(data, int(dataSize), code, int(codeSize), nativeData, sizeNativeData, codeStart)) { - return CodeGenCompilationResult::AllocationFailed; + return {CodeGenCompilationResult::AllocationFailed}; } // Relocate the entry offsets to their final executable addresses: @@ -182,13 +197,15 @@ StandaloneCodeGenContext::StandaloneCodeGenContext( logPerfFunctions(moduleProtos, codeStart, nativeProtos); - bindNativeProtos(moduleProtos, nativeProtos); + const uint32_t protosBound = bindNativeProtos(moduleProtos, nativeProtos); - return CodeGenCompilationResult::Success; + return {CodeGenCompilationResult::Success, protosBound}; } void StandaloneCodeGenContext::onCloseState() noexcept { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + // The StandaloneCodeGenContext is owned by the one VM that owns it, so when // that VM is destroyed, we destroy *this as well: delete this; @@ -196,6 +213,8 @@ void StandaloneCodeGenContext::onCloseState() noexcept void StandaloneCodeGenContext::onDestroyFunction(void* execdata) noexcept { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + destroyNativeProtoExecData(static_cast(execdata)); } @@ -205,11 +224,14 @@ SharedCodeGenContext::SharedCodeGenContext( : BaseCodeGenContext{blockSize, maxTotalSize, allocationCallback, allocationCallbackContext} , sharedAllocator{&codeAllocator} { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); } -[[nodiscard]] std::optional SharedCodeGenContext::tryBindExistingModule( +[[nodiscard]] std::optional SharedCodeGenContext::tryBindExistingModule( const ModuleId& moduleId, const std::vector& moduleProtos) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + NativeModuleRef nativeModule = sharedAllocator.tryGetNativeModule(moduleId); if (nativeModule.empty()) { @@ -217,33 +239,47 @@ SharedCodeGenContext::SharedCodeGenContext( } // Bind the native protos and acquire an owning reference for each: - nativeModule->addRefs(bindNativeProtos(moduleProtos, nativeModule->getNativeProtos())); + const uint32_t protosBound = bindNativeProtos(moduleProtos, nativeModule->getNativeProtos()); + nativeModule->addRefs(protosBound); - return CodeGenCompilationResult::Success; + return {{CodeGenCompilationResult::Success, protosBound}}; } -[[nodiscard]] CodeGenCompilationResult SharedCodeGenContext::bindModule(const ModuleId& moduleId, const std::vector& moduleProtos, +[[nodiscard]] ModuleBindResult SharedCodeGenContext::bindModule(const std::optional& moduleId, const std::vector& moduleProtos, std::vector nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) { - const std::pair insertionResult = - sharedAllocator.getOrInsertNativeModule(moduleId, std::move(nativeProtos), data, dataSize, code, codeSize); + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + + const std::pair insertionResult = [&]() -> std::pair { + if (moduleId.has_value()) + { + return sharedAllocator.getOrInsertNativeModule(*moduleId, std::move(nativeProtos), data, dataSize, code, codeSize); + } + else + { + return {sharedAllocator.insertAnonymousNativeModule(std::move(nativeProtos), data, dataSize, code, codeSize), true}; + } + }(); // If we did not get a NativeModule back, allocation failed: if (insertionResult.first.empty()) - return CodeGenCompilationResult::AllocationFailed; + return {CodeGenCompilationResult::AllocationFailed}; // If we allocated a new module, log the function code ranges for perf: if (insertionResult.second) logPerfFunctions(moduleProtos, insertionResult.first->getModuleBaseAddress(), insertionResult.first->getNativeProtos()); // Bind the native protos and acquire an owning reference for each: - insertionResult.first->addRefs(bindNativeProtos(moduleProtos, insertionResult.first->getNativeProtos())); + const uint32_t protosBound = bindNativeProtos(moduleProtos, insertionResult.first->getNativeProtos()); + insertionResult.first->addRefs(protosBound); - return CodeGenCompilationResult::Success; + return {CodeGenCompilationResult::Success, protosBound}; } void SharedCodeGenContext::onCloseState() noexcept { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + // The lifetime of the SharedCodeGenContext is managed separately from the // VMs that use it. When a VM is destroyed, we don't need to do anything // here. @@ -251,17 +287,23 @@ void SharedCodeGenContext::onCloseState() noexcept void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + getNativeProtoExecDataHeader(static_cast(execdata)).nativeModule->release(); } [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext() { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + return createSharedCodeGenContext(size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr); } [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + return createSharedCodeGenContext( size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext); } @@ -269,6 +311,8 @@ void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept [[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext( size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + UniqueSharedCodeGenContext codeGenContext{new SharedCodeGenContext{blockSize, maxTotalSize, nullptr, nullptr}}; if (!codeGenContext->initHeaderFunctions()) @@ -279,28 +323,38 @@ void SharedCodeGenContext::onDestroyFunction(void* execdata) noexcept void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + delete codeGenContext; } void SharedCodeGenContextDeleter::operator()(const SharedCodeGenContext* codeGenContext) const noexcept { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + destroySharedCodeGenContext(codeGenContext); } [[nodiscard]] static BaseCodeGenContext* getCodeGenContext(lua_State* L) noexcept { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + return static_cast(L->global->ecb.context); } static void onCloseState(lua_State* L) noexcept { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + getCodeGenContext(L)->onCloseState(); L->global->ecb = lua_ExecutionCallbacks{}; } static void onDestroyFunction(lua_State* L, Proto* proto) noexcept { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + getCodeGenContext(L)->onDestroyFunction(proto->execdata); proto->execdata = nullptr; proto->exectarget = 0; @@ -309,6 +363,8 @@ static void onDestroyFunction(lua_State* L, Proto* proto) noexcept static int onEnter(lua_State* L, Proto* proto) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + BaseCodeGenContext* codeGenContext = getCodeGenContext(L); CODEGEN_ASSERT(proto->execdata); @@ -322,6 +378,8 @@ static int onEnter(lua_State* L, Proto* proto) static int onEnterDisabled(lua_State* L, Proto* proto) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + return 1; } @@ -330,6 +388,8 @@ void onDisable(lua_State* L, Proto* proto); static size_t getMemorySize(lua_State* L, Proto* proto) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + const NativeProtoExecDataHeader& execDataHeader = getNativeProtoExecDataHeader(static_cast(proto->execdata)); const size_t execDataSize = sizeof(NativeProtoExecDataHeader) + execDataHeader.bytecodeInstructionCount * sizeof(Instruction); @@ -342,6 +402,8 @@ static size_t getMemorySize(lua_State* L, Proto* proto) static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeGenContext) noexcept { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + lua_ExecutionCallbacks* ecb = &L->global->ecb; ecb->context = codeGenContext; @@ -354,16 +416,22 @@ static void initializeExecutionCallbacks(lua_State* L, BaseCodeGenContext* codeG void create_NEW(lua_State* L) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), nullptr, nullptr); } void create_NEW(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + return create_NEW(L, size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext); } void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + std::unique_ptr codeGenContext = std::make_unique(blockSize, maxTotalSize, allocationCallback, allocationCallbackContext); @@ -375,11 +443,15 @@ void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationC void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + initializeExecutionCallbacks(L, codeGenContext); } [[nodiscard]] static NativeProtoExecDataPtr createNativeProtoExecData(Proto* proto, const IrBuilder& ir) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + NativeProtoExecDataPtr nativeExecData = createNativeProtoExecData(proto->sizecode); uint32_t instTarget = ir.function.entryLocation; @@ -407,6 +479,8 @@ template [[nodiscard]] static NativeProtoExecDataPtr createNativeFunction( AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, uint32_t& totalIrInstCount, CodeGenCompilationResult& result) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + IrBuilder ir; ir.buildFunctionIr(proto); @@ -428,8 +502,10 @@ template return createNativeProtoExecData(proto, ir); } -CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +[[nodiscard]] static CompilationResult compileInternal( + const std::optional& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); CODEGEN_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); @@ -455,12 +531,20 @@ CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, u if (protos.empty()) return CompilationResult{CodeGenCompilationResult::NothingToCompile}; - if (std::optional existingModuleBindResult = codeGenContext->tryBindExistingModule(moduleId, protos)) - return CompilationResult{*existingModuleBindResult}; - if (stats != nullptr) stats->functionsTotal = uint32_t(protos.size()); + if (moduleId.has_value()) + { + if (std::optional existingModuleBindResult = codeGenContext->tryBindExistingModule(*moduleId, protos)) + { + if (stats != nullptr) + stats->functionsBound = existingModuleBindResult->functionsBound; + + return CompilationResult{existingModuleBindResult->compilationResult}; + } + } + #if defined(__aarch64__) static unsigned int cpuFeatures = getCpuFeaturesA64(); A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); @@ -523,7 +607,7 @@ CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, u } stats->functionsCompiled += uint32_t(nativeProtos.size()); - stats->nativeCodeSizeBytes += build.code.size(); + stats->nativeCodeSizeBytes += build.code.size() * sizeof(build.code[0]); stats->nativeDataSizeBytes += build.data.size(); } @@ -533,28 +617,51 @@ CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, u uint32_t begin = uint32_t(reinterpret_cast(header.entryOffsetOrAddress)); uint32_t end = i + 1 < nativeProtos.size() ? uint32_t(uintptr_t(getNativeProtoExecDataHeader(nativeProtos[i + 1].get()).entryOffsetOrAddress)) - : uint32_t(build.code.size()); + : uint32_t(build.code.size() * sizeof(build.code[0])); CODEGEN_ASSERT(begin < end); header.nativeCodeSize = end - begin; } - const CodeGenCompilationResult bindResult = + const ModuleBindResult bindResult = codeGenContext->bindModule(moduleId, protos, std::move(nativeProtos), reinterpret_cast(build.data.data()), build.data.size(), - reinterpret_cast(build.code.data()), build.code.size()); - if (bindResult != CodeGenCompilationResult::Success) - compilationResult.result = bindResult; + reinterpret_cast(build.code.data()), build.code.size() * sizeof(build.code[0])); + + if (stats != nullptr) + stats->functionsBound = bindResult.functionsBound; + + if (bindResult.compilationResult != CodeGenCompilationResult::Success) + compilationResult.result = bindResult.compilationResult; + return compilationResult; } +CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + + return compileInternal(moduleId, L, idx, flags, stats); +} + +CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) +{ + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + + return compileInternal({}, L, idx, flags, stats); +} + [[nodiscard]] bool isNativeExecutionEnabled_NEW(lua_State* L) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + return getCodeGenContext(L) != nullptr && L->global->ecb.enter == onEnter; } void setNativeExecutionEnabled_NEW(lua_State* L, bool enabled) { + CODEGEN_ASSERT(FFlag::LuauCodegenContext); + if (getCodeGenContext(L) != nullptr) L->global->ecb.enter = enabled ? onEnter : onEnterDisabled; } diff --git a/CodeGen/src/CodeGenContext.h b/CodeGen/src/CodeGenContext.h index 9e96e3ec..ca338da5 100644 --- a/CodeGen/src/CodeGenContext.h +++ b/CodeGen/src/CodeGenContext.h @@ -21,6 +21,13 @@ namespace CodeGen // multiple Luau VMs concurrently, and allows for sharing of executable native // code and related metadata. +struct ModuleBindResult +{ + CodeGenCompilationResult compilationResult = {}; + + uint32_t functionsBound = 0; +}; + class BaseCodeGenContext { public: @@ -28,10 +35,10 @@ public: [[nodiscard]] bool initHeaderFunctions(); - [[nodiscard]] virtual std::optional tryBindExistingModule( + [[nodiscard]] virtual std::optional tryBindExistingModule( const ModuleId& moduleId, const std::vector& moduleProtos) = 0; - [[nodiscard]] virtual CodeGenCompilationResult bindModule(const ModuleId& moduleId, const std::vector& moduleProtos, + [[nodiscard]] virtual ModuleBindResult bindModule(const std::optional& moduleId, const std::vector& moduleProtos, std::vector nativeExecDatas, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) = 0; virtual void onCloseState() noexcept = 0; @@ -51,10 +58,10 @@ class StandaloneCodeGenContext final : public BaseCodeGenContext public: StandaloneCodeGenContext(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext); - [[nodiscard]] virtual std::optional tryBindExistingModule( + [[nodiscard]] virtual std::optional tryBindExistingModule( const ModuleId& moduleId, const std::vector& moduleProtos) override; - [[nodiscard]] virtual CodeGenCompilationResult bindModule(const ModuleId& moduleId, const std::vector& moduleProtos, + [[nodiscard]] virtual ModuleBindResult bindModule(const std::optional& moduleId, const std::vector& moduleProtos, std::vector nativeExecDatas, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) override; virtual void onCloseState() noexcept override; @@ -68,10 +75,10 @@ class SharedCodeGenContext final : public BaseCodeGenContext public: SharedCodeGenContext(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext); - [[nodiscard]] virtual std::optional tryBindExistingModule( + [[nodiscard]] virtual std::optional tryBindExistingModule( const ModuleId& moduleId, const std::vector& moduleProtos) override; - [[nodiscard]] virtual CodeGenCompilationResult bindModule(const ModuleId& moduleId, const std::vector& moduleProtos, + [[nodiscard]] virtual ModuleBindResult bindModule(const std::optional& moduleId, const std::vector& moduleProtos, std::vector nativeExecDatas, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) override; virtual void onCloseState() noexcept override; @@ -87,29 +94,6 @@ private: // implementation is removed, the _NEW suffix can be dropped from these // functions. -class SharedCodeGenContext; - -struct SharedCodeGenContextDeleter -{ - void operator()(const SharedCodeGenContext* context) const noexcept; -}; - -using UniqueSharedCodeGenContext = std::unique_ptr; - -// Creates a new SharedCodeGenContext that can be used by multiple Luau VMs -// concurrently, using either the default allocator parameters or custom -// allocator parameters. -[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(); - -[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext(AllocationCallback* allocationCallback, void* allocationCallbackContext); - -[[nodiscard]] UniqueSharedCodeGenContext createSharedCodeGenContext( - size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext); - -// Destroys the provided SharedCodeGenContext. All Luau VMs using the -// SharedCodeGenContext must be destroyed before this function is called. -void destroySharedCodeGenContext(const SharedCodeGenContext* codeGenContext) noexcept; - // Initializes native code-gen on the provided Luau VM, using a VM-specific // code-gen context and either the default allocator parameters or custom // allocator parameters. @@ -123,6 +107,7 @@ void create_NEW(lua_State* L, size_t blockSize, size_t maxTotalSize, AllocationC // destroyed via lua_close. void create_NEW(lua_State* L, SharedCodeGenContext* codeGenContext); +CompilationResult compile_NEW(lua_State* L, int idx, unsigned int flags, CompilationStats* stats); CompilationResult compile_NEW(const ModuleId& moduleId, lua_State* L, int idx, unsigned int flags, CompilationStats* stats); // Returns true if native execution is currently enabled for this VM diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index e1a2b2a9..efd1034d 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -27,7 +27,7 @@ LUAU_FASTFLAG(DebugCodegenSkipNumbering) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) namespace Luau { @@ -312,7 +312,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& } } - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) markDeadStoresInBlockChains(ir); } diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index 8b27f40d..96d22e13 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -12,7 +12,7 @@ #include "lstate.h" -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) namespace Luau { @@ -30,7 +30,7 @@ static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vmovsd(luauRegValue(ra), xmm0); - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) build.mov(luauRegTag(ra), LUA_TNUMBER); if (nresults > 1) @@ -38,7 +38,7 @@ static void emitBuiltinMathFrexp(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vcvtsi2sd(xmm0, xmm0, dword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra + 1), xmm0); - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) build.mov(luauRegTag(ra + 1), LUA_TNUMBER); } } @@ -53,14 +53,14 @@ static void emitBuiltinMathModf(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vmovsd(xmm1, qword[sTemporarySlot + 0]); build.vmovsd(luauRegValue(ra), xmm1); - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) build.mov(luauRegTag(ra), LUA_TNUMBER); if (nresults > 1) { build.vmovsd(luauRegValue(ra + 1), xmm0); - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) build.mov(luauRegTag(ra + 1), LUA_TNUMBER); } } @@ -91,7 +91,7 @@ static void emitBuiltinMathSign(IrRegAllocX64& regs, AssemblyBuilderX64& build, build.vmovsd(luauRegValue(ra), tmp0.reg); - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) build.mov(luauRegTag(ra), LUA_TNUMBER); } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 060912d1..65a8544d 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -11,9 +11,7 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAGVARIABLE(LuauCodeGenOptVecA64, false) - -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenCheckTruthyFormB) namespace Luau @@ -203,7 +201,7 @@ static bool emitBuiltin( { case LBF_MATH_FREXP: { - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) { CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); emitInvokeLibm1P(build, offsetof(NativeContext, libm_frexp), arg); @@ -237,7 +235,7 @@ static bool emitBuiltin( } case LBF_MATH_MODF: { - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) { CODEGEN_ASSERT(nparams == 1 && (nresults == 1 || nresults == 2)); emitInvokeLibm1P(build, offsetof(NativeContext, libm_modf), arg); @@ -277,7 +275,7 @@ static bool emitBuiltin( build.fcsel(d0, d1, d0, getConditionFP(IrCondition::Less)); build.str(d0, mem(rBase, res * sizeof(TValue) + offsetof(TValue, value.n))); - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) { RegisterA64 temp = regs.allocTemp(KindA64::w); build.mov(temp, LUA_TNUMBER); @@ -1118,7 +1116,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { inst.regA64 = regs.allocReg(KindA64::q, index); - if (FFlag::LuauCodeGenOptVecA64 && inst.a.kind == IrOpKind::Constant) + if (inst.a.kind == IrOpKind::Constant) { float value = float(doubleOp(inst.a)); uint32_t asU32; @@ -1391,7 +1389,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) Label fresh; // used when guard aborts execution or jumps to a VM exit Label& fail = getTargetLabel(inst.c, fresh); - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) { if (tagOp(inst.b) == 0) { diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index b88dca81..bec5deea 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -8,7 +8,7 @@ #include -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) // TODO: when nresults is less than our actual result count, we can skip computing/writing unused results @@ -48,7 +48,7 @@ static BuiltinImplResult translateBuiltinNumberToNumber( builtinCheckDouble(build, build.vmReg(arg), pcpos); build.inst(IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(1)); - if (!FFlag::LuauCodegenRemoveDeadStores4) + if (!FFlag::LuauCodegenRemoveDeadStores5) { if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); @@ -112,7 +112,7 @@ static BuiltinImplResult translateBuiltinNumberTo2Number( build.inst( IrCmd::FASTCALL, build.constUint(bfid), build.vmReg(ra), build.vmReg(arg), args, build.constInt(1), build.constInt(nresults == 1 ? 1 : 2)); - if (!FFlag::LuauCodegenRemoveDeadStores4) + if (!FFlag::LuauCodegenRemoveDeadStores5) { if (ra != arg) build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index 3975e25a..3dc72610 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -119,7 +119,7 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) break; // These instructions read VmReg only after optimizeMemoryOperandsX64 - case IrCmd::CHECK_TAG: // TODO: remove with FFlagLuauCodegenRemoveDeadStores4 + case IrCmd::CHECK_TAG: // TODO: remove with FFlagLuauCodegenRemoveDeadStores5 case IrCmd::CHECK_TRUTHY: case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index 974873da..f910a342 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -17,9 +17,10 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenLoadTVTag) LUAU_FASTFLAGVARIABLE(LuauCodegenInferNumTag, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenLoadPropCheckRegLinkInTv, false) namespace Luau { @@ -609,7 +610,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (state.tryGetTag(source) == value) { - if (FFlag::DebugLuauAbortingChecks && !FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::DebugLuauAbortingChecks && !FFlag::LuauCodegenRemoveDeadStores5) replace(function, block, index, {IrCmd::CHECK_TAG, inst.a, inst.b, build.undef()}); else kill(function, inst); @@ -738,9 +739,10 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& IrCmd activeLoadCmd = IrCmd::NOP; uint32_t activeLoadValue = kInvalidInstIdx; - if (tag != 0xff) + // If we know the tag, we can try extracting the value from a register used by LOAD_TVALUE + // To do that, we have to ensure that the register link of the source value is still valid + if (tag != 0xff && (!FFlag::LuauCodegenLoadPropCheckRegLinkInTv || state.tryGetRegLink(inst.b) != nullptr)) { - // If we know the tag, try to extract the value from a register used by LOAD_TVALUE if (IrInst* arg = function.asInstOp(inst.b); arg && arg->cmd == IrCmd::LOAD_TVALUE && arg->a.kind == IrOpKind::VmReg) { std::tie(activeLoadCmd, activeLoadValue) = state.getPreviousVersionedLoadForTag(tag, arg->a); @@ -1100,7 +1102,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::FASTCALL: { - if (FFlag::LuauCodegenRemoveDeadStores4) + if (FFlag::LuauCodegenRemoveDeadStores5) { LuauBuiltinFunction bfid = LuauBuiltinFunction(function.uintOp(inst.a)); int firstReturnReg = vmRegOp(inst.b); diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp index a1a0d91d..dc187f4a 100644 --- a/CodeGen/src/OptimizeDeadStore.cpp +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -9,7 +9,7 @@ #include "lobject.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores4, false) +LUAU_FASTFLAGVARIABLE(LuauCodegenRemoveDeadStores5, false) LUAU_FASTFLAG(LuauCodegenLoadTVTag) // TODO: optimization can be improved by knowing which registers are live in at each VM exit @@ -19,6 +19,8 @@ namespace Luau namespace CodeGen { +constexpr uint8_t kUnknownTag = 0xff; + // Luau value structure reminder: // [ TValue ] // [ Value ][ Extra ][ Tag ] @@ -34,6 +36,9 @@ struct StoreRegInfo // This register might contain a GC object bool maybeGco = false; + + // Knowing the last stored tag can help safely remove additional unused partial stores + uint8_t knownTag = kUnknownTag; }; struct RemoveDeadStoreState @@ -66,6 +71,32 @@ struct RemoveDeadStoreState } } + void killTagAndValueStorePair(StoreRegInfo& regInfo) + { + bool tagEstablished = regInfo.tagInstIdx != ~0u || regInfo.knownTag != kUnknownTag; + + // When tag is 'nil', we don't need to remove the unused value store + bool valueEstablished = regInfo.valueInstIdx != ~0u || regInfo.knownTag == LUA_TNIL; + + // Partial stores can only be removed if the whole pair is established + if (tagEstablished && valueEstablished) + { + if (regInfo.tagInstIdx != ~0u) + { + kill(function, function.instructions[regInfo.tagInstIdx]); + regInfo.tagInstIdx = ~0u; + } + + if (regInfo.valueInstIdx != ~0u) + { + kill(function, function.instructions[regInfo.valueInstIdx]); + regInfo.valueInstIdx = ~0u; + } + + regInfo.maybeGco = false; + } + } + void killTValueStore(StoreRegInfo& regInfo) { if (regInfo.tvalueInstIdx != ~0u) @@ -86,15 +117,23 @@ struct RemoveDeadStoreState if (function.cfg.captured.regs.test(reg)) return; - killTagStore(regInfo); - killValueStore(regInfo); + killTagAndValueStorePair(regInfo); killTValueStore(regInfo); + + // Opaque register definition removes the knowledge of the actual tag value + regInfo.knownTag = kUnknownTag; } - // When a register value is being used, we forget about the last store location to not kill them + // When a register value is being used (read), we forget about the last store location to not kill them void useReg(uint8_t reg) { - info[reg] = StoreRegInfo{}; + StoreRegInfo& regInfo = info[reg]; + + // Register read doesn't clear the known tag + regInfo.tagInstIdx = ~0u; + regInfo.valueInstIdx = ~0u; + regInfo.tvalueInstIdx = ~0u; + regInfo.maybeGco = false; } // When checking control flow, such as exit to fallback blocks: @@ -104,7 +143,7 @@ struct RemoveDeadStoreState { if (op.kind == IrOpKind::VmExit) { - clear(); + readAllRegs(); } else if (op.kind == IrOpKind::Block) { @@ -120,7 +159,7 @@ struct RemoveDeadStoreState } else { - clear(); + readAllRegs(); } } else if (op.kind == IrOpKind::Undef) @@ -147,7 +186,16 @@ struct RemoveDeadStoreState bool isOut = out.regs.test(i) || (out.varargSeq && i >= out.varargStart); if (!isOut) - defReg(i); + { + StoreRegInfo& regInfo = info[i]; + + // Stores to captured registers are not removed since we don't track their uses outside of function + if (!function.cfg.captured.regs.test(i)) + { + killTagAndValueStorePair(regInfo); + killTValueStore(regInfo); + } + } } } } @@ -217,10 +265,10 @@ struct RemoveDeadStoreState void capture(int reg) {} // Full clear of the tracked information - void clear() + void readAllRegs() { for (int i = 0; i <= maxReg; i++) - info[i] = StoreRegInfo(); + useReg(i); hasGcoToClear = false; } @@ -231,8 +279,19 @@ struct RemoveDeadStoreState { for (int i = 0; i <= maxReg; i++) { - if (info[i].maybeGco) - info[i] = StoreRegInfo(); + StoreRegInfo& regInfo = info[i]; + + if (regInfo.maybeGco) + { + // If we happen to know the exact tag, it has to be a GCO, otherwise 'maybeGCO' should be false + CODEGEN_ASSERT(regInfo.knownTag == kUnknownTag || isGCO(regInfo.knownTag)); + + // Indirect register read by GC doesn't clear the known tag + regInfo.tagInstIdx = ~0u; + regInfo.valueInstIdx = ~0u; + regInfo.tvalueInstIdx = ~0u; + regInfo.maybeGco = false; + } } hasGcoToClear = false; @@ -247,6 +306,105 @@ struct RemoveDeadStoreState bool hasGcoToClear = false; }; +static bool tryReplaceTagWithFullStore(RemoveDeadStoreState& state, IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t instIndex, + IrOp targetOp, IrOp tagOp, StoreRegInfo& regInfo) +{ + uint8_t tag = function.tagOp(tagOp); + + // If the tag+value pair is established, we can mark both as dead and use a single split TValue store + if (regInfo.tagInstIdx != ~0u && (regInfo.valueInstIdx != ~0u || regInfo.knownTag == LUA_TNIL)) + { + // If the 'nil' is stored, we keep 'STORE_TAG Rn, tnil' as it writes the 'full' TValue + // If a 'nil' tag is being replaced by something else, we also keep 'STORE_TAG Rn, tag', expecting a value store to follow + // And value store has to follow, as the pre-DSO code would not allow GC to observe an incomplete stack variable + if (tag != LUA_TNIL && regInfo.valueInstIdx != ~0u) + { + IrOp prevValueOp = function.instructions[regInfo.valueInstIdx].b; + replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp}); + } + + state.killTagStore(regInfo); + state.killValueStore(regInfo); + + regInfo.tvalueInstIdx = instIndex; + regInfo.maybeGco = isGCO(tag); + regInfo.knownTag = tag; + state.hasGcoToClear |= regInfo.maybeGco; + return true; + } + + // We can also replace a dead split TValue store with a new one, while keeping the value the same + if (regInfo.tvalueInstIdx != ~0u) + { + IrInst& prev = function.instructions[regInfo.tvalueInstIdx]; + + if (prev.cmd == IrCmd::STORE_SPLIT_TVALUE) + { + CODEGEN_ASSERT(prev.d.kind == IrOpKind::None); + + // If the 'nil' is stored, we keep 'STORE_TAG Rn, tnil' as it writes the 'full' TValue + if (tag != LUA_TNIL) + { + IrOp prevValueOp = prev.c; + replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp}); + } + + state.killTValueStore(regInfo); + + regInfo.tvalueInstIdx = instIndex; + regInfo.maybeGco = isGCO(tag); + regInfo.knownTag = tag; + state.hasGcoToClear |= regInfo.maybeGco; + return true; + } + } + + return false; +} + +static bool tryReplaceValueWithFullStore(RemoveDeadStoreState& state, IrBuilder& build, IrFunction& function, IrBlock& block, uint32_t instIndex, + IrOp targetOp, IrOp valueOp, StoreRegInfo& regInfo) +{ + // If the tag+value pair is established, we can mark both as dead and use a single split TValue store + if (regInfo.tagInstIdx != ~0u && regInfo.valueInstIdx != ~0u) + { + IrOp prevTagOp = function.instructions[regInfo.tagInstIdx].b; + uint8_t prevTag = function.tagOp(prevTagOp); + + CODEGEN_ASSERT(regInfo.knownTag == prevTag); + replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, prevTagOp, valueOp}); + + state.killTagStore(regInfo); + state.killValueStore(regInfo); + + regInfo.tvalueInstIdx = instIndex; + return true; + } + + // We can also replace a dead split TValue store with a new one, while keeping the value the same + if (regInfo.tvalueInstIdx != ~0u) + { + IrInst& prev = function.instructions[regInfo.tvalueInstIdx]; + + if (prev.cmd == IrCmd::STORE_SPLIT_TVALUE) + { + IrOp prevTagOp = prev.b; + uint8_t prevTag = function.tagOp(prevTagOp); + + CODEGEN_ASSERT(regInfo.knownTag == prevTag); + CODEGEN_ASSERT(prev.d.kind == IrOpKind::None); + replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, prevTagOp, valueOp}); + + state.killTValueStore(regInfo); + + regInfo.tvalueInstIdx = instIndex; + return true; + } + } + + return false; +} + static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, IrFunction& function, IrBlock& block, IrInst& inst, uint32_t index) { switch (inst.cmd) @@ -261,18 +419,14 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, StoreRegInfo& regInfo = state.info[reg]; - state.killTagStore(regInfo); + if (tryReplaceTagWithFullStore(state, build, function, block, index, inst.a, inst.b, regInfo)) + break; uint8_t tag = function.tagOp(inst.b); - // Storing 'nil' TValue doesn't update the value part because we don't care about that part of 'nil' - // This however prevents us from removing unused value store elimination and has an impact on GC - // To solve this issues, we invalidate the value part of a 'nil' store as well - if (tag == LUA_TNIL) - state.killValueStore(regInfo); - regInfo.tagInstIdx = index; regInfo.maybeGco = isGCO(tag); + regInfo.knownTag = tag; state.hasGcoToClear |= regInfo.maybeGco; } break; @@ -293,7 +447,16 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, StoreRegInfo& regInfo = state.info[reg]; - state.killValueStore(regInfo); + if (tryReplaceValueWithFullStore(state, build, function, block, index, inst.a, inst.b, regInfo)) + { + regInfo.maybeGco = true; + state.hasGcoToClear |= true; + break; + } + + // Partial value store can be removed by a new one if the tag is known + if (regInfo.knownTag != kUnknownTag) + state.killValueStore(regInfo); regInfo.valueInstIdx = index; regInfo.maybeGco = true; @@ -302,7 +465,6 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, break; case IrCmd::STORE_DOUBLE: case IrCmd::STORE_INT: - case IrCmd::STORE_VECTOR: if (inst.a.kind == IrOpKind::VmReg) { int reg = vmRegOp(inst.a); @@ -312,9 +474,22 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, StoreRegInfo& regInfo = state.info[reg]; - state.killValueStore(regInfo); + if (tryReplaceValueWithFullStore(state, build, function, block, index, inst.a, inst.b, regInfo)) + break; + + // Partial value store can be removed by a new one if the tag is known + if (regInfo.knownTag != kUnknownTag) + state.killValueStore(regInfo); regInfo.valueInstIdx = index; + regInfo.maybeGco = false; + } + break; + case IrCmd::STORE_VECTOR: + // Partial vector value store cannot be combined into a STORE_SPLIT_TVALUE, so we skip dead store optimization for it + if (inst.a.kind == IrOpKind::VmReg) + { + state.useReg(vmRegOp(inst.a)); } break; case IrCmd::STORE_TVALUE: @@ -327,13 +502,15 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, StoreRegInfo& regInfo = state.info[reg]; - state.killTagStore(regInfo); - state.killValueStore(regInfo); + state.killTagAndValueStorePair(regInfo); state.killTValueStore(regInfo); regInfo.tvalueInstIdx = index; regInfo.maybeGco = true; + // We do not use tag inference from the source instruction here as it doesn't provide useful opportunities for dead store removal + regInfo.knownTag = kUnknownTag; + // If the argument is a vector, it's not a GC object // Note that for known boolean/number/GCO, we already optimize into STORE_SPLIT_TVALUE form // TODO (CLI-101027): similar code is used in constant propagation optimization and should be shared in utilities @@ -359,12 +536,12 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, StoreRegInfo& regInfo = state.info[reg]; - state.killTagStore(regInfo); - state.killValueStore(regInfo); + state.killTagAndValueStorePair(regInfo); state.killTValueStore(regInfo); regInfo.tvalueInstIdx = index; regInfo.maybeGco = isGCO(function.tagOp(inst.b)); + regInfo.knownTag = function.tagOp(inst.b); state.hasGcoToClear |= regInfo.maybeGco; } break; @@ -372,6 +549,16 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, // Guard checks can jump to a block which might be using some or all the values we stored case IrCmd::CHECK_TAG: state.checkLiveIns(inst.c); + + // Tag guard establishes the tag value of the register in the current block + if (IrInst* load = function.asInstOp(inst.a); load && load->cmd == IrCmd::LOAD_TAG && load->a.kind == IrOpKind::VmReg) + { + int reg = vmRegOp(load->a); + + StoreRegInfo& regInfo = state.info[reg]; + + regInfo.knownTag = function.tagOp(inst.b); + } break; case IrCmd::TRY_NUM_TO_INDEX: state.checkLiveIns(inst.b); diff --git a/CodeGen/src/SharedCodeAllocator.cpp b/CodeGen/src/SharedCodeAllocator.cpp index 5ca5d8c5..c659a9b2 100644 --- a/CodeGen/src/SharedCodeAllocator.cpp +++ b/CodeGen/src/SharedCodeAllocator.cpp @@ -2,6 +2,7 @@ #include "Luau/SharedCodeAllocator.h" #include "Luau/CodeAllocator.h" +#include "Luau/CodeGenCommon.h" #include #include @@ -39,15 +40,15 @@ struct NativeProtoBytecodeIdLess } }; -NativeModule::NativeModule(SharedCodeAllocator* allocator, const ModuleId& moduleId, const uint8_t* moduleBaseAddress, +NativeModule::NativeModule(SharedCodeAllocator* allocator, const std::optional& moduleId, const uint8_t* moduleBaseAddress, std::vector nativeProtos) noexcept : allocator{allocator} , moduleId{moduleId} , moduleBaseAddress{moduleBaseAddress} , nativeProtos{std::move(nativeProtos)} { - LUAU_ASSERT(allocator != nullptr); - LUAU_ASSERT(moduleBaseAddress != nullptr); + CODEGEN_ASSERT(allocator != nullptr); + CODEGEN_ASSERT(moduleBaseAddress != nullptr); // Bind all of the NativeProtos to this module: for (const NativeProtoExecDataPtr& nativeProto : this->nativeProtos) @@ -60,12 +61,13 @@ NativeModule::NativeModule(SharedCodeAllocator* allocator, const ModuleId& modul std::sort(this->nativeProtos.begin(), this->nativeProtos.end(), NativeProtoBytecodeIdLess{}); // We should not have two NativeProtos for the same bytecode id: - LUAU_ASSERT(std::adjacent_find(this->nativeProtos.begin(), this->nativeProtos.end(), NativeProtoBytecodeIdEqual{}) == this->nativeProtos.end()); + CODEGEN_ASSERT( + std::adjacent_find(this->nativeProtos.begin(), this->nativeProtos.end(), NativeProtoBytecodeIdEqual{}) == this->nativeProtos.end()); } NativeModule::~NativeModule() noexcept { - LUAU_ASSERT(refcount == 0); + CODEGEN_ASSERT(refcount == 0); } size_t NativeModule::addRef() const noexcept @@ -84,7 +86,7 @@ size_t NativeModule::release() const noexcept if (newRefcount != 0) return newRefcount; - allocator->eraseNativeModuleIfUnreferenced(moduleId); + allocator->eraseNativeModuleIfUnreferenced(*this); // NOTE: *this may have been destroyed by the prior call, and must not be // accessed after this point. @@ -96,6 +98,11 @@ size_t NativeModule::release() const noexcept return refcount; } +[[nodiscard]] const std::optional& NativeModule::getModuleId() const noexcept +{ + return moduleId; +} + [[nodiscard]] const uint8_t* NativeModule::getModuleBaseAddress() const noexcept { return moduleBaseAddress; @@ -107,7 +114,7 @@ size_t NativeModule::release() const noexcept if (range.first == range.second) return nullptr; - LUAU_ASSERT(std::next(range.first) == range.second); + CODEGEN_ASSERT(std::next(range.first) == range.second); return range.first->get(); } @@ -118,7 +125,7 @@ size_t NativeModule::release() const noexcept } -NativeModuleRef::NativeModuleRef(NativeModule* nativeModule) noexcept +NativeModuleRef::NativeModuleRef(const NativeModule* nativeModule) noexcept : nativeModule{nativeModule} { if (nativeModule != nullptr) @@ -198,7 +205,8 @@ SharedCodeAllocator::~SharedCodeAllocator() noexcept { // The allocator should not be destroyed until all outstanding references // have been released and all allocated modules have been destroyed. - LUAU_ASSERT(nativeModules.empty()); + CODEGEN_ASSERT(identifiedModules.empty()); + CODEGEN_ASSERT(anonymousModuleCount == 0); } [[nodiscard]] NativeModuleRef SharedCodeAllocator::tryGetNativeModule(const ModuleId& moduleId) const noexcept @@ -224,33 +232,59 @@ std::pair SharedCodeAllocator::getOrInsertNativeModule(co return {}; } - std::unique_ptr& nativeModule = nativeModules[moduleId]; + std::unique_ptr& nativeModule = identifiedModules[moduleId]; nativeModule = std::make_unique(this, moduleId, codeStart, std::move(nativeProtos)); return {NativeModuleRef{nativeModule.get()}, true}; } -void SharedCodeAllocator::eraseNativeModuleIfUnreferenced(const ModuleId& moduleId) +NativeModuleRef SharedCodeAllocator::insertAnonymousNativeModule( + std::vector nativeProtos, const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize) { std::unique_lock lock{mutex}; - const auto it = nativeModules.find(moduleId); - if (it == nativeModules.end()) - return; + uint8_t* nativeData = nullptr; + size_t sizeNativeData = 0; + uint8_t* codeStart = nullptr; + if (!codeAllocator->allocate(data, int(dataSize), code, int(codeSize), nativeData, sizeNativeData, codeStart)) + { + return {}; + } + + NativeModuleRef nativeModuleRef{new NativeModule{this, std::nullopt, codeStart, std::move(nativeProtos)}}; + ++anonymousModuleCount; + + return nativeModuleRef; +} + +void SharedCodeAllocator::eraseNativeModuleIfUnreferenced(const NativeModule& nativeModule) +{ + std::unique_lock lock{mutex}; // It is possible that someone acquired a reference to the module between // the time that we called this function and the time that we acquired the // lock. If so, that's okay. - if (it->second->getRefcount() != 0) + if (nativeModule.getRefcount() != 0) return; - nativeModules.erase(it); + if (const std::optional& moduleId = nativeModule.getModuleId()) + { + const auto it = identifiedModules.find(*moduleId); + CODEGEN_ASSERT(it != identifiedModules.end()); + + identifiedModules.erase(it); + } + else + { + CODEGEN_ASSERT(anonymousModuleCount.fetch_sub(1) != 0); + delete &nativeModule; + } } [[nodiscard]] NativeModuleRef SharedCodeAllocator::tryGetNativeModuleWithLockHeld(const ModuleId& moduleId) const noexcept { - const auto it = nativeModules.find(moduleId); - if (it == nativeModules.end()) + const auto it = identifiedModules.find(moduleId); + if (it == identifiedModules.end()) return NativeModuleRef{}; return NativeModuleRef{it->second.get()}; diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 15fb9716..c3163265 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -14,7 +14,6 @@ inline bool isFlagExperimental(const char* flag) "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauTinyControlFlowAnalysis", // waiting for updates to packages depended by internal builtin plugins "LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative - "LuauUpdatedRequireByStringSemantics", // requires some small fixes to fully implement some proposed changes // makes sure we always have at least one entry nullptr, }; diff --git a/Sources.cmake b/Sources.cmake index 90d053d5..6adbf283 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -422,6 +422,7 @@ if(TARGET Luau.UnitTest) tests/Fixture.h tests/Frontend.test.cpp tests/InsertionOrderedMap.test.cpp + tests/Instantiation2.test.cpp tests/IostreamOptional.h tests/IrBuilder.test.cpp tests/IrCallWrapperX64.test.cpp @@ -443,7 +444,6 @@ if(TARGET Luau.UnitTest) tests/ScopedFlags.h tests/Simplify.test.cpp tests/Set.test.cpp - tests/SharedCodeAllocator.test.cpp tests/StringUtils.test.cpp tests/Subtyping.test.cpp tests/Symbol.test.cpp @@ -496,6 +496,7 @@ if(TARGET Luau.Conformance) tests/RegisterCallbacks.cpp tests/Conformance.test.cpp tests/IrLowering.test.cpp + tests/SharedCodeAllocator.test.cpp tests/main.cpp) endif() diff --git a/VM/src/ldblib.cpp b/VM/src/ldblib.cpp index 77c1befc..dfc61e4d 100644 --- a/VM/src/ldblib.cpp +++ b/VM/src/ldblib.cpp @@ -8,8 +8,6 @@ #include #include -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauDebugInfoDupArgLeftovers, false) - static lua_State* getthread(lua_State* L, int* arg) { if (lua_isthread(L, 1)) @@ -36,8 +34,7 @@ static int db_info(lua_State* L) // for 'f' option, we reserve one slot and we also record the stack top lua_rawcheckstack(L1, 1); - if (DFFlag::LuauDebugInfoDupArgLeftovers) - l1top = lua_gettop(L1); + l1top = lua_gettop(L1); } int level; @@ -70,7 +67,7 @@ static int db_info(lua_State* L) if (occurs[*it - 'a']) { // restore stack state of another thread as 'f' option might not have been visited yet - if (DFFlag::LuauDebugInfoDupArgLeftovers && L != L1) + if (L != L1) lua_settop(L1, l1top); luaL_argerror(L, arg + 2, "duplicate option"); diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index bac7ccb8..2610ddf8 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -14,8 +14,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCodegenHeapSizeReport, false) - static void validateobjref(global_State* g, GCObject* f, GCObject* t) { LUAU_ASSERT(!isdead(g, t)); @@ -826,15 +824,12 @@ static void enumproto(EnumContext* ctx, Proto* p) size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; - if (FFlag::LuauCodegenHeapSizeReport) + if (p->execdata && ctx->L->global->ecb.getmemorysize) { - if (p->execdata && ctx->L->global->ecb.getmemorysize) - { - size_t nativesize = ctx->L->global->ecb.getmemorysize(ctx->L, p); + size_t nativesize = ctx->L->global->ecb.getmemorysize(ctx->L, p); - ctx->node(ctx->context, p->execdata, uint8_t(LUA_TNONE), p->memcat, nativesize, NULL); - ctx->edge(ctx->context, enumtopointer(obj2gco(p)), p->execdata, "[native]"); - } + ctx->node(ctx->context, p->execdata, uint8_t(LUA_TNONE), p->memcat, nativesize, NULL); + ctx->edge(ctx->context, enumtopointer(obj2gco(p)), p->execdata, "[native]"); } enumnode(ctx, obj2gco(p), size, p->source ? getstr(p->source) : NULL); diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 60de435d..68fa48f9 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAG(LuauCodeGenOptVecA64) - using namespace Luau::CodeGen; using namespace Luau::CodeGen::A64; @@ -451,8 +449,6 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPCompare") TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPImm") { - ScopedFastFlag luauCodeGenOptVecA64{FFlag::LuauCodeGenOptVecA64, true}; - SINGLE_COMPARE(fmov(d0, 0), 0x2F00E400); SINGLE_COMPARE(fmov(d0, 0.125), 0x1E681000); SINGLE_COMPARE(fmov(d0, -0.125), 0x1E781000); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 1ab49c82..4bafc8fa 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -32,7 +32,6 @@ void luaC_validate(lua_State* L); LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) -LUAU_DYNAMIC_FASTFLAG(LuauDebugInfoDupArgLeftovers) LUAU_FASTFLAG(LuauCompileRepeatUntilSkippedLocals) LUAU_FASTFLAG(LuauCodegenInferNumTag) LUAU_FASTFLAG(LuauCodegenDetailedCompilationResult) @@ -639,8 +638,6 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { - ScopedFastFlag luauDebugInfoDupArgLeftovers{DFFlag::LuauDebugInfoDupArgLeftovers, true}; - runConformance("debug.lua"); } diff --git a/tests/DataFlowGraph.test.cpp b/tests/DataFlowGraph.test.cpp index c3a6a464..1b7fe842 100644 --- a/tests/DataFlowGraph.test.cpp +++ b/tests/DataFlowGraph.test.cpp @@ -658,4 +658,24 @@ TEST_CASE_FIXTURE(DataFlowGraphFixture, "insert_trivial_phi_nodes_inside_of_phi_ CHECK(t2phi->operands.at(0) == t1); } +TEST_CASE_FIXTURE(DataFlowGraphFixture, "dfg_function_definition_in_a_do_block") +{ + dfg(R"( + local f + do + function f() + end + end + f() + )"); + + DefId x1 = graph->getDef(query(module)->vars.data[0]); + DefId x2 = getDef(); // x = 5 + DefId x3 = getDef(); // print(x) + + CHECK(x1 != x2); + CHECK(x1 != x3); + CHECK(x2 == x3); +} + TEST_SUITE_END(); diff --git a/tests/Error.test.cpp b/tests/Error.test.cpp index 10f38abe..677e3217 100644 --- a/tests/Error.test.cpp +++ b/tests/Error.test.cpp @@ -6,8 +6,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauStacklessTypeClone3); - TEST_SUITE_BEGIN("ErrorTests"); TEST_CASE("TypeError_code_should_return_nonzero_code") @@ -19,7 +17,7 @@ TEST_CASE("TypeError_code_should_return_nonzero_code") TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_names_show_instead_of_tables") { frontend.options.retainFullTypeGraphs = false; - ScopedFastFlag sff{FFlag::LuauStacklessTypeClone3, true}; + CheckResult result = check(R"( --!strict local Account = {} diff --git a/tests/Instantiation2.test.cpp b/tests/Instantiation2.test.cpp new file mode 100644 index 00000000..ed9d7198 --- /dev/null +++ b/tests/Instantiation2.test.cpp @@ -0,0 +1,53 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Instantiation2.h" + +#include "Fixture.h" +#include "ClassFixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("Instantiation2Test"); + +TEST_CASE_FIXTURE(Fixture, "weird_cyclic_instantiation") +{ + TypeArena arena; + Scope scope(builtinTypes->anyTypePack); + + TypeId genericT = arena.addType(GenericType{"T"}); + + TypeId idTy = arena.addType(FunctionType{ + /* generics */ {genericT}, + /* genericPacks */ {}, + /* argTypes */ arena.addTypePack({genericT}), + /* retTypes */ arena.addTypePack({genericT}) + }); + + DenseHashMap genericSubstitutions{nullptr}; + DenseHashMap genericPackSubstitutions{nullptr}; + + TypeId freeTy = arena.freshType(&scope); + FreeType* ft = getMutable(freeTy); + REQUIRE(ft); + ft->lowerBound = idTy; + ft->upperBound = builtinTypes->unknownType; + + genericSubstitutions[genericT] = freeTy; + + CHECK("(T) -> T" == toString(idTy)); + + std::optional res = instantiate2(&arena, std::move(genericSubstitutions), std::move(genericPackSubstitutions), idTy); + + // Substitutions should not mutate the original type! + CHECK("(T) -> T" == toString(idTy)); + + // Weird looking because we haven't properly clipped the generic from the + // function type, but this is what we asked for. + REQUIRE(res); + CHECK("<(T) -> T>((T) -> T) -> (T) -> T" == toString(*res)); +} + +TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 00e9d312..7ec01e77 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -12,9 +12,10 @@ #include -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTFLAG(LuauCodegenInferNumTag) +LUAU_FASTFLAG(LuauCodegenLoadPropCheckRegLinkInTv) using namespace Luau::CodeGen; @@ -117,6 +118,7 @@ public: static const int tnumber = 3; static const int tstring = 5; static const int ttable = 6; + static const int tfunction = 7; }; TEST_SUITE_BEGIN("Optimization"); @@ -2539,7 +2541,7 @@ bb_0: ; useCount: 0 TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepInvalidation") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp block = build.block(IrBlockKind::Internal); IrOp followup = build.block(IrBlockKind::Internal); @@ -2580,7 +2582,7 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects1") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -2605,7 +2607,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "FastCallEffects2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -2964,7 +2966,7 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "ForgprepImplicitUse") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); IrOp direct = build.block(IrBlockKind::Internal); @@ -3468,25 +3470,83 @@ bb_1: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "TaggedValuePropagationIntoTvalueChecksRegisterVersion") +{ + ScopedFastFlag luauCodegenLoadPropCheckRegLinkInTv{FFlag::LuauCodegenLoadPropCheckRegLinkInTv, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + IrOp a1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp b1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(1)); + IrOp sum1 = build.inst(IrCmd::ADD_NUM, a1, b1); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), sum1); + build.inst(IrCmd::STORE_TAG, build.vmReg(7), build.constTag(tnumber)); + + IrOp a2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); + IrOp b2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp sum2 = build.inst(IrCmd::ADD_NUM, a2, b2); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(8), sum2); + build.inst(IrCmd::STORE_TAG, build.vmReg(8), build.constTag(tnumber)); + + IrOp old7 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(7), build.constInt(0), build.constTag(tnumber)); + IrOp old8 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(8), build.constInt(0), build.constTag(tnumber)); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(8), old7); // Invalidate R8 + build.inst(IrCmd::STORE_TVALUE, build.vmReg(9), old8); // Old R8 cannot be substituted as it was invalidated + + build.inst(IrCmd::RETURN, build.vmReg(8), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0, R1, R2, R3 + %0 = LOAD_DOUBLE R0 + %1 = LOAD_DOUBLE R1 + %2 = ADD_NUM %0, %1 + STORE_DOUBLE R7, %2 + STORE_TAG R7, tnumber + %5 = LOAD_DOUBLE R2 + %6 = LOAD_DOUBLE R3 + %7 = ADD_NUM %5, %6 + STORE_DOUBLE R8, %7 + STORE_TAG R8, tnumber + %11 = LOAD_TVALUE R8, 0i, tnumber + STORE_SPLIT_TVALUE R8, tnumber, %2 + STORE_TVALUE R9, %11 + RETURN R8, 2i + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("DeadStoreRemoval"); TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDoubleStore") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(2.0)); // Should remove previous store build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.constDouble(1.0)); - build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4)); // Should remove previous store of different type + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tnumber)); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4)); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tboolean)); // Should remove previous store of different type build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnil)); - build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); // Should remove previous store + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.constDouble(4.0)); build.inst(IrCmd::STORE_TAG, build.vmReg(4), build.constTag(tnil)); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(4), build.constDouble(1.0)); @@ -3507,12 +3567,13 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDoubleStore") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: ; in regs: R0 - STORE_DOUBLE R1, 2 - STORE_INT R2, 4i + STORE_SPLIT_TVALUE R1, tnumber, 2 + STORE_SPLIT_TVALUE R2, tboolean, 4i STORE_TAG R3, tnumber + STORE_DOUBLE R3, 4 STORE_SPLIT_TVALUE R4, tnumber, 2 - %9 = LOAD_TVALUE R0 - STORE_TVALUE R5, %9 + %13 = LOAD_TVALUE R0 + STORE_TVALUE R5, %13 RETURN R1, 5i )"); @@ -3520,19 +3581,22 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturn") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); build.beginBlock(entry); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4)); - build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tboolean)); build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(4), build.constTag(tnumber), build.constDouble(2.0)); IrOp someTv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0)); build.inst(IrCmd::STORE_TVALUE, build.vmReg(5), someTv); + build.inst(IrCmd::STORE_TAG, build.vmReg(6), build.constTag(tnil)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); updateUseCounts(build.function); @@ -3548,9 +3612,39 @@ bb_0: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "UnusedAtReturnPartial") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1.0)); + build.inst(IrCmd::STORE_INT, build.vmReg(2), build.constInt(4)); + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Partial stores cannot be removed, even if unused + // Existance of an unpaired partial store means that the other valid part is a block live in (even if not present is this test) + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; in regs: R0 + STORE_DOUBLE R1, 1 + STORE_INT R2, 4i + STORE_TAG R3, tnumber + RETURN R0, 1i + +)"); +} + TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse1") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3580,7 +3674,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3614,7 +3708,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse3") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3644,7 +3738,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "HiddenPointerUse4") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3678,7 +3772,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "PartialVsFullStoresWithRecombination") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3703,7 +3797,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "IgnoreFastcallAdjustment") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3722,9 +3816,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IgnoreFastcallAdjustment") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: - STORE_TAG R1, tnumber ADJUST_STACK_TO_REG R1, 1i - STORE_DOUBLE R1, 1 + STORE_SPLIT_TVALUE R1, tnumber, 1 RETURN R1, 1i )"); @@ -3732,7 +3825,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "JumpImplicitLiveOut") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); IrOp next = build.block(IrBlockKind::Internal); @@ -3769,7 +3862,7 @@ bb_1: TEST_CASE_FIXTURE(IrBuilderFixture, "KeepCapturedRegisterStores") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3808,7 +3901,7 @@ bb_0: TEST_CASE_FIXTURE(IrBuilderFixture, "StoreCannotBeReplacedWithCheck") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; ScopedFastFlag debugLuauAbortingChecks{FFlag::DebugLuauAbortingChecks, true}; IrOp block = build.block(IrBlockKind::Internal); @@ -3875,9 +3968,363 @@ bb_2: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32))); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32))); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(1.0)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Even though R1 is not live in of the fallback, stack state cannot be left in a partial store state + // Either tag+pointer store should both remain before the guard, or they both have to be made after + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0 +; out regs: R0, R1 + CHECK_SAFE_ENV bb_fallback_1 + %4 = NEW_TABLE 16u, 32u + STORE_SPLIT_TVALUE R1, ttable, %4 + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0 +; out regs: R0, R1 + CHECK_GC + STORE_SPLIT_TVALUE R1, tnumber, 1 + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks2") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); // Tag store unpaired to a visible value store + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(2))); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(1.0)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // If table tag store at the start is removed, GC assists in the fallback can observe value with a wrong tag + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0, R2 +; out regs: R0, R1 + STORE_TAG R1, tnumber + CHECK_SAFE_ENV bb_fallback_1 + %2 = LOAD_TVALUE R2 + STORE_TVALUE R1, %2 + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0 +; out regs: R0, R1 + CHECK_GC + STORE_SPLIT_TVALUE R1, tnumber, 1 + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FullStoreHasToBeObservableFromFallbacks3") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(1)), build.constTag(tfunction), fallback); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), build.inst(IrCmd::LOAD_POINTER, build.vmConst(10))); + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(1)); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(1.0)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + // Tag check establishes that at that point, the tag of the value IS a function (as an exit here has to be with well-formed stack) + // Later additional function pointer store can be removed, even if it observable from the GC in the fallback + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_fallback_1, bb_2 +; in regs: R0, R1 +; out regs: R0, R1 + %0 = LOAD_TAG R1 + CHECK_TAG %0, tfunction, bb_fallback_1 + CHECK_SAFE_ENV bb_fallback_1 + STORE_DOUBLE R1, 1 + STORE_TAG R1, tnumber + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0, bb_0 +; successors: bb_2 +; in regs: R0 +; out regs: R0, R1 + CHECK_GC + STORE_SPLIT_TVALUE R1, tnumber, 1 + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "SafePartialValueStoresWithPreservedTag") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(1)); + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); // While R1 has to be observed in full by the fallback + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(2)); // This partial store is safe to remove because number tag is established + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(3)); // And so is this + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(4)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // If table tag store at the start is removed, GC assists in the fallback can observe value with a wrong tag + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0 +; out regs: R0, R1 + STORE_SPLIT_TVALUE R1, tnumber, 1 + CHECK_SAFE_ENV bb_fallback_1 + STORE_DOUBLE R1, 4 + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0, R1 +; out regs: R0, R1 + CHECK_GC + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "SafePartialValueStoresWithPreservedTag2") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp last = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(1)); + build.inst(IrCmd::CHECK_SAFE_ENV, fallback); // While R1 has to be observed in full by the fallback + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.constDouble(2)); // This partial store is safe to remove because tag is established + build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(1), build.constTag(tnumber), build.constDouble(4)); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(fallback); + build.inst(IrCmd::CHECK_GC); + build.inst(IrCmd::JUMP, last); + + build.beginBlock(last); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // If table tag store at the start is removed, GC assists in the fallback can observe value with a wrong tag + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0 +; out regs: R0, R1 + STORE_SPLIT_TVALUE R1, tnumber, 1 + CHECK_SAFE_ENV bb_fallback_1 + STORE_SPLIT_TVALUE R1, tnumber, 4 + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0, R1 +; out regs: R0, R1 + CHECK_GC + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0, R1 + RETURN R0, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DoNotReturnWithPartialStores") +{ + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; + + IrOp entry = build.block(IrBlockKind::Internal); + IrOp success = build.block(IrBlockKind::Internal); + IrOp fail = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), build.inst(IrCmd::NEW_TABLE, build.constUint(0), build.constUint(0))); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(ttable)); + IrOp toUint = build.inst(IrCmd::NUM_TO_UINT, build.constDouble(-1)); + IrOp bitAnd = build.inst(IrCmd::BITAND_UINT, toUint, build.constInt(4)); + build.inst(IrCmd::JUMP_CMP_INT, bitAnd, build.constInt(0), build.cond(IrCondition::Equal), success, fail); + + build.beginBlock(success); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(0)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(fail); + build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tboolean)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + constPropInBlockChains(build, true); + markDeadStoresInBlockChains(build); + + // Even though R1 is not live out at return, we stored table tag followed by an integer value + // Boolean tag store has to remain, even if unused, because all stack slots are visible to GC + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: +; successors: bb_1, bb_2 +; in regs: R0 +; out regs: R0 + %0 = NEW_TABLE 0u, 0u + STORE_POINTER R1, %0 + STORE_TAG R1, ttable + %3 = NUM_TO_UINT -1 + %4 = BITAND_UINT %3, 4i + JUMP_CMP_INT %4, 0i, eq, bb_1, bb_2 + +bb_1: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R0 +; out regs: R0 + STORE_INT R1, 0i + JUMP bb_3 + +bb_2: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R0 +; out regs: R0 + STORE_INT R1, 1i + JUMP bb_3 + +bb_3: +; predecessors: bb_1, bb_2 +; in regs: R0 + STORE_TAG R1, tboolean + RETURN R0, 1i + +)"); +} + TEST_CASE_FIXTURE(IrBuilderFixture, "PartialOverFullValue") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; IrOp entry = build.block(IrBlockKind::Internal); @@ -3891,7 +4338,9 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PartialOverFullValue") build.inst(IrCmd::STORE_POINTER, build.vmReg(0), build.inst(IrCmd::NEW_TABLE, build.constUint(4), build.constUint(8))); build.inst(IrCmd::STORE_SPLIT_TVALUE, build.vmReg(0), build.constTag(tnumber), build.constDouble(1.0)); build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tstring)); + IrOp newtable = build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32)); build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(ttable)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), newtable); build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); updateUseCounts(build.function); @@ -3900,8 +4349,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PartialOverFullValue") CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( bb_0: - STORE_SPLIT_TVALUE R0, tnumber, 1 - STORE_TAG R0, ttable + %11 = NEW_TABLE 16u, 32u + STORE_SPLIT_TVALUE R0, ttable, %11 RETURN R0, 1i )"); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index f4b85a33..fffafe4d 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -12,7 +12,7 @@ #include -LUAU_FASTFLAG(LuauCodegenRemoveDeadStores4) +LUAU_FASTFLAG(LuauCodegenRemoveDeadStores5) LUAU_FASTFLAG(LuauCodegenLoadTVTag) static std::string getCodegenAssembly(const char* source) @@ -89,7 +89,7 @@ bb_bytecode_1: TEST_CASE("VectorComponentRead") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function compsum(a: vector) @@ -168,7 +168,7 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) @@ -202,7 +202,7 @@ bb_bytecode_1: TEST_CASE("VectorSubMulDiv2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector) @@ -232,7 +232,7 @@ bb_bytecode_1: TEST_CASE("VectorMulDivMixed") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function vec3combo(a: vector, b: vector, c: vector, d: vector) @@ -274,7 +274,7 @@ bb_bytecode_1: TEST_CASE("ExtraMathMemoryOperands") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a: number, b: number, c: number, d: number, e: number) @@ -312,7 +312,7 @@ bb_bytecode_1: TEST_CASE("DseInitialStackState") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo() @@ -352,7 +352,7 @@ bb_5: TEST_CASE("DseInitialStackState2") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a) @@ -373,7 +373,7 @@ bb_bytecode_0: TEST_CASE("DseInitialStackState3") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( local function foo(a) @@ -394,7 +394,7 @@ bb_bytecode_0: TEST_CASE("VectorConstantTag") { - ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores4, true}; + ScopedFastFlag luauCodegenRemoveDeadStores{FFlag::LuauCodegenRemoveDeadStores5, true}; ScopedFastFlag luauCodegenLoadTVTag{FFlag::LuauCodegenLoadTVTag, true}; CHECK_EQ("\n" + getCodegenAssembly(R"( diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 80d4507b..dd7538ae 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -2,8 +2,6 @@ #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/Module.h" -#include "Luau/Scope.h" -#include "Luau/RecursionCounter.h" #include "Luau/Parser.h" #include "Fixture.h" @@ -14,10 +12,8 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauStacklessTypeClone3) LUAU_FASTFLAG(DebugLuauFreezeArena); LUAU_FASTINT(LuauTypeCloneIterationLimit); -LUAU_FASTINT(LuauTypeCloneRecursionLimit); TEST_SUITE_BEGIN("ModuleTests"); @@ -331,47 +327,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "clone_self_property") CHECK_EQ("This function must be called with self. Did you mean to use a colon instead of a dot?", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") -{ -#if defined(_DEBUG) || defined(_NOOPT) - int limit = 250; -#else - int limit = 400; -#endif - - ScopedFastFlag sff{FFlag::LuauStacklessTypeClone3, false}; - ScopedFastInt luauTypeCloneRecursionLimit{FInt::LuauTypeCloneRecursionLimit, limit}; - - TypeArena src; - - TypeId table = src.addType(TableType{}); - TypeId nested = table; - - for (int i = 0; i < limit + 100; i++) - { - TableType* ttv = getMutable(nested); - - ttv->props["a"].setType(src.addType(TableType{})); - nested = ttv->props["a"].type(); - } - - TypeArena dest; - CloneState cloneState{builtinTypes}; - - CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); -} - TEST_CASE_FIXTURE(Fixture, "clone_iteration_limit") { - ScopedFastFlag sff{FFlag::LuauStacklessTypeClone3, true}; - ScopedFastInt sfi{FInt::LuauTypeCloneIterationLimit, 500}; + ScopedFastInt sfi{FInt::LuauTypeCloneIterationLimit, 2000}; TypeArena src; TypeId table = src.addType(TableType{}); TypeId nested = table; - for (int i = 0; i < 2500; i++) + int nesting = 2500; + for (int i = 0; i < nesting; i++) { TableType* ttv = getMutable(nested); ttv->props["a"].setType(src.addType(TableType{})); @@ -533,8 +499,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "clone_table_bound_to_table_bound_to_table") TEST_CASE_FIXTURE(BuiltinsFixture, "clone_a_bound_type_to_a_persistent_type") { - ScopedFastFlag sff{FFlag::LuauStacklessTypeClone3, true}; - TypeArena arena; TypeId boundTo = arena.addType(BoundType{builtinTypes->numberType}); @@ -549,8 +513,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "clone_a_bound_type_to_a_persistent_type") TEST_CASE_FIXTURE(BuiltinsFixture, "clone_a_bound_typepack_to_a_persistent_typepack") { - ScopedFastFlag sff{FFlag::LuauStacklessTypeClone3, true}; - TypeArena arena; TypePackId boundTo = arena.addTypePack(BoundTypePack{builtinTypes->neverTypePack}); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 97d2dafe..61fa6391 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -11,7 +11,7 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAG(LuauTransitiveSubtyping); +LUAU_FASTFLAG(LuauFixNormalizeCaching); using namespace Luau; @@ -29,21 +29,6 @@ struct IsSubtypeFixture : Fixture return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); } - - bool isConsistentSubtype(TypeId a, TypeId b) - { - // any test that is testing isConsistentSubtype is testing the old solver exclusively! - ScopedFastFlag noDcr{FFlag::DebugLuauDeferredConstraintResolution, false}; - - Location location; - ModulePtr module = getMainModule(); - REQUIRE(module); - - if (!module->hasModuleScope()) - FAIL("isSubtype: module scope data is not available"); - - return ::Luau::isConsistentSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); - } }; } // namespace @@ -90,22 +75,6 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "functions") CHECK(isSubtype(a, d)); } -TEST_CASE_FIXTURE(IsSubtypeFixture, "functions_and_any") -{ - check(R"( - function a(n: number) return "string" end - function b(q: any) return 5 :: any end - )"); - - TypeId a = requireType("a"); - TypeId b = requireType("b"); - - // any makes things work even when it makes no sense. - - CHECK(isConsistentSubtype(b, a)); - CHECK(isConsistentSubtype(a, b)); -} - TEST_CASE_FIXTURE(IsSubtypeFixture, "variadic_functions_with_no_head") { check(R"( @@ -182,10 +151,6 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_union_prop") TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop") { - ScopedFastFlag sffs[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - check(R"( local a: {x: number} local b: {x: any} @@ -199,7 +164,6 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "table_with_any_prop") else CHECK(isSubtype(a, b)); CHECK(!isSubtype(b, a)); - CHECK(isConsistentSubtype(b, a)); } TEST_CASE_FIXTURE(IsSubtypeFixture, "intersection") @@ -243,10 +207,6 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "union_and_intersection") TEST_CASE_FIXTURE(IsSubtypeFixture, "tables") { - ScopedFastFlag sffs[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - check(R"( local a: {x: number} local b: {x: any} @@ -264,7 +224,6 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "tables") else CHECK(isSubtype(a, b)); CHECK(!isSubtype(b, a)); - CHECK(isConsistentSubtype(b, a)); CHECK(!isSubtype(c, a)); CHECK(!isSubtype(a, c)); @@ -398,10 +357,6 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "metatable" * doctest::expected_failures{1}) TEST_CASE_FIXTURE(IsSubtypeFixture, "any_is_unknown_union_error") { - ScopedFastFlag sffs[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - check(R"( local err = 5.nope.nope -- err is now an error type local a : any @@ -418,10 +373,6 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "any_is_unknown_union_error") TEST_CASE_FIXTURE(IsSubtypeFixture, "any_intersect_T_is_T") { - ScopedFastFlag sffs[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - check(R"( local a : (any & string) local b : string @@ -440,10 +391,6 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "any_intersect_T_is_T") TEST_CASE_FIXTURE(IsSubtypeFixture, "error_suppression") { - ScopedFastFlag sffs[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - check(""); TypeId any = builtinTypes->anyType; @@ -453,33 +400,21 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "error_suppression") CHECK(!isSubtype(any, err)); CHECK(isSubtype(err, any)); - CHECK(isConsistentSubtype(any, err)); - CHECK(isConsistentSubtype(err, any)); CHECK(!isSubtype(any, str)); CHECK(isSubtype(str, any)); - CHECK(isConsistentSubtype(any, str)); - CHECK(isConsistentSubtype(str, any)); CHECK(!isSubtype(any, unk)); CHECK(isSubtype(unk, any)); - CHECK(isConsistentSubtype(any, unk)); - CHECK(isConsistentSubtype(unk, any)); CHECK(!isSubtype(err, str)); CHECK(!isSubtype(str, err)); - CHECK(isConsistentSubtype(err, str)); - CHECK(isConsistentSubtype(str, err)); CHECK(!isSubtype(err, unk)); CHECK(!isSubtype(unk, err)); - CHECK(isConsistentSubtype(err, unk)); - CHECK(isConsistentSubtype(unk, err)); CHECK(isSubtype(str, unk)); CHECK(!isSubtype(unk, str)); - CHECK(isConsistentSubtype(str, unk)); - CHECK(!isConsistentSubtype(unk, str)); } TEST_SUITE_END(); @@ -490,13 +425,15 @@ struct NormalizeFixture : Fixture InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; Normalizer normalizer{&arena, builtinTypes, NotNull{&unifierState}}; + Scope globalScope{builtinTypes->anyTypePack}; + ScopedFastFlag fixNormalizeCaching{FFlag::LuauFixNormalizeCaching, true}; NormalizeFixture() { registerHiddenTypes(&frontend); } - const NormalizedType* toNormalizedType(const std::string& annotation) + std::shared_ptr toNormalizedType(const std::string& annotation) { normalizer.clearCaches(); CheckResult result = check("type _Res = " + annotation); @@ -524,7 +461,7 @@ struct NormalizeFixture : Fixture TypeId normal(const std::string& annotation) { - const NormalizedType* norm = toNormalizedType(annotation); + std::shared_ptr norm = toNormalizedType(annotation); REQUIRE(norm); return normalizer.typeFromNormal(*norm); } @@ -728,10 +665,10 @@ TEST_CASE_FIXTURE(NormalizeFixture, "trivial_intersection_inhabited") TypeId a = arena.addType(FunctionType{builtinTypes->emptyTypePack, builtinTypes->anyTypePack, std::nullopt, false}); TypeId c = arena.addType(IntersectionType{{a, a}}); - const NormalizedType* n = normalizer.normalize(c); + std::shared_ptr n = normalizer.normalize(c); REQUIRE(n); - CHECK(normalizer.isInhabited(n) == NormalizationResult::True); + CHECK(normalizer.isInhabited(n.get()) == NormalizationResult::True); } TEST_CASE_FIXTURE(NormalizeFixture, "bare_negated_boolean") @@ -841,7 +778,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "recurring_intersection") std::optional t = lookupType("B"); REQUIRE(t); - const NormalizedType* nt = normalizer.normalize(*t); + std::shared_ptr nt = normalizer.normalize(*t); REQUIRE(nt); CHECK("any" == toString(normalizer.typeFromNormal(*nt))); @@ -854,7 +791,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "cyclic_union") TypeId u = arena.addType(UnionType{{builtinTypes->numberType, t}}); asMutable(t)->ty.emplace(IntersectionType{{builtinTypes->anyType, u}}); - const NormalizedType* nt = normalizer.normalize(t); + std::shared_ptr nt = normalizer.normalize(t); REQUIRE(nt); CHECK("number" == toString(normalizer.typeFromNormal(*nt))); @@ -910,25 +847,25 @@ TEST_CASE_FIXTURE(NormalizeFixture, "normalize_blocked_types") { Type blocked{BlockedType{}}; - const NormalizedType* norm = normalizer.normalize(&blocked); + std::shared_ptr norm = normalizer.normalize(&blocked); CHECK_EQ(normalizer.typeFromNormal(*norm), &blocked); } TEST_CASE_FIXTURE(NormalizeFixture, "normalize_is_exactly_number") { - const NormalizedType* number = normalizer.normalize(builtinTypes->numberType); + std::shared_ptr number = normalizer.normalize(builtinTypes->numberType); // 1. all types for which Types::number say true for, NormalizedType::isExactlyNumber should say true as well CHECK(Luau::isNumber(builtinTypes->numberType) == number->isExactlyNumber()); // 2. isExactlyNumber should handle cases like `number & number` TypeId intersection = arena.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->numberType}}); - const NormalizedType* normIntersection = normalizer.normalize(intersection); + std::shared_ptr normIntersection = normalizer.normalize(intersection); CHECK(normIntersection->isExactlyNumber()); // 3. isExactlyNumber should reject things that are definitely not precisely numbers `number | any` TypeId yoonion = arena.addType(UnionType{{builtinTypes->anyType, builtinTypes->numberType}}); - const NormalizedType* unionIntersection = normalizer.normalize(yoonion); + std::shared_ptr unionIntersection = normalizer.normalize(yoonion); CHECK(!unionIntersection->isExactlyNumber()); } @@ -952,14 +889,34 @@ TEST_CASE_FIXTURE(NormalizeFixture, "read_only_props_2") { ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; - CHECK(R"({ x: never })" == toString(normal(R"({ x: "hello" } & { x: "world" })"), {true})); + CHECK(R"({ x: "hello" })" == toString(normal(R"({ x: "hello" } & { x: string })"), {true})); + CHECK(R"(never)" == toString(normal(R"({ x: "hello" } & { x: "world" })"), {true})); } TEST_CASE_FIXTURE(NormalizeFixture, "read_only_props_3") { ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; - CHECK("{ read x: never }" == toString(normal(R"({ read x: "hello" } & { read x: "world" })"), {true})); + CHECK(R"({ read x: "hello" })" == toString(normal(R"({ read x: "hello" } & { read x: string })"), {true})); + CHECK("never" == toString(normal(R"({ read x: "hello" } & { read x: "world" })"), {true})); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "final_types_are_cached") +{ + std::shared_ptr na1 = normalizer.normalize(builtinTypes->numberType); + std::shared_ptr na2 = normalizer.normalize(builtinTypes->numberType); + + CHECK(na1 == na2); +} + +TEST_CASE_FIXTURE(NormalizeFixture, "non_final_types_can_be_normalized_but_are_not_cached") +{ + TypeId a = arena.freshType(&globalScope); + + std::shared_ptr na1 = normalizer.normalize(a); + std::shared_ptr na2 = normalizer.normalize(a); + + CHECK(na1 != na2); } TEST_SUITE_END(); diff --git a/tests/RequireByString.test.cpp b/tests/RequireByString.test.cpp index 99e562a8..c4f1fb7c 100644 --- a/tests/RequireByString.test.cpp +++ b/tests/RequireByString.test.cpp @@ -53,8 +53,6 @@ std::optional getResourcePath() #endif #endif -LUAU_FASTFLAG(LuauUpdatedRequireByStringSemantics) - class ReplWithPathFixture { public: @@ -219,7 +217,6 @@ TEST_CASE("PathResolution") std::string prefix = "/"; #endif - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; CHECK(resolvePath(prefix + "Users/modules/module.luau", "") == prefix + "Users/modules/module.luau"); CHECK(resolvePath(prefix + "Users/modules/module.luau", "a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); CHECK(resolvePath(prefix + "Users/modules/module.luau", "./a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); @@ -245,7 +242,6 @@ TEST_CASE("PathNormalization") std::string prefix = "/"; #endif - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; // Relative path std::optional result = normalizePath("../../modules/module"); CHECK(result); @@ -275,7 +271,6 @@ TEST_CASE("PathNormalization") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireSimpleRelativePath") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency"; runProtectedRequire(path); assertOutputContainsAll({"true", "result from dependency"}); @@ -283,7 +278,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireSimpleRelativePath") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireRelativeToRequiringFile") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/module"; runProtectedRequire(path); assertOutputContainsAll({"true", "result from dependency", "required into module"}); @@ -291,7 +285,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireRelativeToRequiringFile") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireLua") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/lua_dependency"; runProtectedRequire(path); assertOutputContainsAll({"true", "result from lua_dependency"}); @@ -299,7 +292,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireLua") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireInitLuau") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/luau"; runProtectedRequire(path); assertOutputContainsAll({"true", "result from init.luau"}); @@ -307,7 +299,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireInitLuau") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireInitLua") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/lua"; runProtectedRequire(path); assertOutputContainsAll({"true", "result from init.lua"}); @@ -315,7 +306,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireInitLua") TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireLuau") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/module"; std::string absolutePath = getLuauDirectory(PathType::Absolute) + "/tests/require/without_config/module"; @@ -335,7 +325,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireLuau") TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireLua") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/lua_dependency"; std::string absolutePath = getLuauDirectory(PathType::Absolute) + "/tests/require/without_config/lua_dependency"; @@ -355,7 +344,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireLua") TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireInitLuau") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/luau"; std::string absolutePath = getLuauDirectory(PathType::Absolute) + "/tests/require/without_config/luau"; @@ -375,7 +363,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireInitLuau") TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireInitLua") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/lua"; std::string absolutePath = getLuauDirectory(PathType::Absolute) + "/tests/require/without_config/lua"; @@ -395,14 +382,12 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireInitLua") TEST_CASE_FIXTURE(ReplWithPathFixture, "LoadStringRelative") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; runCode(L, "return pcall(function() return loadstring(\"require('a/relative/path')\")() end)"); assertOutputContainsAll({"false", "require is not supported in this context"}); } TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireAbsolutePath") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; #ifdef _WIN32 std::string absolutePath = "C:/an/absolute/path"; #else @@ -414,7 +399,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireAbsolutePath") TEST_CASE_FIXTURE(ReplWithPathFixture, "PathsArrayRelativePath") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/requirer"; runProtectedRequire(path); assertOutputContainsAll({"true", "result from library"}); @@ -422,7 +406,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "PathsArrayRelativePath") TEST_CASE_FIXTURE(ReplWithPathFixture, "PathsArrayExplicitlyRelativePath") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/fail_requirer"; runProtectedRequire(path); assertOutputContainsAll({"false", "error requiring module"}); @@ -430,7 +413,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "PathsArrayExplicitlyRelativePath") TEST_CASE_FIXTURE(ReplWithPathFixture, "PathsArrayFromParent") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/global_library_requirer"; runProtectedRequire(path); assertOutputContainsAll({"true", "result from global_library"}); @@ -438,7 +420,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "PathsArrayFromParent") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/alias_requirer"; runProtectedRequire(path); assertOutputContainsAll({"true", "result from dependency"}); @@ -446,7 +427,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithParentAlias") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/parent_alias_requirer"; runProtectedRequire(path); assertOutputContainsAll({"true", "result from other_dependency"}); @@ -455,7 +435,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithParentAlias") TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireAliasThatDoesNotExist") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string nonExistentAlias = "@this.alias.does.not.exist"; runProtectedRequire(nonExistentAlias); @@ -464,7 +443,6 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireAliasThatDoesNotExist") TEST_CASE_FIXTURE(ReplWithPathFixture, "AliasHasIllegalFormat") { - ScopedFastFlag sff{FFlag::LuauUpdatedRequireByStringSemantics, true}; std::string illegalCharacter = "@@"; runProtectedRequire(illegalCharacter); diff --git a/tests/SharedCodeAllocator.test.cpp b/tests/SharedCodeAllocator.test.cpp index 983b077c..833b7502 100644 --- a/tests/SharedCodeAllocator.test.cpp +++ b/tests/SharedCodeAllocator.test.cpp @@ -3,15 +3,21 @@ #include "Luau/CodeAllocator.h" +#include "luacode.h" #include "luacodegen.h" +#include "lualib.h" #include "doctest.h" +#include "ScopedFlags.h" // We explicitly test correctness of self-assignment for some types #ifdef __clang__ #pragma GCC diagnostic ignored "-Wself-assign-overloaded" #endif +LUAU_FASTFLAG(LuauCodegenContext) +LUAU_FASTFLAG(LuauCodegenDetailedCompilationResult) + using namespace Luau::CodeGen; @@ -27,6 +33,9 @@ TEST_CASE("NativeModuleRefRefcounting") if (!luau_codegen_supported()) return; + ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; + ScopedFastFlag luauCodegenDetailedCompilationResult{FFlag::LuauCodegenDetailedCompilationResult, true}; + CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -243,6 +252,9 @@ TEST_CASE("NativeProtoRefcounting") if (!luau_codegen_supported()) return; + ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; + ScopedFastFlag luauCodegenDetailedCompilationResult{FFlag::LuauCodegenDetailedCompilationResult, true}; + CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -294,6 +306,9 @@ TEST_CASE("NativeProtoState") if (!luau_codegen_supported()) return; + ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; + ScopedFastFlag luauCodegenDetailedCompilationResult{FFlag::LuauCodegenDetailedCompilationResult, true}; + CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; SharedCodeAllocator allocator{&codeAllocator}; @@ -347,3 +362,104 @@ TEST_CASE("NativeProtoState") REQUIRE(modRefA->tryGetNativeProto(2) == nullptr); REQUIRE(modRefA->tryGetNativeProto(4) == nullptr); } + +TEST_CASE("AnonymousModuleLifetime") +{ + if (!luau_codegen_supported()) + return; + + ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; + ScopedFastFlag luauCodegenDetailedCompilationResult{FFlag::LuauCodegenDetailedCompilationResult, true}; + + CodeAllocator codeAllocator{kBlockSize, kMaxTotalSize}; + SharedCodeAllocator allocator{&codeAllocator}; + + const std::vector data(8); + const std::vector code(8); + + std::vector nativeProtos; + nativeProtos.reserve(1); + + { + NativeProtoExecDataPtr nativeProto = createNativeProtoExecData(2); + getNativeProtoExecDataHeader(nativeProto.get()).bytecodeId = 1; + getNativeProtoExecDataHeader(nativeProto.get()).entryOffsetOrAddress = reinterpret_cast(0x00); + nativeProto[0] = 0; + nativeProto[1] = 4; + + nativeProtos.push_back(std::move(nativeProto)); + } + + NativeModuleRef modRef = allocator.insertAnonymousNativeModule(std::move(nativeProtos), data.data(), data.size(), code.data(), code.size()); + REQUIRE(!modRef.empty()); + REQUIRE(modRef->getModuleBaseAddress() != nullptr); + REQUIRE(modRef->tryGetNativeProto(1) != nullptr); + REQUIRE(modRef->getRefcount() == 1); + + const NativeModule* mod = modRef.get(); + + // Acquire a reference (as if we are binding it to a Luau VM Proto): + modRef->addRef(); + REQUIRE(mod->getRefcount() == 2); + + // Release our "owning" reference: + modRef.reset(); + REQUIRE(mod->getRefcount() == 1); + + // Release our added reference (as if the Luau VM Proto is being GC'ed): + mod->release(); + + // When we return and the sharedCodeAllocator is destroyed it will verify + // that there are no outstanding anonymous NativeModules. +} + +TEST_CASE("SharedAllocation") +{ + if (!luau_codegen_supported()) + return; + + ScopedFastFlag luauCodegenContext{FFlag::LuauCodegenContext, true}; + ScopedFastFlag luauCodegenDetailedCompilationResult{FFlag::LuauCodegenDetailedCompilationResult, true}; + + UniqueSharedCodeGenContext sharedCodeGenContext = createSharedCodeGenContext(); + + std::unique_ptr L1{luaL_newstate(), lua_close}; + std::unique_ptr L2{luaL_newstate(), lua_close}; + + create(L1.get(), sharedCodeGenContext.get()); + create(L2.get(), sharedCodeGenContext.get()); + + std::string source = R"( + function add(x, y) return x + y end + function sub(x, y) return x - y end + )"; + + size_t bytecodeSize = 0; + std::unique_ptr bytecode{luau_compile(source.data(), source.size(), nullptr, &bytecodeSize), free}; + const int loadResult1 = luau_load(L1.get(), "=Functions", bytecode.get(), bytecodeSize, 0); + const int loadResult2 = luau_load(L2.get(), "=Functions", bytecode.get(), bytecodeSize, 0); + REQUIRE(loadResult1 == 0); + REQUIRE(loadResult2 == 0); + bytecode.reset(); + + const ModuleId moduleId = {0x01}; + + CompilationStats nativeStats1 = {}; + CompilationStats nativeStats2 = {}; + const CompilationResult codeGenResult1 = Luau::CodeGen::compile(moduleId, L1.get(), -1, CodeGen_ColdFunctions, &nativeStats1); + const CompilationResult codeGenResult2 = Luau::CodeGen::compile(moduleId, L2.get(), -1, CodeGen_ColdFunctions, &nativeStats2); + REQUIRE(codeGenResult1.result == CodeGenCompilationResult::Success); + REQUIRE(codeGenResult2.result == CodeGenCompilationResult::Success); + + // We should have identified all three functions both times through: + REQUIRE(nativeStats1.functionsTotal == 3); + REQUIRE(nativeStats2.functionsTotal == 3); + + // We should have compiled the three functions only the first time: + REQUIRE(nativeStats1.functionsCompiled == 3); + REQUIRE(nativeStats2.functionsCompiled == 0); + + // We should have bound all three functions both times through: + REQUIRE(nativeStats1.functionsBound == 3); + REQUIRE(nativeStats2.functionsBound == 3); +} diff --git a/tests/Simplify.test.cpp b/tests/Simplify.test.cpp index 83f1d87d..ddddbe67 100644 --- a/tests/Simplify.test.cpp +++ b/tests/Simplify.test.cpp @@ -571,4 +571,15 @@ TEST_CASE_FIXTURE(SimplifyFixture, "bound_intersected_by_itself_should_be_itself CHECK(toString(blocked) == intersectStr(blocked, blocked)); } +TEST_CASE_FIXTURE(SimplifyFixture, "cyclic_never_union_and_string") +{ + // t1 where t1 = never | t1 + TypeId leftType = arena->addType(UnionType{{builtinTypes->neverType, builtinTypes->neverType}}); + UnionType* leftUnion = getMutable(leftType); + REQUIRE(leftUnion); + leftUnion->options[0] = leftType; + + CHECK(builtinTypes->stringType == union_(leftType, builtinTypes->stringType)); +} + TEST_SUITE_END(); diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index 07ef9f1e..c5b3e053 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -24,7 +24,7 @@ struct FamilyFixture : Fixture { swapFamily = TypeFamily{/* name */ "Swap", /* reducer */ - [](TypeId instance, std::vector tys, std::vector tps, + [](TypeId instance, NotNull queue, const std::vector& tys, const std::vector& tps, NotNull ctx) -> TypeFamilyReductionResult { LUAU_ASSERT(tys.size() == 1); TypeId param = follow(tys.at(0)); @@ -218,6 +218,62 @@ TEST_CASE_FIXTURE(Fixture, "add_family_at_work") CHECK(toString(result.errors[1]) == "Type family instance Add is uninhabited"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "cyclic_add_family_at_work") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( + type T = add + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "mul_family_with_union_of_multiplicatives") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + loadDefinition(R"( + declare class Vec2 + function __mul(self, rhs: number): Vec2 + end + + declare class Vec3 + function __mul(self, rhs: number): Vec3 + end + )"); + + CheckResult result = check(R"( + type T = mul + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "Vec2 | Vec3"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "mul_family_with_union_of_multiplicatives_2") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + loadDefinition(R"( + declare class Vec3 + function __mul(self, rhs: number): Vec3 + function __mul(self, rhs: Vec3): Vec3 + end + )"); + + CheckResult result = check(R"( + type T = mul + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(toString(requireTypeAlias("T")) == "Vec3"); +} + TEST_CASE_FIXTURE(Fixture, "internal_families_raise_errors") { if (!FFlag::DebugLuauDeferredConstraintResolution) diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 24ccd296..c65d6ec5 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -1082,4 +1082,57 @@ type t0 = (t0) )"); LUAU_REQUIRE_ERRORS(result); } + + +TEST_CASE_FIXTURE(Fixture, "recursive_type_alias_warns") +{ + CheckResult result = check(R"( +type Foo = Foo +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto occursCheckError = get(result.errors[0]); + REQUIRE(occursCheckError); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_type_alias_bad_pack_use_warns") +{ + if (!FFlag::DebugLuauDeferredConstraintResolution) + return; + + CheckResult result = check(R"( +type Foo = Foo +)"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + auto occursCheckFailed = get(result.errors[1]); + REQUIRE(occursCheckFailed); + + auto swappedGeneric = get(result.errors[2]); + REQUIRE(swappedGeneric); + CHECK(swappedGeneric->name == "T"); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_aliases") +{ + CheckResult result = check(R"( +type Foo = Bar +type Bar = Foo +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto err = get(result.errors[0]); + REQUIRE(err); +} + +TEST_CASE_FIXTURE(Fixture, "should_also_occurs_check") +{ + CheckResult result = check(R"( +type Foo = Foo | string +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto err = get(result.errors[0]); + REQUIRE(err); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.cfa.test.cpp b/tests/TypeInfer.cfa.test.cpp index 99d3008a..b701e960 100644 --- a/tests/TypeInfer.cfa.test.cpp +++ b/tests/TypeInfer.cfa.test.cpp @@ -1,13 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Fixture.h" - -#include "Luau/Symbol.h" #include "doctest.h" using namespace Luau; LUAU_FASTFLAG(LuauTinyControlFlowAnalysis); -LUAU_FASTFLAG(LuauLoopControlFlowAnalysis); TEST_SUITE_BEGIN("ControlFlowAnalysis"); @@ -31,7 +28,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return") TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}) @@ -51,7 +48,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break") TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}) @@ -93,7 +90,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_return") TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_break") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -118,7 +115,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_break") TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_continue") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -143,7 +140,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_continue") TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_break") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -168,7 +165,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_break") TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_continue") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -217,7 +214,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_rand_return_elif_not_y_ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_rand_break_elif_not_y_break") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -244,7 +241,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_rand_break_elif_not_y_br TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_rand_continue_elif_not_y_continue") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -295,7 +292,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_rand_return_elif_no TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_rand_break_elif_not_y_fallthrough") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -322,7 +319,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_rand_break_elif_not_y_fa TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_rand_continue_elif_not_y_fallthrough") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -375,7 +372,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_fallthrough_elif_ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_fallthrough_elif_not_z_break") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) @@ -405,7 +402,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_elif_not_y_fallthrough_elif_n TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_fallthrough_elif_not_z_continue") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) @@ -435,7 +432,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_fallthrough_eli TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_throw_elif_not_z_fallthrough") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) @@ -465,7 +462,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_elif_not_y_throw_elif_not_ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_elif_not_y_fallthrough_elif_not_z_break") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}, z: {{value: string?}}) @@ -515,7 +512,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_if_not_x_return") TEST_CASE_FIXTURE(BuiltinsFixture, "for_record_do_if_not_x_break") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}) @@ -537,7 +534,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_record_do_if_not_x_break") TEST_CASE_FIXTURE(BuiltinsFixture, "for_record_do_if_not_x_continue") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}) @@ -688,7 +685,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_return_if_not_y_return") TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_if_not_y_break") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -715,7 +712,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_if_not_y_break") TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_if_not_y_continue") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -742,7 +739,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_if_not_y_continue") TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_if_not_y_throw") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -769,7 +766,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_continue_if_not_y_throw") TEST_CASE_FIXTURE(BuiltinsFixture, "if_not_x_break_if_not_y_continue") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}, y: {{value: string?}}) @@ -819,7 +816,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out") TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out_breaking") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}) @@ -844,7 +841,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out_breaking") TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_does_not_leak_out_continuing") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}) @@ -895,7 +892,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_ TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_scope_breaking") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}) @@ -920,7 +917,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_ TEST_CASE_FIXTURE(BuiltinsFixture, "prototyping_and_visiting_alias_has_the_same_scope_continuing") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( local function f(x: {{value: string?}}) @@ -980,7 +977,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions") TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions_breaking") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( type Ok = { tag: "ok", value: T } @@ -1013,7 +1010,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions_breaking") TEST_CASE_FIXTURE(BuiltinsFixture, "tagged_unions_continuing") { - ScopedFastFlag flags[] = {{FFlag::LuauTinyControlFlowAnalysis, true}, {FFlag::LuauLoopControlFlowAnalysis, true}}; + ScopedFastFlag sff{FFlag::LuauTinyControlFlowAnalysis, true}; CheckResult result = check(R"( type Ok = { tag: "ok", value: T } diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 99df053f..29f70b30 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -22,6 +22,22 @@ LUAU_FASTINT(LuauTarjanChildLimit); TEST_SUITE_BEGIN("TypeInferFunctions"); +TEST_CASE_FIXTURE(Fixture, "general_case_table_literal_blocks") +{ + CheckResult result = check(R"( +--!strict +function f(x : {[any]: number}) + return x +end + +local Foo = {bar = "$$$"} + +f({[Foo.bar] = 0}) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "overload_resolution") { CheckResult result = check(R"( @@ -2476,4 +2492,47 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "function_that_could_return_anything_is_compa LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "self_application_does_not_segfault") +{ + (void)check(R"( + function f(a) + f(f) + return f(), a + end + )"); + + // We only care that type checking completes without tripping a crash or an assertion. +} + +TEST_CASE_FIXTURE(Fixture, "function_definition_in_a_do_block") +{ + CheckResult result = check(R"( + local f + do + function f() + end + end + f() + )"); + + // We are predominantly interested in this test not crashing. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "function_definition_in_a_do_block_with_global") +{ + CheckResult result = check(R"( + function f() print("a") end + do + function f() + print("b") + end + end + f() + )"); + + // We are predominantly interested in this test not crashing. + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 667a1ebe..8234a4fb 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -2,8 +2,6 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" -#include "Luau/Scope.h" -#include "Luau/TypeInfer.h" #include "Luau/Type.h" #include "Luau/VisitType.h" @@ -15,7 +13,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauStacklessTypeClone3); TEST_SUITE_BEGIN("TypeInferOOP"); @@ -416,7 +413,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "promise_type_error_too_complex" * doctest::t // TODO: LTI changes to function call resolution have rendered this test impossibly slow // shared self should fix it, but there may be other mitigations possible as well REQUIRE(!FFlag::DebugLuauDeferredConstraintResolution); - ScopedFastFlag sff{FFlag::LuauStacklessTypeClone3, true}; frontend.options.retainFullTypeGraphs = false; diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 64c263c3..8e81b0cc 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -12,7 +12,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauSharedSelf); -LUAU_FASTFLAG(LuauTransitiveSubtyping); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauTarjanChildLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); @@ -507,10 +506,6 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") { - ScopedFastFlag sff[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - TypeArena arena; TypeId nilType = builtinTypes->nilType; @@ -916,10 +911,6 @@ TEST_CASE_FIXTURE(Fixture, "floating_generics_should_not_be_allowed") TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") { - ScopedFastFlag sff[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - TypeArena arena; TypeId nilType = builtinTypes->nilType; @@ -1071,7 +1062,6 @@ tbl:f3() TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_limit_in_unify_with_any") { ScopedFastFlag sff[] = { - {FFlag::LuauTransitiveSubtyping, true}, {FFlag::DebugLuauDeferredConstraintResolution, true}, }; @@ -1218,4 +1208,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "update_phonemes_minimized") LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_containing_non_final_type_is_erroneously_cached") +{ + TypeArena arena; + Scope globalScope(builtinTypes->anyTypePack); + UnifierSharedState sharedState{&ice}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + + TypeId tableTy = arena.addType(TableType{}); + TableType* table = getMutable(tableTy); + REQUIRE(table); + + TypeId freeTy = arena.freshType(&globalScope); + + table->props["foo"] = Property::rw(freeTy); + + std::shared_ptr n1 = normalizer.normalize(tableTy); + std::shared_ptr n2 = normalizer.normalize(tableTy); + + // This should not hold + CHECK(n1 == n2); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 49bca197..485a18c6 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -233,6 +233,18 @@ TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") )"); LUAU_REQUIRE_ERRORS(result); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CannotAssignToNever* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK(builtinTypes->stringType == tm->rhsType); + CHECK(CannotAssignToNever::Reason::PropertyNarrowed == tm->reason); + REQUIRE(tm->cause.size() == 2); + CHECK("\"Dog\"" == toString(tm->cause[0])); + CHECK("\"Cat\"" == toString(tm->cause[1])); + } } TEST_CASE_FIXTURE(Fixture, "table_has_a_boolean") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 5d5d6042..e10aea39 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -4283,6 +4283,30 @@ TEST_CASE_FIXTURE(Fixture, "parameter_was_set_an_indexer_and_bounded_by_another_ CHECK_EQ("({number}, unknown) -> ()", toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "write_to_union_property_not_all_present") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + CheckResult result = check(R"( + type Animal = {tag: "Cat", meow: boolean} | {tag: "Dog", woof: boolean} + function f(t: Animal) + t.tag = "Dog" + end + )"); + + // this should fail because `t` may be a `Cat` variant, and `"Dog"` is not a subtype of `"Cat"`. + LUAU_REQUIRE_ERRORS(result); + + CannotAssignToNever* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK(builtinTypes->stringType == tm->rhsType); + CHECK(CannotAssignToNever::Reason::PropertyNarrowed == tm->reason); + REQUIRE(tm->cause.size() == 2); + CHECK("\"Cat\"" == toString(tm->cause[0])); + CHECK("\"Dog\"" == toString(tm->cause[1])); +} + TEST_CASE_FIXTURE(Fixture, "mymovie_read_write_tables_bug") { CheckResult result = check(R"( @@ -4322,4 +4346,17 @@ TEST_CASE_FIXTURE(Fixture, "mymovie_read_write_tables_bug_2") LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "setindexer_always_transmute") +{ + ScopedFastFlag sff{FFlag::DebugLuauDeferredConstraintResolution, true}; + + CheckResult result = check(R"( + function f(x) + (5)[5] = x + end + )"); + + CHECK_EQ("(*error-type*) -> ()", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 8d252ddb..fc804265 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -19,7 +19,6 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping); -LUAU_FASTFLAG(LuauTransitiveSubtyping); LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauRecursionLimit); @@ -980,6 +979,41 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") )"); } +/* + * We had a bug where we'd improperly cache the normalization of types that are + * not fully solved yet. This eventually caused a crash elsewhere in the type + * solver. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "fuzzer_found_this_2") +{ + (void) check(R"( + local _ + if _ then + _ = _ + while _() do + _ = # _ + end + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_a_cyclic_intersection_does_not_crash") +{ + (void) check(R"( + local _ + if _ then + while nil do + _ = _ + end + end + if _[if _ then ""] then + while nil do + _ = if _ then "" + end + end + )"); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_crash") { CheckResult result = check(R"( @@ -1272,9 +1306,6 @@ TEST_CASE_FIXTURE(Fixture, "dcr_delays_expansion_of_function_containing_blocked_ { ScopedFastFlag sff[] = { {FFlag::DebugLuauDeferredConstraintResolution, true}, - // If we run this with error-suppression, it triggers an assertion. - // FATAL ERROR: Assertion failed: !"Internal error: Trying to normalize a BlockedType" - {FFlag::LuauTransitiveSubtyping, false}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index b2bb1ccc..58ccea89 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -13,7 +13,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls); -LUAU_FASTFLAG(LuauTransitiveSubtyping); struct TryUnifyFixture : Fixture { @@ -32,10 +31,6 @@ TEST_SUITE_BEGIN("TryUnifyTests"); TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") { - ScopedFastFlag sff[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - Type numberOne{TypeVariant{PrimitiveType{PrimitiveType::Number}}}; Type numberTwo = numberOne; @@ -47,10 +42,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") { - ScopedFastFlag sff[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - Type functionOne{ TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))}}; @@ -68,10 +59,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") { - ScopedFastFlag sff[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; Type functionOne{ TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))}}; @@ -94,10 +81,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") { - ScopedFastFlag sff[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - Type tableOne{TypeVariant{ TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; @@ -120,10 +103,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") { - ScopedFastFlag sff[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - Type tableOne{TypeVariant{ TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, std::nullopt, globalScope->level, TableState::Unsealed}, @@ -352,10 +331,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table") { - ScopedFastFlag sff[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - TableType::Props freeProps{ {"foo", {builtinTypes->numberType}}, }; diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index dfa88934..4b2f029d 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -9,7 +9,6 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAG(LuauTransitiveSubtyping); TEST_SUITE_BEGIN("UnionTypes"); @@ -867,10 +866,6 @@ TEST_CASE_FIXTURE(Fixture, "optional_any") TEST_CASE_FIXTURE(Fixture, "generic_function_with_optional_arg") { - ScopedFastFlag sff[] = { - {FFlag::LuauTransitiveSubtyping, true}, - }; - CheckResult result = check(R"( function f(x : T?) : {T} local result = {} diff --git a/tools/faillist.txt b/tools/faillist.txt index ce91bcc4..ea609649 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -32,7 +32,6 @@ BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument2 BuiltinTests.table_freeze_is_generic BuiltinTests.tonumber_returns_optional_number_type -ControlFlowAnalysis.tagged_unions DefinitionTests.class_definition_overload_metamethods Differ.metatable_metamissing_left Differ.metatable_metamissing_right @@ -75,7 +74,6 @@ GenericsTests.no_stack_overflow_from_quantifying GenericsTests.properties_can_be_instantiated_polytypes GenericsTests.quantify_functions_even_if_they_have_an_explicit_generic GenericsTests.self_recursive_instantiated_param -GenericsTests.type_parameters_can_be_polytypes IntersectionTypes.CLI-44817 IntersectionTypes.error_detailed_intersection_all IntersectionTypes.error_detailed_intersection_part @@ -84,7 +82,6 @@ IntersectionTypes.intersect_false_and_bool_and_false IntersectionTypes.intersect_metatables IntersectionTypes.intersect_saturate_overloaded_functions IntersectionTypes.intersection_of_tables -IntersectionTypes.intersection_of_tables_with_never_properties IntersectionTypes.intersection_of_tables_with_top_properties IntersectionTypes.less_greedy_unification_with_intersection_types IntersectionTypes.overloaded_functions_mentioning_generic @@ -139,9 +136,6 @@ RefinementTest.call_an_incompatible_function_after_using_typeguard RefinementTest.dataflow_analysis_can_tell_refinements_when_its_appropriate_to_refine_into_nil_or_never RefinementTest.discriminate_from_isa_of_x RefinementTest.discriminate_from_truthiness_of_x -RefinementTest.discriminate_tag -RefinementTest.discriminate_tag_with_implicit_else -RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.function_call_with_colon_after_refining_not_to_be_nil RefinementTest.globals_can_be_narrowed_too RefinementTest.isa_type_refinement_must_be_known_ahead_of_time @@ -199,7 +193,6 @@ TableTests.meta_add_inferred TableTests.metatable_mismatch_should_fail TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys -TableTests.nil_assign_doesnt_hit_indexer TableTests.ok_to_provide_a_subtype_during_construction TableTests.ok_to_set_nil_even_on_non_lvalue_base_expr TableTests.okay_to_add_property_to_unsealed_tables_by_assignment @@ -236,6 +229,7 @@ TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type +TableTests.wrong_assign_does_hit_indexer ToDot.function ToString.exhaustive_toString_of_cyclic_table ToString.free_types @@ -265,6 +259,7 @@ TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeFamilyTests.add_family_at_work TypeFamilyTests.family_as_fn_arg TypeFamilyTests.internal_families_raise_errors +TypeFamilyTests.mul_family_with_union_of_multiplicatives_2 TypeFamilyTests.unsolvable_family TypeInfer.be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload TypeInfer.check_type_infer_recursion_count @@ -320,6 +315,7 @@ TypeInferFunctions.function_exprs_are_generalized_at_signature_scope_not_enclosi TypeInferFunctions.function_is_supertype_of_concrete_functions TypeInferFunctions.function_statement_sealed_table_assignment_through_indexer TypeInferFunctions.generic_packs_are_not_variadic +TypeInferFunctions.higher_order_function_2 TypeInferFunctions.higher_order_function_3 TypeInferFunctions.higher_order_function_4 TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict @@ -415,7 +411,6 @@ TypeSingletons.error_detailed_tagged_union_mismatch_string TypeSingletons.overloaded_function_call_with_singletons_mismatch TypeSingletons.return_type_of_f_is_not_widened TypeSingletons.table_properties_type_error_escapes -TypeSingletons.tagged_unions_immutable_tag TypeSingletons.widen_the_supertype_if_it_is_free_and_subtype_has_singleton TypeStatesTest.prototyped_recursive_functions_but_has_future_assignments TypeStatesTest.typestates_preserve_error_suppression_properties