diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index b0c8fd17..7d5ce892 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -4,6 +4,7 @@ #include #include "Luau/TypeArena.h" #include "Luau/Type.h" +#include "Luau/Scope.h" #include @@ -26,13 +27,17 @@ struct CloneState * while `clone` will make a deep copy of the entire type and its every component. * * Be mindful about which behavior you actually _want_. + * + * Persistent types are not cloned as an optimization. + * If a type is cloned in order to mutate it, 'ignorePersistent' has to be set */ -TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState); -TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState); +TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false); +TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false); TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); +Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState); } // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 272ee52a..dc443777 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -7,6 +7,7 @@ #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" #include "Luau/Scope.h" +#include "Luau/Set.h" #include "Luau/TypeCheckLimits.h" #include "Luau/Variant.h" #include "Luau/AnyTypeSummary.h" @@ -56,13 +57,32 @@ struct SourceNode return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; } + bool hasInvalidModuleDependency(bool forAutocomplete) const + { + return forAutocomplete ? invalidModuleDependencyForAutocomplete : invalidModuleDependency; + } + + void setInvalidModuleDependency(bool value, bool forAutocomplete) + { + if (forAutocomplete) + invalidModuleDependencyForAutocomplete = value; + else + invalidModuleDependency = value; + } + ModuleName name; std::string humanReadableName; DenseHashSet requireSet{{}}; std::vector> requireLocations; + Set dependents{{}}; + bool dirtySourceModule = true; bool dirtyModule = true; bool dirtyModuleForAutocomplete = true; + + bool invalidModuleDependency = true; + bool invalidModuleDependencyForAutocomplete = true; + double autocompleteLimitsMult = 1.0; }; @@ -117,7 +137,7 @@ struct FrontendModuleResolver : ModuleResolver std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; std::string getHumanReadableModuleName(const ModuleName& moduleName) const override; - void setModule(const ModuleName& moduleName, ModulePtr module); + bool setModule(const ModuleName& moduleName, ModulePtr module); void clearModules(); private: @@ -151,9 +171,13 @@ struct Frontend // Parse and typecheck module graph CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess + bool allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete = false) const; + bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); + void traverseDependents(const ModuleName& name, std::function processSubtree); + /** Borrow a pointer into the SourceModule cache. * * Returns nullptr if we don't have it. This could mean that the script diff --git a/Analysis/include/Luau/Simplify.h b/Analysis/include/Luau/Simplify.h index 5b363e96..aab37876 100644 --- a/Analysis/include/Luau/Simplify.h +++ b/Analysis/include/Luau/Simplify.h @@ -19,10 +19,10 @@ struct SimplifyResult DenseHashSet blockedTypes; }; -SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right); SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts); -SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right); enum class Relation { diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index 558a5110..5c268f67 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -69,12 +69,16 @@ using Name = std::string; // A free type is one whose exact shape has yet to be fully determined. struct FreeType { + // New constructors + explicit FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound); + // This one got promoted to explicit + explicit FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound); + explicit FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound); + // Old constructors explicit FreeType(TypeLevel level); explicit FreeType(Scope* scope); FreeType(Scope* scope, TypeLevel level); - FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound); - int index; TypeLevel level; Scope* scope = nullptr; diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 4f8aea87..ebefa41f 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -32,9 +32,13 @@ struct TypeArena TypeId addTV(Type&& tv); - TypeId freshType(TypeLevel level); - TypeId freshType(Scope* scope); - TypeId freshType(Scope* scope, TypeLevel level); + TypeId freshType(NotNull builtins, TypeLevel level); + TypeId freshType(NotNull builtins, Scope* scope); + TypeId freshType(NotNull builtins, Scope* scope, TypeLevel level); + + TypeId freshType_DEPRECATED(TypeLevel level); + TypeId freshType_DEPRECATED(Scope* scope); + TypeId freshType_DEPRECATED(Scope* scope, TypeLevel level); TypePackId freshTypePack(Scope* scope); diff --git a/Analysis/include/Luau/TypeFunction.h b/Analysis/include/Luau/TypeFunction.h index dadad721..1c97550f 100644 --- a/Analysis/include/Luau/TypeFunction.h +++ b/Analysis/include/Luau/TypeFunction.h @@ -241,6 +241,9 @@ struct BuiltinTypeFunctions TypeFunction indexFunc; TypeFunction rawgetFunc; + TypeFunction setmetatableFunc; + TypeFunction getmetatableFunc; + void addToScope(NotNull arena, NotNull scope) const; }; diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index b1fd18ac..dbc1b5d8 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -1161,6 +1161,19 @@ struct AstJsonEncoder : public AstVisitor ); } + bool visit(class AstTypeGroup* node) override + { + writeNode( + node, + "AstTypeGroup", + [&]() + { + write("type", node->type); + } + ); + return false; + } + bool visit(class AstTypeSingletonBool* node) override { writeNode( diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 01c97547..7aee25ce 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -29,12 +29,11 @@ */ LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2) -LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix) LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression) LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2) -LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType2) +LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType3) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) +LUAU_FASTFLAGVARIABLE(LuauFreezeIgnorePersistent) namespace Luau { @@ -449,10 +448,9 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC ttv->props["foreachi"].deprecated = true; attachMagicFunction(ttv->props["pack"].type(), std::make_shared()); - if (FFlag::LuauTableCloneClonesType2) + if (FFlag::LuauTableCloneClonesType3) attachMagicFunction(ttv->props["clone"].type(), std::make_shared()); - if (FFlag::LuauTypestateBuiltins2) - attachMagicFunction(ttv->props["freeze"].type(), std::make_shared()); + attachMagicFunction(ttv->props["freeze"].type(), std::make_shared()); } if (FFlag::AutocompleteRequirePathSuggestions2) @@ -613,10 +611,7 @@ bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context) if (!fmt) { - if (FFlag::LuauStringFormatArityFix) - context.typechecker->reportError( - CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location - ); + context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location); return true; } @@ -1401,7 +1396,7 @@ std::optional> MagicClone::handleOldSolver( WithPredicate withPredicate ) { - LUAU_ASSERT(FFlag::LuauTableCloneClonesType2); + LUAU_ASSERT(FFlag::LuauTableCloneClonesType3); auto [paramPack, _predicates] = withPredicate; @@ -1416,6 +1411,9 @@ std::optional> MagicClone::handleOldSolver( TypeId inputType = follow(paramTypes[0]); + if (!get(inputType)) + return std::nullopt; + CloneState cloneState{typechecker.builtinTypes}; TypeId resultType = shallowClone(inputType, arena, cloneState); @@ -1425,7 +1423,7 @@ std::optional> MagicClone::handleOldSolver( bool MagicClone::infer(const MagicFunctionCallContext& context) { - LUAU_ASSERT(FFlag::LuauTableCloneClonesType2); + LUAU_ASSERT(FFlag::LuauTableCloneClonesType3); TypeArena* arena = context.solver->arena; @@ -1438,8 +1436,11 @@ bool MagicClone::infer(const MagicFunctionCallContext& context) TypeId inputType = follow(paramTypes[0]); + if (!get(inputType)) + return false; + CloneState cloneState{context.solver->builtinTypes}; - TypeId resultType = shallowClone(inputType, *arena, cloneState); + TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent); if (auto tableType = getMutable(resultType)) { @@ -1475,7 +1476,7 @@ static std::optional freezeTable(TypeId inputType, const MagicFunctionCa { // Clone the input type, this will become our final result type after we mutate it. CloneState cloneState{context.solver->builtinTypes}; - TypeId resultType = shallowClone(inputType, *arena, cloneState); + TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent); auto tableTy = getMutable(resultType); // `clone` should not break this. LUAU_ASSERT(tableTy); @@ -1507,8 +1508,6 @@ std::optional> MagicFreeze::handleOldSolver(struct Typ bool MagicFreeze::infer(const MagicFunctionCallContext& context) { - LUAU_ASSERT(FFlag::LuauTypestateBuiltins2); - TypeArena* arena = context.solver->arena; const DataFlowGraph* dfg = context.solver->dfg.get(); Scope* scope = context.constraint->scope.get(); diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 98397fa3..6309fa7c 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -7,6 +7,7 @@ #include "Luau/Unifiable.h" LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauFreezeIgnorePersistent) // For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit. LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000) @@ -38,14 +39,26 @@ class TypeCloner NotNull types; NotNull packs; + TypeId forceTy = nullptr; + TypePackId forceTp = nullptr; + int steps = 0; public: - TypeCloner(NotNull arena, NotNull builtinTypes, NotNull types, NotNull packs) + TypeCloner( + NotNull arena, + NotNull builtinTypes, + NotNull types, + NotNull packs, + TypeId forceTy, + TypePackId forceTp + ) : arena(arena) , builtinTypes(builtinTypes) , types(types) , packs(packs) + , forceTy(forceTy) + , forceTp(forceTp) { } @@ -112,7 +125,7 @@ private: ty = follow(ty, FollowOption::DisableLazyTypeThunks); if (auto it = types->find(ty); it != types->end()) return it->second; - else if (ty->persistent) + else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy)) return ty; return std::nullopt; } @@ -122,7 +135,7 @@ private: tp = follow(tp); if (auto it = packs->find(tp); it != packs->end()) return it->second; - else if (tp->persistent) + else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp)) return tp; return std::nullopt; } @@ -148,7 +161,7 @@ public: if (auto clone = find(ty)) return *clone; - else if (ty->persistent) + else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy)) return ty; TypeId target = arena->addType(ty->ty); @@ -174,7 +187,7 @@ public: if (auto clone = find(tp)) return *clone; - else if (tp->persistent) + else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp)) return tp; TypePackId target = arena->addTypePack(tp->ty); @@ -458,21 +471,37 @@ private: } // namespace -TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState) +TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent) { - if (tp->persistent) + if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent)) return tp; - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{ + NotNull{&dest}, + cloneState.builtinTypes, + NotNull{&cloneState.seenTypes}, + NotNull{&cloneState.seenTypePacks}, + nullptr, + FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? tp : nullptr + }; + return cloner.shallowClone(tp); } -TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState) +TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent) { - if (typeId->persistent) + if (typeId->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent)) return typeId; - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{ + NotNull{&dest}, + cloneState.builtinTypes, + NotNull{&cloneState.seenTypes}, + NotNull{&cloneState.seenTypePacks}, + FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? typeId : nullptr, + nullptr + }; + return cloner.shallowClone(typeId); } @@ -481,7 +510,7 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) if (tp->persistent) return tp; - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; return cloner.clone(tp); } @@ -490,13 +519,13 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (typeId->persistent) return typeId; - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; return cloner.clone(typeId); } TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) { - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; TypeFun copy = typeFun; @@ -521,4 +550,18 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) return copy; } +Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState) +{ + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; + + Binding b; + b.deprecated = binding.deprecated; + b.deprecatedSuggestion = binding.deprecatedSuggestion; + b.documentationSymbol = binding.documentationSymbol; + b.location = binding.location; + b.typeId = cloner.clone(binding.typeId); + + return b; +} + } // namespace Luau diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 90feb1a6..f77d7944 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -31,13 +31,14 @@ LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauMagicTypes) -LUAU_FASTFLAG(LuauTypestateBuiltins2) +LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses) LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) LUAU_FASTFLAGVARIABLE(LuauTrackInteriorFreeTypesOnScope) LUAU_FASTFLAGVARIABLE(InferGlobalTypes) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -1088,18 +1089,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); else if (const AstExprCall* call = value->as()) { - if (FFlag::LuauTypestateBuiltins2) - { - if (matchSetMetatable(*call)) - addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); - } - else - { - if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") - { - addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); - } - } + if (matchSetMetatable(*call)) + addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); } } @@ -2068,7 +2059,7 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; } - if (FFlag::LuauTypestateBuiltins2 && shouldTypestateForFirstArgument(*call) && call->args.size > 0 && isLValue(call->args.data[0])) + if (shouldTypestateForFirstArgument(*call) && call->args.size > 0 && isLValue(call->args.data[0])) { AstExpr* targetExpr = call->args.data[0]; auto resultTy = arena->addType(BlockedType{}); @@ -2217,7 +2208,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantStrin if (forceSingleton) return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - FreeType ft = FreeType{scope.get()}; + FreeType ft = + FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()}; ft.lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}}); ft.upperBound = builtinTypes->stringType; const TypeId freeTy = arena->addType(ft); @@ -2231,7 +2223,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool* if (forceSingleton) return Inference{singletonType}; - FreeType ft = FreeType{scope.get()}; + FreeType ft = + FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()}; ft.lowerBound = singletonType; ft.upperBound = builtinTypes->booleanType; const TypeId freeTy = arena->addType(ft); @@ -3427,6 +3420,12 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool } else if (auto unionAnnotation = ty->as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (unionAnnotation->types.size == 1) + return resolveType(scope, unionAnnotation->types.data[0], inTypeArguments); + } + std::vector parts; for (AstType* part : unionAnnotation->types) { @@ -3437,6 +3436,12 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool } else if (auto intersectionAnnotation = ty->as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (intersectionAnnotation->types.size == 1) + return resolveType(scope, intersectionAnnotation->types.data[0], inTypeArguments); + } + std::vector parts; for (AstType* part : intersectionAnnotation->types) { @@ -3445,6 +3450,10 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool result = arena->addType(IntersectionType{parts}); } + else if (auto typeGroupAnnotation = ty->as()) + { + result = resolveType(scope, typeGroupAnnotation->type, inTypeArguments); + } else if (auto boolAnnotation = ty->as()) { if (boolAnnotation->value) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 5b0c73e7..cb2f6bbf 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -31,7 +31,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies) LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) -LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack) LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification) LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) LUAU_FASTFLAGVARIABLE(LuauAllowNilAssignmentToIndexer) @@ -1161,22 +1160,8 @@ void ConstraintSolver::fillInDiscriminantTypes( continue; } - if (FFlag::LuauRemoveNotAnyHack) - { - // We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored. - emplaceType(asMutable(follow(*ty)), builtinTypes->noRefineType); - } - else - { - // We use `any` here because the discriminant type may be pointed at by both branches, - // where the discriminant type is not negated, and the other where it is negated, i.e. - // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` - // v.s. - // `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` - // - // In practice, users cannot negate `any`, so this is an implementation detail we can always change. - emplaceType(asMutable(follow(*ty)), builtinTypes->anyType); - } + // We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored. + emplaceType(asMutable(follow(*ty)), builtinTypes->noRefineType); } } @@ -1313,22 +1298,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(asMutable(follow(*ty)), builtinTypes->noRefineType); - } - else - { - // We use `any` here because the discriminant type may be pointed at by both branches, - // where the discriminant type is not negated, and the other where it is negated, i.e. - // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` - // v.s. - // `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` - // - // In practice, users cannot negate `any`, so this is an implementation detail we can always change. - emplaceType(asMutable(follow(*ty)), builtinTypes->anyType); - } + // We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored. + emplaceType(asMutable(follow(*ty)), builtinTypes->noRefineType); } } diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 3f724f2c..cff87858 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -13,7 +13,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauTypestateBuiltins2) namespace Luau { @@ -879,7 +878,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c) { visitExpr(c->func); - if (FFlag::LuauTypestateBuiltins2 && shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin())) + if (shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin())) { AstExpr* firstArg = *c->args.begin(); @@ -1170,6 +1169,8 @@ void DataFlowGraphBuilder::visitType(AstType* t) return; // ok else if (auto s = t->as()) return; // ok + else if (auto g = t->as()) + return visitType(g->type); else handle->ice("Unknown AstType in DataFlowGraphBuilder::visitType"); } diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index cde8125a..bc82d750 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -6,6 +6,7 @@ #include "Luau/Autocomplete.h" #include "Luau/Common.h" #include "Luau/EqSatSimplification.h" +#include "Luau/ModuleResolver.h" #include "Luau/Parser.h" #include "Luau/ParseOptions.h" #include "Luau/Module.h" @@ -19,16 +20,21 @@ #include "Luau/Parser.h" #include "Luau/ParseOptions.h" #include "Luau/Module.h" - +#include "Luau/Clone.h" #include "AutocompleteCore.h" - LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) +LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteBugfixes) +LUAU_FASTFLAG(LuauReferenceAllocatorInNewSolver) +LUAU_FASTFLAGVARIABLE(LuauMixedModeDefFinderTraversesTypeOf) +LUAU_FASTFLAG(LuauBetterReverseDependencyTracking) +LUAU_FASTFLAGVARIABLE(LuauCloneIncrementalModule) + namespace { template @@ -49,6 +55,96 @@ void copyModuleMap(Luau::DenseHashMap& result, const Luau::DenseHashMap +void cloneModuleMap(TypeArena& destArena, CloneState& cloneState, const Luau::DenseHashMap& source, Luau::DenseHashMap& dest) +{ + for (auto [k, v] : source) + { + dest[k] = Luau::clone(v, destArena, cloneState); + } +} + +struct MixedModeIncrementalTCDefFinder : public AstVisitor +{ + bool visit(AstExprLocal* local) override + { + referencedLocalDefs.emplace_back(local->local, local); + return true; + } + + bool visit(AstTypeTypeof* node) override + { + // We need to traverse typeof expressions because they may refer to locals that we need + // to populate the local environment for fragment typechecking. For example, `typeof(m)` + // requires that we find the local/global `m` and place it in the environment. + // The default behaviour here is to return false, and have individual visitors override + // the specific behaviour they need. + return FFlag::LuauMixedModeDefFinderTraversesTypeOf; + } + + // ast defs is just a mapping from expr -> def in general + // will get built up by the dfg builder + + // localDefs, we need to copy over + std::vector> referencedLocalDefs; +}; + +void cloneAndSquashScopes( + CloneState& cloneState, + const Scope* staleScope, + const ModulePtr& staleModule, + NotNull destArena, + NotNull dfg, + AstStatBlock* program, + Scope* destScope +) +{ + std::vector scopes; + for (const Scope* current = staleScope; current; current = current->parent.get()) + { + scopes.emplace_back(current); + } + + // in reverse order (we need to clone the parents and override defs as we go down the list) + for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) + { + const Scope* curr = *it; + // Clone the lvalue types + for (const auto& [def, ty] : curr->lvalueTypes) + destScope->lvalueTypes[def] = Luau::clone(ty, *destArena, cloneState); + // Clone the rvalueRefinements + for (const auto& [def, ty] : curr->rvalueRefinements) + destScope->rvalueRefinements[def] = Luau::clone(ty, *destArena, cloneState); + for (const auto& [n, m] : curr->importedTypeBindings) + { + std::unordered_map importedBindingTypes; + for (const auto& [v, tf] : m) + importedBindingTypes[v] = Luau::clone(tf, *destArena, cloneState); + destScope->importedTypeBindings[n] = m; + } + + // Finally, clone up the bindings + for (const auto& [s, b] : curr->bindings) + { + destScope->bindings[s] = Luau::clone(b, *destArena, cloneState); + } + } + + // The above code associates defs with TypeId's in the scope + // so that lookup to locals will succeed. + MixedModeIncrementalTCDefFinder finder; + program->visit(&finder); + std::vector> locals = std::move(finder.referencedLocalDefs); + for (auto [loc, expr] : locals) + { + if (std::optional binding = staleScope->linearSearchForBinding(loc->name.value, true)) + { + destScope->lvalueTypes[dfg->getDef(expr)] = Luau::clone(binding->typeId, *destArena, cloneState); + } + } + return; +} + static FrontendModuleResolver& getModuleResolver(Frontend& frontend, std::optional options) { if (FFlag::LuauSolverV2 || !options) @@ -265,13 +361,35 @@ std::optional parseFragment( return fragmentResult; } +ModulePtr cloneModule(CloneState& cloneState, const ModulePtr& source, std::unique_ptr alloc) +{ + freeze(source->internalTypes); + freeze(source->interfaceTypes); + ModulePtr incremental = std::make_shared(); + incremental->name = source->name; + incremental->humanReadableName = source->humanReadableName; + incremental->allocator = std::move(alloc); + // Clone types + cloneModuleMap(incremental->internalTypes, cloneState, source->astTypes, incremental->astTypes); + cloneModuleMap(incremental->internalTypes, cloneState, source->astTypePacks, incremental->astTypePacks); + cloneModuleMap(incremental->internalTypes, cloneState, source->astExpectedTypes, incremental->astExpectedTypes); + + cloneModuleMap(incremental->internalTypes, cloneState, source->astOverloadResolvedTypes, incremental->astOverloadResolvedTypes); + + cloneModuleMap(incremental->internalTypes, cloneState, source->astForInNextTypes, incremental->astForInNextTypes); + + copyModuleMap(incremental->astScopes, source->astScopes); + + return incremental; +} + ModulePtr copyModule(const ModulePtr& result, std::unique_ptr alloc) { - freeze(result->internalTypes); - freeze(result->interfaceTypes); ModulePtr incrementalModule = std::make_shared(); incrementalModule->name = result->name; - incrementalModule->humanReadableName = result->humanReadableName; + incrementalModule->humanReadableName = "Incremental$" + result->humanReadableName; + incrementalModule->internalTypes.owningModule = incrementalModule.get(); + incrementalModule->interfaceTypes.owningModule = incrementalModule.get(); incrementalModule->allocator = std::move(alloc); // Don't need to keep this alive (it's already on the source module) copyModuleVec(incrementalModule->scopes, result->scopes); @@ -290,21 +408,6 @@ ModulePtr copyModule(const ModulePtr& result, std::unique_ptr alloc) return incrementalModule; } -struct MixedModeIncrementalTCDefFinder : public AstVisitor -{ - bool visit(AstExprLocal* local) override - { - referencedLocalDefs.push_back({local->local, local}); - return true; - } - - // ast defs is just a mapping from expr -> def in general - // will get built up by the dfg builder - - // localDefs, we need to copy over - std::vector> referencedLocalDefs; -}; - void mixedModeCompatibility( const ScopePtr& bottomScopeStale, const ScopePtr& myFakeScope, @@ -343,7 +446,9 @@ FragmentTypeCheckResult typecheckFragment_( { freeze(stale->internalTypes); freeze(stale->interfaceTypes); - ModulePtr incrementalModule = copyModule(stale, std::move(astAllocator)); + CloneState cloneState{frontend.builtinTypes}; + ModulePtr incrementalModule = + FFlag::LuauCloneIncrementalModule ? cloneModule(cloneState, stale, std::move(astAllocator)) : copyModule(stale, std::move(astAllocator)); incrementalModule->checkedInNewSolver = true; unfreeze(incrementalModule->internalTypes); unfreeze(incrementalModule->interfaceTypes); @@ -391,25 +496,34 @@ FragmentTypeCheckResult typecheckFragment_( NotNull{&dfg}, {} }; + std::shared_ptr freshChildOfNearestScope = nullptr; + if (FFlag::LuauCloneIncrementalModule) + { + freshChildOfNearestScope = std::make_shared(closestScope); + incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); + cg.rootScope = freshChildOfNearestScope.get(); - cg.rootScope = stale->getModuleScope().get(); - // Any additions to the scope must occur in a fresh scope - auto freshChildOfNearestScope = std::make_shared(closestScope); - incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); - - // Update freshChildOfNearestScope with the appropriate lvalueTypes - mixedModeCompatibility(closestScope, freshChildOfNearestScope, stale, NotNull{&dfg}, root); - - // closest Scope -> children = { ...., freshChildOfNearestScope} - // We need to trim nearestChild from the scope hierarcy - closestScope->children.push_back(NotNull{freshChildOfNearestScope.get()}); - // Visit just the root - we know the scope it should be in - cg.visitFragmentRoot(freshChildOfNearestScope, root); - // Trim nearestChild from the closestScope - Scope* back = closestScope->children.back().get(); - LUAU_ASSERT(back == freshChildOfNearestScope.get()); - closestScope->children.pop_back(); - + cloneAndSquashScopes( + cloneState, closestScope.get(), stale, NotNull{&incrementalModule->internalTypes}, NotNull{&dfg}, root, freshChildOfNearestScope.get() + ); + cg.visitFragmentRoot(freshChildOfNearestScope, root); + } + else + { + // Any additions to the scope must occur in a fresh scope + cg.rootScope = stale->getModuleScope().get(); + freshChildOfNearestScope = std::make_shared(closestScope); + incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); + mixedModeCompatibility(closestScope, freshChildOfNearestScope, stale, NotNull{&dfg}, root); + // closest Scope -> children = { ...., freshChildOfNearestScope} + // We need to trim nearestChild from the scope hierarcy + closestScope->children.emplace_back(freshChildOfNearestScope.get()); + cg.visitFragmentRoot(freshChildOfNearestScope, root); + // Trim nearestChild from the closestScope + Scope* back = closestScope->children.back().get(); + LUAU_ASSERT(back == freshChildOfNearestScope.get()); + closestScope->children.pop_back(); + } /// Initialize the constraint solver and run it ConstraintSolver cs{ @@ -458,6 +572,13 @@ std::pair typecheckFragment( std::optional fragmentEndPosition ) { + + if (FFlag::LuauBetterReverseDependencyTracking) + { + if (!frontend.allModuleDependenciesValid(moduleName, opts && opts->forAutocomplete)) + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + } + const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) { @@ -473,6 +594,14 @@ std::pair typecheckFragment( return {}; } + if (FFlag::LuauIncrementalAutocompleteBugfixes && FFlag::LuauReferenceAllocatorInNewSolver) + { + if (sourceModule->allocator.get() != module->allocator.get()) + { + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + } + } + auto tryParse = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition); if (!tryParse) diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 1ce35d76..0292726b 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -47,6 +47,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode) LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) +LUAU_FASTFLAGVARIABLE(LuauBetterReverseDependencyTracking) + LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule) @@ -820,6 +822,16 @@ bool Frontend::parseGraph( topseen = Permanent; buildQueue.push_back(top->name); + + if (FFlag::LuauBetterReverseDependencyTracking) + { + // at this point we know all valid dependencies are processed into SourceNodes + for (const ModuleName& dep : top->requireSet) + { + if (auto it = sourceNodes.find(dep); it != sourceNodes.end()) + it->second->dependents.insert(top->name); + } + } } else { @@ -1107,15 +1119,49 @@ void Frontend::recordItemResult(const BuildQueueItem& item) if (item.exception) std::rethrow_exception(item.exception); - if (item.options.forAutocomplete) + if (FFlag::LuauBetterReverseDependencyTracking) { - moduleResolverForAutocomplete.setModule(item.name, item.module); - item.sourceNode->dirtyModuleForAutocomplete = false; + bool replacedModule = false; + if (item.options.forAutocomplete) + { + replacedModule = moduleResolverForAutocomplete.setModule(item.name, item.module); + item.sourceNode->dirtyModuleForAutocomplete = false; + } + else + { + replacedModule = moduleResolver.setModule(item.name, item.module); + item.sourceNode->dirtyModule = false; + } + + if (replacedModule) + { + LUAU_TIMETRACE_SCOPE("Frontend::invalidateDependentModules", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", item.name.c_str()); + traverseDependents( + item.name, + [forAutocomplete = item.options.forAutocomplete](SourceNode& sourceNode) + { + bool traverseSubtree = !sourceNode.hasInvalidModuleDependency(forAutocomplete); + sourceNode.setInvalidModuleDependency(true, forAutocomplete); + return traverseSubtree; + } + ); + } + + item.sourceNode->setInvalidModuleDependency(false, item.options.forAutocomplete); } else { - moduleResolver.setModule(item.name, item.module); - item.sourceNode->dirtyModule = false; + if (item.options.forAutocomplete) + { + moduleResolverForAutocomplete.setModule(item.name, item.module); + item.sourceNode->dirtyModuleForAutocomplete = false; + } + else + { + moduleResolver.setModule(item.name, item.module); + item.sourceNode->dirtyModule = false; + } } stats.timeCheck += item.stats.timeCheck; @@ -1152,6 +1198,13 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config return result; } +bool Frontend::allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete) const +{ + LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking); + auto it = sourceNodes.find(name); + return it != sourceNodes.end() && !it->second->hasInvalidModuleDependency(forAutocomplete); +} + bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); @@ -1166,16 +1219,80 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { + LUAU_TIMETRACE_SCOPE("Frontend::markDirty", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + + if (FFlag::LuauBetterReverseDependencyTracking) + { + traverseDependents( + name, + [markedDirty](SourceNode& sourceNode) + { + if (markedDirty) + markedDirty->push_back(sourceNode.name); + + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + return false; + + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; + + return true; + } + ); + } + else + { + if (sourceNodes.count(name) == 0) + return; + + std::unordered_map> reverseDeps; + for (const auto& module : sourceNodes) + { + for (const auto& dep : module.second->requireSet) + reverseDeps[dep].push_back(module.first); + } + + std::vector queue{name}; + + while (!queue.empty()) + { + ModuleName next = std::move(queue.back()); + queue.pop_back(); + + LUAU_ASSERT(sourceNodes.count(next) > 0); + SourceNode& sourceNode = *sourceNodes[next]; + + if (markedDirty) + markedDirty->push_back(next); + + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + continue; + + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; + + if (0 == reverseDeps.count(next)) + continue; + + sourceModules.erase(next); + + const std::vector& dependents = reverseDeps[next]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); + } + } +} + +void Frontend::traverseDependents(const ModuleName& name, std::function processSubtree) +{ + LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking); + LUAU_TIMETRACE_SCOPE("Frontend::traverseDependents", "Frontend"); + if (sourceNodes.count(name) == 0) return; - std::unordered_map> reverseDeps; - for (const auto& module : sourceNodes) - { - for (const auto& dep : module.second->requireSet) - reverseDeps[dep].push_back(module.first); - } - std::vector queue{name}; while (!queue.empty()) @@ -1186,22 +1303,10 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked LUAU_ASSERT(sourceNodes.count(next) > 0); SourceNode& sourceNode = *sourceNodes[next]; - if (markedDirty) - markedDirty->push_back(next); - - if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + if (!processSubtree(sourceNode)) continue; - sourceNode.dirtySourceModule = true; - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; - - if (0 == reverseDeps.count(next)) - continue; - - sourceModules.erase(next); - - const std::vector& dependents = reverseDeps[next]; + const Set& dependents = sourceNode.dependents; queue.insert(queue.end(), dependents.begin(), dependents.end()); } } @@ -1643,6 +1748,17 @@ std::pair Frontend::getSourceNode(const ModuleName& sourceNode->name = sourceModule->name; sourceNode->humanReadableName = sourceModule->humanReadableName; + + if (FFlag::LuauBetterReverseDependencyTracking) + { + // clear all prior dependents. we will re-add them after parsing the rest of the graph + for (const auto& [moduleName, _] : sourceNode->requireLocations) + { + if (auto depIt = sourceNodes.find(moduleName); depIt != sourceNodes.end()) + depIt->second->dependents.erase(sourceNode->name); + } + } + sourceNode->requireSet.clear(); sourceNode->requireLocations.clear(); sourceNode->dirtySourceModule = false; @@ -1764,11 +1880,21 @@ std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& return frontend->fileResolver->getHumanReadableModuleName(moduleName); } -void FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module) +bool FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module) { std::scoped_lock lock(moduleMutex); - modules[moduleName] = std::move(module); + if (FFlag::LuauBetterReverseDependencyTracking) + { + bool replaced = modules.count(moduleName) > 0; + modules[moduleName] = std::move(module); + return replaced; + } + else + { + modules[moduleName] = std::move(module); + return false; + } } void FrontendModuleResolver::clearModules() diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index d9d0d3b0..79b7f03e 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -163,7 +164,7 @@ TypeId ReplaceGenerics::clean(TypeId ty) } else { - return addType(FreeType{scope, level}); + return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, scope, level) : addType(FreeType{scope, level}); } } diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index f830d126..0645e4e2 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -20,6 +20,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauCountSelfCallsNonstrict) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -211,7 +212,7 @@ struct NonStrictTypeChecker return *fst; else if (auto ftp = get(pack)) { - TypeId result = arena->addType(FreeType{ftp->scope}); + TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, ftp->scope) : arena->addType(FreeType{ftp->scope}); TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope}); TypePack* resultPack = emplaceTypePack(asMutable(pack)); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 80398bf7..9aa6fb97 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -18,10 +18,10 @@ LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant) LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000) -LUAU_FASTFLAG(LuauSolverV2) - LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200) -LUAU_FASTFLAGVARIABLE(LuauNormalizationTracksCyclicPairsThroughInhabitance) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(LuauFixInfiniteRecursionInNormalization) +LUAU_FASTFLAGVARIABLE(LuauFixNormalizedIntersectionOfNegatedClass) namespace Luau { @@ -2284,9 +2284,24 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th else if (isSubclass(there, hereTy)) { TypeIds negations = std::move(hereNegations); + bool emptyIntersectWithNegation = false; for (auto nIt = negations.begin(); nIt != negations.end();) { + if (FFlag::LuauFixNormalizedIntersectionOfNegatedClass && isSubclass(there, *nIt)) + { + // Hitting this block means that the incoming class is a + // subclass of this type, _and_ one of its negations is a + // superclass of this type, e.g.: + // + // Dog & ~Animal + // + // Clearly this intersects to never, so we mark this class as + // being removed from the normalized class type. + emptyIntersectWithNegation = true; + break; + } + if (!isSubclass(*nIt, there)) { nIt = negations.erase(nIt); @@ -2299,7 +2314,8 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th it = heres.ordering.erase(it); heres.classes.erase(hereTy); - heres.pushPair(there, std::move(negations)); + if (!emptyIntersectWithNegation) + heres.pushPair(there, std::move(negations)); break; } // If the incoming class is a superclass of the current class, we don't @@ -2584,11 +2600,31 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there { if (tprop.readTy.has_value()) { - // if the intersection of the read types of a property is uninhabited, the whole table is `never`. - // We've seen these table prop elements before and we're about to ask if their intersection - // is inhabited - if (FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance) + if (FFlag::LuauFixInfiniteRecursionInNormalization) { + TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; + + // If any property is going to get mapped to `never`, we can just call the entire table `never`. + // Since this check is syntactic, we may sometimes miss simplifying tables with complex uninhabited properties. + // Prior versions of this code attempted to do this semantically using the normalization machinery, but this + // mistakenly causes infinite loops when giving more complex recursive table types. As it stands, this approach + // will continue to scale as simplification is improved, but we may wish to reintroduce the semantic approach + // once we have revisited the usage of seen sets systematically (and possibly with some additional guarding to recognize + // when types are infinitely-recursive with non-pointer identical instances of them, or some guard to prevent that + // construction altogether). See also: `gh1632_no_infinite_recursion_in_normalization` + if (get(ty)) + return {builtinTypes->neverType}; + + prop.readTy = ty; + hereSubThere &= (ty == hprop.readTy); + thereSubHere &= (ty == tprop.readTy); + } + else + { + // if the intersection of the read types of a property is uninhabited, the whole table is `never`. + // We've seen these table prop elements before and we're about to ask if their intersection + // is inhabited + auto pair1 = std::pair{*hprop.readTy, *tprop.readTy}; auto pair2 = std::pair{*tprop.readTy, *hprop.readTy}; if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2)) @@ -2603,6 +2639,8 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there seenTablePropPairs.insert(pair2); } + // FIXME(ariel): this is being added in a flag removal, so not changing the semantics here, but worth noting that this + // fresh `seenSet` is definitely a bug. we already have `seenSet` from the parameter that _should_ have been used here. Set seenSet{nullptr}; NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet); @@ -2616,34 +2654,6 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there hereSubThere &= (ty == hprop.readTy); thereSubHere &= (ty == tprop.readTy); } - else - { - - if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) - { - seenSet.erase(*hprop.readTy); - seenSet.erase(*tprop.readTy); - return {builtinTypes->neverType}; - } - else - { - seenSet.insert(*hprop.readTy); - seenSet.insert(*tprop.readTy); - } - - NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy); - - seenSet.erase(*hprop.readTy); - seenSet.erase(*tprop.readTy); - - if (NormalizationResult::True != res) - return {builtinTypes->neverType}; - - TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; - prop.readTy = ty; - hereSubThere &= (ty == hprop.readTy); - thereSubHere &= (ty == tprop.readTy); - } } else { diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index 6cb511eb..8a0483e6 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -31,16 +31,16 @@ struct TypeSimplifier int recursionDepth = 0; - TypeId mkNegation(TypeId ty); + TypeId mkNegation(TypeId ty) const; TypeId intersectFromParts(std::set parts); - TypeId intersectUnionWithType(TypeId unionTy, TypeId right); + TypeId intersectUnionWithType(TypeId left, TypeId right); TypeId intersectUnions(TypeId left, TypeId right); - TypeId intersectNegatedUnion(TypeId unionTy, TypeId right); + TypeId intersectNegatedUnion(TypeId left, TypeId right); - TypeId intersectTypeWithNegation(TypeId a, TypeId b); - TypeId intersectNegations(TypeId a, TypeId b); + TypeId intersectTypeWithNegation(TypeId left, TypeId right); + TypeId intersectNegations(TypeId left, TypeId right); TypeId intersectIntersectionWithType(TypeId left, TypeId right); @@ -48,8 +48,8 @@ struct TypeSimplifier // unions, intersections, or negations. std::optional basicIntersect(TypeId left, TypeId right); - TypeId intersect(TypeId ty, TypeId discriminant); - TypeId union_(TypeId ty, TypeId discriminant); + TypeId intersect(TypeId left, TypeId right); + TypeId union_(TypeId left, TypeId right); TypeId simplify(TypeId ty); TypeId simplify(TypeId ty, DenseHashSet& seen); @@ -573,7 +573,7 @@ Relation relate(TypeId left, TypeId right) return relate(left, right, seen); } -TypeId TypeSimplifier::mkNegation(TypeId ty) +TypeId TypeSimplifier::mkNegation(TypeId ty) const { TypeId result = nullptr; diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index 40132500..e4985a02 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -22,7 +22,6 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity) -LUAU_FASTFLAGVARIABLE(LuauRetrySubtypingWithoutHiddenPack) namespace Luau { @@ -1474,7 +1473,7 @@ SubtypingResult Subtyping::isCovariantWith( // If subtyping failed in the argument packs, we should check if there's a hidden variadic tail and try ignoring it. // This might cause subtyping correctly because the sub type here may not have a hidden variadic tail or equivalent. - if (FFlag::LuauRetrySubtypingWithoutHiddenPack && !result.isSubtype) + if (!result.isSubtype) { auto [arguments, tail] = flatten(superFunction->argTypes); diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index a42882ed..eee41f24 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,6 +10,9 @@ #include #include +LUAU_FASTFLAG(LuauStoreCSTData) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAG(LuauAstTypeGroup) namespace { @@ -45,11 +48,13 @@ struct Writer virtual void space() = 0; virtual void maybeSpace(const Position& newPos, int reserve) = 0; virtual void write(std::string_view) = 0; + virtual void writeMultiline(std::string_view) = 0; virtual void identifier(std::string_view name) = 0; virtual void keyword(std::string_view) = 0; virtual void symbol(std::string_view) = 0; virtual void literal(std::string_view) = 0; virtual void string(std::string_view) = 0; + virtual void sourceString(std::string_view, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth) = 0; }; struct StringWriter : Writer @@ -93,6 +98,32 @@ struct StringWriter : Writer lastChar = ' '; } + void writeMultiline(std::string_view s) override + { + if (s.empty()) + return; + + ss.append(s.data(), s.size()); + lastChar = s[s.size() - 1]; + + size_t index = 0; + size_t numLines = 0; + while (true) + { + auto newlinePos = s.find('\n', index); + if (newlinePos == std::string::npos) + break; + numLines++; + index = newlinePos + 1; + } + + pos.line += unsigned(numLines); + if (numLines > 0) + pos.column = unsigned(s.size()) - unsigned(index); + else + pos.column += unsigned(s.size()); + } + void write(std::string_view s) override { if (s.empty()) @@ -134,10 +165,17 @@ struct StringWriter : Writer void symbol(std::string_view s) override { - if (isDigit(lastChar) && s[0] == '.') - space(); + if (FFlag::LuauStoreCSTData) + { + write(s); + } + else + { + if (isDigit(lastChar) && s[0] == '.') + space(); - write(s); + write(s); + } } void literal(std::string_view s) override @@ -161,14 +199,54 @@ struct StringWriter : Writer write(escape(s)); write(quote); } + + void sourceString(std::string_view s, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth) override + { + if (quoteStyle == CstExprConstantString::QuotedRaw) + { + auto blocks = std::string(blockDepth, '='); + write('['); + write(blocks); + write('['); + writeMultiline(s); + write(']'); + write(blocks); + write(']'); + } + else + { + LUAU_ASSERT(blockDepth == 0); + + char quote = '"'; + switch (quoteStyle) + { + case CstExprConstantString::QuotedDouble: + quote = '"'; + break; + case CstExprConstantString::QuotedSingle: + quote = '\''; + break; + case CstExprConstantString::QuotedInterp: + quote = '`'; + break; + default: + LUAU_ASSERT(!"Unhandled quote type"); + } + + write(quote); + writeMultiline(s); + write(quote); + } + } }; class CommaSeparatorInserter { public: - CommaSeparatorInserter(Writer& w) + explicit CommaSeparatorInserter(Writer& w, const Position* commaPosition = nullptr) : first(true) , writer(w) + , commaPosition(commaPosition) { } void operator()() @@ -176,17 +254,25 @@ public: if (first) first = !first; else + { + if (FFlag::LuauStoreCSTData && commaPosition) + { + writer.advance(*commaPosition); + commaPosition++; + } writer.symbol(","); + } } private: bool first; Writer& writer; + const Position* commaPosition; }; -struct Printer +struct Printer_DEPRECATED { - explicit Printer(Writer& writer) + explicit Printer_DEPRECATED(Writer& writer) : writer(writer) { } @@ -242,7 +328,8 @@ struct Printer } else if (typeCount == 1) { - if (unconditionallyParenthesize) + bool shouldParenthesize = unconditionallyParenthesize && (list.types.size == 0 || !list.types.data[0]->is()); + if (FFlag::LuauAstTypeGroup ? shouldParenthesize : unconditionallyParenthesize) writer.symbol("("); // Only variadic tail @@ -255,7 +342,7 @@ struct Printer visualizeTypeAnnotation(*list.types.data[0]); } - if (unconditionallyParenthesize) + if (FFlag::LuauAstTypeGroup ? shouldParenthesize : unconditionallyParenthesize) writer.symbol(")"); } else @@ -433,6 +520,7 @@ struct Printer visualize(*item.value); } + // Decrement endPos column so that we advance to before the closing `}` brace before writing, rather than after it Position endPos = expr.location.end; if (endPos.column > 0) --endPos.column; @@ -1164,6 +1252,12 @@ struct Printer writer.symbol(")"); } } + else if (const auto& a = typeAnnotation.as()) + { + writer.symbol("("); + visualizeTypeAnnotation(*a->type); + writer.symbol(")"); + } else if (const auto& a = typeAnnotation.as()) { writer.keyword(a->value ? "true" : "false"); @@ -1183,20 +1277,1349 @@ struct Printer } }; +struct Printer +{ + explicit Printer(Writer& writer, CstNodeMap cstNodeMap) + : writer(writer) + , cstNodeMap(std::move(cstNodeMap)) + { + } + + bool writeTypes = false; + Writer& writer; + CstNodeMap cstNodeMap; + + template + T* lookupCstNode(AstNode* astNode) + { + if (const auto cstNode = cstNodeMap[astNode]) + return cstNode->as(); + return nullptr; + } + + void visualize(const AstLocal& local) + { + advance(local.location.begin); + + writer.identifier(local.name.value); + if (writeTypes && local.annotation) + { + // TODO: handle spacing for type annotation + writer.symbol(":"); + visualizeTypeAnnotation(*local.annotation); + } + } + + void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) + { + advance(annotation.location.begin); + if (const AstTypePackVariadic* variadicTp = annotation.as()) + { + if (!forVarArg) + writer.symbol("..."); + + visualizeTypeAnnotation(*variadicTp->variadicType); + } + else if (const AstTypePackGeneric* genericTp = annotation.as()) + { + writer.symbol(genericTp->genericName.value); + writer.symbol("..."); + } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + LUAU_ASSERT(!forVarArg); + visualizeTypeList(explicitTp->typeList, true); + } + else + { + LUAU_ASSERT(!"Unknown TypePackAnnotation kind"); + } + } + + void visualizeTypeList(const AstTypeList& list, bool unconditionallyParenthesize) + { + size_t typeCount = list.types.size + (list.tailType != nullptr ? 1 : 0); + if (typeCount == 0) + { + writer.symbol("("); + writer.symbol(")"); + } + else if (typeCount == 1) + { + bool shouldParenthesize = unconditionallyParenthesize && (list.types.size == 0 || !list.types.data[0]->is()); + if (FFlag::LuauAstTypeGroup ? shouldParenthesize : unconditionallyParenthesize) + writer.symbol("("); + + // Only variadic tail + if (list.types.size == 0) + { + visualizeTypePackAnnotation(*list.tailType, false); + } + else + { + visualizeTypeAnnotation(*list.types.data[0]); + } + + if (FFlag::LuauAstTypeGroup ? shouldParenthesize : unconditionallyParenthesize) + writer.symbol(")"); + } + else + { + writer.symbol("("); + + bool first = true; + for (const auto& el : list.types) + { + if (first) + first = false; + else + writer.symbol(","); + + visualizeTypeAnnotation(*el); + } + + if (list.tailType) + { + writer.symbol(","); + visualizeTypePackAnnotation(*list.tailType, false); + } + + writer.symbol(")"); + } + } + + bool isIntegerish(double d) + { + if (d <= std::numeric_limits::max() && d >= std::numeric_limits::min()) + return double(int(d)) == d && !(d == 0.0 && signbit(d)); + else + return false; + } + + void visualize(AstExpr& expr) + { + advance(expr.location.begin); + + if (const auto& a = expr.as()) + { + writer.symbol("("); + visualize(*a->expr); + advance(Position{a->location.end.line, a->location.end.column - 1}); + writer.symbol(")"); + } + else if (expr.is()) + { + writer.keyword("nil"); + } + else if (const auto& a = expr.as()) + { + if (a->value) + writer.keyword("true"); + else + writer.keyword("false"); + } + else if (const auto& a = expr.as()) + { + if (const auto cstNode = lookupCstNode(a)) + { + writer.literal(std::string_view(cstNode->value.data, cstNode->value.size)); + } + else + { + if (isinf(a->value)) + { + if (a->value > 0) + writer.literal("1e500"); + else + writer.literal("-1e500"); + } + else if (isnan(a->value)) + writer.literal("0/0"); + else + { + if (isIntegerish(a->value)) + writer.literal(std::to_string(int(a->value))); + else + { + char buffer[100]; + size_t len = snprintf(buffer, sizeof(buffer), "%.17g", a->value); + writer.literal(std::string_view{buffer, len}); + } + } + } + } + else if (const auto& a = expr.as()) + { + if (const auto cstNode = lookupCstNode(a)) + { + writer.sourceString( + std::string_view(cstNode->sourceString.data, cstNode->sourceString.size), cstNode->quoteStyle, cstNode->blockDepth + ); + } + else + writer.string(std::string_view(a->value.data, a->value.size)); + } + else if (const auto& a = expr.as()) + { + writer.identifier(a->local->name.value); + } + else if (const auto& a = expr.as()) + { + writer.identifier(a->name.value); + } + else if (expr.is()) + { + writer.symbol("..."); + } + else if (const auto& a = expr.as()) + { + visualize(*a->func); + + const auto cstNode = lookupCstNode(a); + + if (cstNode) + { + if (cstNode->openParens) + { + advance(*cstNode->openParens); + writer.symbol("("); + } + } + else + { + writer.symbol("("); + } + + CommaSeparatorInserter comma(writer, cstNode ? cstNode->commaPositions.begin() : nullptr); + for (const auto& arg : a->args) + { + comma(); + visualize(*arg); + } + + if (cstNode) + { + if (cstNode->closeParens) + { + advance(*cstNode->closeParens); + writer.symbol(")"); + } + } + else + { + writer.symbol(")"); + } + } + else if (const auto& a = expr.as()) + { + visualize(*a->expr); + advance(a->opPosition); + writer.symbol(std::string(1, a->op)); + advance(a->indexLocation.begin); + writer.write(a->index.value); + } + else if (const auto& a = expr.as()) + { + const auto cstNode = lookupCstNode(a); + visualize(*a->expr); + if (cstNode) + advance(cstNode->openBracketPosition); + writer.symbol("["); + visualize(*a->index); + if (cstNode) + advance(cstNode->closeBracketPosition); + writer.symbol("]"); + } + else if (const auto& a = expr.as()) + { + writer.keyword("function"); + visualizeFunctionBody(*a); + } + else if (const auto& a = expr.as()) + { + writer.symbol("{"); + + const CstExprTable::Item* cstItem = nullptr; + if (const auto cstNode = lookupCstNode(a)) + { + LUAU_ASSERT(cstNode->items.size == a->items.size); + cstItem = cstNode->items.begin(); + } + + bool first = true; + + for (const auto& item : a->items) + { + if (!cstItem) + { + if (first) + first = false; + else + writer.symbol(","); + } + + switch (item.kind) + { + case AstExprTable::Item::List: + break; + + case AstExprTable::Item::Record: + { + const auto& value = item.key->as()->value; + advance(item.key->location.begin); + writer.identifier(std::string_view(value.data, value.size)); + if (cstItem) + advance(*cstItem->equalsPosition); + else + writer.maybeSpace(item.value->location.begin, 1); + writer.symbol("="); + } + break; + + case AstExprTable::Item::General: + { + if (cstItem) + advance(*cstItem->indexerOpenPosition); + writer.symbol("["); + visualize(*item.key); + if (cstItem) + advance(*cstItem->indexerClosePosition); + writer.symbol("]"); + if (cstItem) + advance(*cstItem->equalsPosition); + else + writer.maybeSpace(item.value->location.begin, 1); + writer.symbol("="); + } + break; + + default: + LUAU_ASSERT(!"Unknown table item kind"); + } + + advance(item.value->location.begin); + visualize(*item.value); + + if (cstItem) + { + if (cstItem->separator) + { + LUAU_ASSERT(cstItem->separatorPosition); + advance(*cstItem->separatorPosition); + if (cstItem->separator == CstExprTable::Comma) + writer.symbol(","); + else if (cstItem->separator == CstExprTable::Semicolon) + writer.symbol(";"); + } + cstItem++; + } + } + + Position endPos = expr.location.end; + if (endPos.column > 0) + --endPos.column; + + advance(endPos); + + writer.symbol("}"); + advance(expr.location.end); + } + else if (const auto& a = expr.as()) + { + if (const auto cstNode = lookupCstNode(a)) + advance(cstNode->opPosition); + + switch (a->op) + { + case AstExprUnary::Not: + writer.keyword("not"); + break; + case AstExprUnary::Minus: + writer.symbol("-"); + break; + case AstExprUnary::Len: + writer.symbol("#"); + break; + } + visualize(*a->expr); + } + else if (const auto& a = expr.as()) + { + visualize(*a->left); + + if (const auto cstNode = lookupCstNode(a)) + advance(cstNode->opPosition); + else + { + switch (a->op) + { + case AstExprBinary::Add: + case AstExprBinary::Sub: + case AstExprBinary::Mul: + case AstExprBinary::Div: + case AstExprBinary::FloorDiv: + case AstExprBinary::Mod: + case AstExprBinary::Pow: + case AstExprBinary::CompareLt: + case AstExprBinary::CompareGt: + writer.maybeSpace(a->right->location.begin, 2); + break; + case AstExprBinary::Concat: + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + case AstExprBinary::CompareLe: + case AstExprBinary::CompareGe: + case AstExprBinary::Or: + writer.maybeSpace(a->right->location.begin, 3); + break; + case AstExprBinary::And: + writer.maybeSpace(a->right->location.begin, 4); + break; + default: + LUAU_ASSERT(!"Unknown Op"); + } + } + + writer.symbol(toString(a->op)); + + visualize(*a->right); + } + else if (const auto& a = expr.as()) + { + visualize(*a->expr); + + if (writeTypes) + { + writer.maybeSpace(a->annotation->location.begin, 2); + writer.symbol("::"); + visualizeTypeAnnotation(*a->annotation); + } + } + else if (const auto& a = expr.as()) + { + writer.keyword("if"); + visualizeElseIfExpr(*a); + } + else if (const auto& a = expr.as()) + { + const auto* cstNode = lookupCstNode(a); + + writer.symbol("`"); + + size_t index = 0; + + for (const auto& string : a->strings) + { + if (cstNode) + { + if (index > 0) + { + advance(cstNode->stringPositions.data[index]); + writer.symbol("}"); + } + const AstArray sourceString = cstNode->sourceStrings.data[index]; + writer.writeMultiline(std::string_view(sourceString.data, sourceString.size)); + } + else + { + writer.write(escape(std::string_view(string.data, string.size), /* escapeForInterpString = */ true)); + } + + if (index < a->expressions.size) + { + writer.symbol("{"); + visualize(*a->expressions.data[index]); + if (!cstNode) + writer.symbol("}"); + } + + index++; + } + + writer.symbol("`"); + } + else if (const auto& a = expr.as()) + { + writer.symbol("(error-expr"); + + for (size_t i = 0; i < a->expressions.size; i++) + { + writer.symbol(i == 0 ? ": " : ", "); + visualize(*a->expressions.data[i]); + } + + writer.symbol(")"); + } + else + { + LUAU_ASSERT(!"Unknown AstExpr"); + } + } + + void writeEnd(const Location& loc) + { + Position endPos = loc.end; + if (endPos.column >= 3) + endPos.column -= 3; + advance(endPos); + writer.keyword("end"); + } + + void advance(const Position& newPos) + { + writer.advance(newPos); + } + + void visualize(AstStat& program) + { + advance(program.location.begin); + + if (const auto& block = program.as()) + { + writer.keyword("do"); + for (const auto& s : block->body) + visualize(*s); + if (const auto cstNode = lookupCstNode(block)) + { + advance(cstNode->endPosition); + writer.keyword("end"); + } + else + { + writer.advance(block->location.end); + writeEnd(program.location); + } + } + else if (const auto& a = program.as()) + { + writer.keyword("if"); + visualizeElseIf(*a); + } + else if (const auto& a = program.as()) + { + writer.keyword("while"); + visualize(*a->condition); + // TODO: what if 'hasDo = false'? + advance(a->doLocation.begin); + writer.keyword("do"); + visualizeBlock(*a->body); + advance(a->body->location.end); + writer.keyword("end"); + } + else if (const auto& a = program.as()) + { + writer.keyword("repeat"); + visualizeBlock(*a->body); + if (const auto cstNode = lookupCstNode(a)) + writer.advance(cstNode->untilPosition); + else if (a->condition->location.begin.column > 5) + writer.advance(Position{a->condition->location.begin.line, a->condition->location.begin.column - 6}); + writer.keyword("until"); + visualize(*a->condition); + } + else if (program.is()) + writer.keyword("break"); + else if (program.is()) + writer.keyword("continue"); + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("return"); + + CommaSeparatorInserter comma(writer, cstNode ? cstNode->commaPositions.begin() : nullptr); + for (const auto& expr : a->list) + { + comma(); + visualize(*expr); + } + } + else if (const auto& a = program.as()) + { + visualize(*a->expr); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("local"); + + CommaSeparatorInserter varComma(writer, cstNode ? cstNode->varsCommaPositions.begin() : nullptr); + for (const auto& local : a->vars) + { + varComma(); + visualize(*local); + } + + if (a->equalsSignLocation) + { + advance(a->equalsSignLocation->begin); + writer.symbol("="); + } + + + CommaSeparatorInserter valueComma(writer, cstNode ? cstNode->valuesCommaPositions.begin() : nullptr); + for (const auto& value : a->values) + { + valueComma(); + visualize(*value); + } + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("for"); + + visualize(*a->var); + if (cstNode) + advance(cstNode->equalsPosition); + writer.symbol("="); + visualize(*a->from); + if (cstNode) + advance(cstNode->endCommaPosition); + writer.symbol(","); + visualize(*a->to); + if (a->step) + { + if (cstNode && cstNode->stepCommaPosition) + advance(*cstNode->stepCommaPosition); + writer.symbol(","); + visualize(*a->step); + } + advance(a->doLocation.begin); + writer.keyword("do"); + visualizeBlock(*a->body); + + advance(a->body->location.end); + writer.keyword("end"); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("for"); + + CommaSeparatorInserter varComma(writer, cstNode ? cstNode->varsCommaPositions.begin() : nullptr); + for (const auto& var : a->vars) + { + varComma(); + visualize(*var); + } + + advance(a->inLocation.begin); + writer.keyword("in"); + + CommaSeparatorInserter valComma(writer, cstNode ? cstNode->valuesCommaPositions.begin() : nullptr); + + for (const auto& val : a->values) + { + valComma(); + visualize(*val); + } + + advance(a->doLocation.begin); + writer.keyword("do"); + + visualizeBlock(*a->body); + + advance(a->body->location.end); + writer.keyword("end"); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + CommaSeparatorInserter varComma(writer, cstNode ? cstNode->varsCommaPositions.begin() : nullptr); + for (const auto& var : a->vars) + { + varComma(); + visualize(*var); + } + + if (cstNode) + advance(cstNode->equalsPosition); + else + writer.space(); + writer.symbol("="); + + CommaSeparatorInserter valueComma(writer, cstNode ? cstNode->valuesCommaPositions.begin() : nullptr); + for (const auto& value : a->values) + { + valueComma(); + visualize(*value); + } + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + visualize(*a->var); + + if (cstNode) + advance(cstNode->opPosition); + + switch (a->op) + { + case AstExprBinary::Add: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("+="); + break; + case AstExprBinary::Sub: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("-="); + break; + case AstExprBinary::Mul: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("*="); + break; + case AstExprBinary::Div: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("/="); + break; + case AstExprBinary::FloorDiv: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 3); + writer.symbol("//="); + break; + case AstExprBinary::Mod: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("%="); + break; + case AstExprBinary::Pow: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("^="); + break; + case AstExprBinary::Concat: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 3); + writer.symbol("..="); + break; + default: + LUAU_ASSERT(!"Unexpected compound assignment op"); + } + + visualize(*a->value); + } + else if (const auto& a = program.as()) + { + writer.keyword("function"); + visualize(*a->name); + visualizeFunctionBody(*a->func); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("local"); + + if (cstNode) + advance(cstNode->functionKeywordPosition); + else + writer.space(); + + writer.keyword("function"); + advance(a->name->location.begin); + writer.identifier(a->name->name.value); + visualizeFunctionBody(*a->func); + } + else if (const auto& a = program.as()) + { + if (writeTypes) + { + if (a->exported) + writer.keyword("export"); + + writer.keyword("type"); + writer.identifier(a->name.value); + if (a->generics.size > 0 || a->genericPacks.size > 0) + { + writer.symbol("<"); + CommaSeparatorInserter comma(writer); + + for (auto o : a->generics) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*o.defaultValue); + } + } + + for (auto o : a->genericPacks) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypePackAnnotation(*o.defaultValue, false); + } + } + + writer.symbol(">"); + } + writer.maybeSpace(a->type->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*a->type); + } + } + else if (const auto& t = program.as()) + { + if (writeTypes) + { + writer.keyword("type function"); + writer.identifier(t->name.value); + visualizeFunctionBody(*t->body); + } + } + else if (const auto& a = program.as()) + { + writer.symbol("(error-stat"); + + for (size_t i = 0; i < a->expressions.size; i++) + { + writer.symbol(i == 0 ? ": " : ", "); + visualize(*a->expressions.data[i]); + } + + for (size_t i = 0; i < a->statements.size; i++) + { + writer.symbol(i == 0 && a->expressions.size == 0 ? ": " : ", "); + visualize(*a->statements.data[i]); + } + + writer.symbol(")"); + } + else + { + LUAU_ASSERT(!"Unknown AstStat"); + } + + if (program.hasSemicolon) + { + if (FFlag::LuauStoreCSTData) + advance(Position{program.location.end.line, program.location.end.column - 1}); + writer.symbol(";"); + } + } + + void visualizeFunctionBody(AstExprFunction& func) + { + if (func.generics.size > 0 || func.genericPacks.size > 0) + { + CommaSeparatorInserter comma(writer); + writer.symbol("<"); + for (const auto& o : func.generics) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + } + for (const auto& o : func.genericPacks) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + } + writer.symbol(">"); + } + + writer.symbol("("); + CommaSeparatorInserter comma(writer); + + for (size_t i = 0; i < func.args.size; ++i) + { + AstLocal* local = func.args.data[i]; + + comma(); + + advance(local->location.begin); + writer.identifier(local->name.value); + if (writeTypes && local->annotation) + { + writer.symbol(":"); + visualizeTypeAnnotation(*local->annotation); + } + } + + if (func.vararg) + { + comma(); + advance(func.varargLocation.begin); + writer.symbol("..."); + + if (func.varargAnnotation) + { + writer.symbol(":"); + visualizeTypePackAnnotation(*func.varargAnnotation, true); + } + } + + writer.symbol(")"); + + if (writeTypes && func.returnAnnotation) + { + writer.symbol(":"); + writer.space(); + + visualizeTypeList(*func.returnAnnotation, false); + } + + visualizeBlock(*func.body); + advance(func.body->location.end); + writer.keyword("end"); + } + + void visualizeBlock(AstStatBlock& block) + { + for (const auto& s : block.body) + visualize(*s); + writer.advance(block.location.end); + } + + void visualizeBlock(AstStat& stat) + { + if (AstStatBlock* block = stat.as()) + visualizeBlock(*block); + else + LUAU_ASSERT(!"visualizeBlock was expecting an AstStatBlock"); + } + + void visualizeElseIf(AstStatIf& elseif) + { + visualize(*elseif.condition); + if (elseif.thenLocation) + advance(elseif.thenLocation->begin); + writer.keyword("then"); + visualizeBlock(*elseif.thenbody); + + if (elseif.elsebody == nullptr) + { + advance(elseif.thenbody->location.end); + writer.keyword("end"); + } + else if (auto elseifelseif = elseif.elsebody->as()) + { + if (elseif.elseLocation) + advance(elseif.elseLocation->begin); + writer.keyword("elseif"); + visualizeElseIf(*elseifelseif); + } + else + { + if (elseif.elseLocation) + advance(elseif.elseLocation->begin); + writer.keyword("else"); + + visualizeBlock(*elseif.elsebody); + advance(elseif.elsebody->location.end); + writer.keyword("end"); + } + } + + void visualizeElseIfExpr(AstExprIfElse& elseif) + { + const auto cstNode = lookupCstNode(&elseif); + + visualize(*elseif.condition); + if (cstNode) + advance(cstNode->thenPosition); + writer.keyword("then"); + visualize(*elseif.trueExpr); + + if (elseif.falseExpr) + { + if (cstNode) + advance(cstNode->elsePosition); + if (auto elseifelseif = elseif.falseExpr->as(); elseifelseif && (!cstNode || cstNode->isElseIf)) + { + writer.keyword("elseif"); + visualizeElseIfExpr(*elseifelseif); + } + else + { + writer.keyword("else"); + visualize(*elseif.falseExpr); + } + } + } + + void visualizeTypeAnnotation(AstType& typeAnnotation) + { + advance(typeAnnotation.location.begin); + if (const auto& a = typeAnnotation.as()) + { + const auto cstNode = lookupCstNode(a); + + if (a->prefix) + { + writer.write(a->prefix->value); + if (cstNode) + advance(*cstNode->prefixPointPosition); + writer.symbol("."); + } + + advance(a->nameLocation.begin); + writer.write(a->name.value); + if (a->parameters.size > 0 || a->hasParameterList) + { + CommaSeparatorInserter comma(writer, cstNode ? cstNode->parametersCommaPositions.begin() : nullptr); + if (cstNode) + advance(cstNode->openParametersPosition); + writer.symbol("<"); + for (auto o : a->parameters) + { + comma(); + + if (o.type) + visualizeTypeAnnotation(*o.type); + else + visualizeTypePackAnnotation(*o.typePack, false); + } + if (cstNode) + advance(cstNode->closeParametersPosition); + writer.symbol(">"); + } + } + else if (const auto& a = typeAnnotation.as()) + { + if (a->generics.size > 0 || a->genericPacks.size > 0) + { + CommaSeparatorInserter comma(writer); + writer.symbol("<"); + for (const auto& o : a->generics) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + } + for (const auto& o : a->genericPacks) + { + comma(); + + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + } + writer.symbol(">"); + } + + { + visualizeTypeList(a->argTypes, true); + } + + writer.symbol("->"); + visualizeTypeList(a->returnTypes, true); + } + else if (const auto& a = typeAnnotation.as()) + { + AstTypeReference* indexType = a->indexer ? a->indexer->indexType->as() : nullptr; + + writer.symbol("{"); + + const auto cstNode = lookupCstNode(a); + if (cstNode) + { + if (cstNode->isArray) + { + LUAU_ASSERT(a->props.size == 0 && indexType && indexType->name == "number"); + if (a->indexer->accessLocation) + { + LUAU_ASSERT(a->indexer->access != AstTableAccess::ReadWrite); + advance(a->indexer->accessLocation->begin); + writer.keyword(a->indexer->access == AstTableAccess::Read ? "read" : "write"); + } + visualizeTypeAnnotation(*a->indexer->resultType); + } + else + { + const AstTableProp* prop = a->props.begin(); + + for (size_t i = 0; i < cstNode->items.size; ++i) + { + CstTypeTable::Item item = cstNode->items.data[i]; + // we store indexer as part of items to preserve property ordering + if (item.kind == CstTypeTable::Item::Kind::Indexer) + { + LUAU_ASSERT(a->indexer); + + if (a->indexer->accessLocation) + { + LUAU_ASSERT(a->indexer->access != AstTableAccess::ReadWrite); + advance(a->indexer->accessLocation->begin); + writer.keyword(a->indexer->access == AstTableAccess::Read ? "read" : "write"); + } + + advance(item.indexerOpenPosition); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + advance(item.indexerClosePosition); + writer.symbol("]"); + advance(item.colonPosition); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + + if (item.separator) + { + LUAU_ASSERT(item.separatorPosition); + advance(*item.separatorPosition); + if (item.separator == CstExprTable::Comma) + writer.symbol(","); + else if (item.separator == CstExprTable::Semicolon) + writer.symbol(";"); + } + } + else + { + if (prop->accessLocation) + { + LUAU_ASSERT(prop->access != AstTableAccess::ReadWrite); + advance(prop->accessLocation->begin); + writer.keyword(prop->access == AstTableAccess::Read ? "read" : "write"); + } + + if (item.kind == CstTypeTable::Item::Kind::StringProperty) + { + advance(item.indexerOpenPosition); + writer.symbol("["); + writer.sourceString( + std::string_view(item.stringInfo->sourceString.data, item.stringInfo->sourceString.size), + item.stringInfo->quoteStyle, + item.stringInfo->blockDepth + ); + advance(item.indexerClosePosition); + writer.symbol("]"); + } + else + { + advance(prop->location.begin); + writer.identifier(prop->name.value); + } + + advance(item.colonPosition); + writer.symbol(":"); + visualizeTypeAnnotation(*prop->type); + + if (item.separator) + { + LUAU_ASSERT(item.separatorPosition); + advance(*item.separatorPosition); + if (item.separator == CstExprTable::Comma) + writer.symbol(","); + else if (item.separator == CstExprTable::Semicolon) + writer.symbol(";"); + } + + ++prop; + } + } + } + } + else + { + if (a->props.size == 0 && indexType && indexType->name == "number") + { + visualizeTypeAnnotation(*a->indexer->resultType); + } + else + { + CommaSeparatorInserter comma(writer); + + for (size_t i = 0; i < a->props.size; ++i) + { + comma(); + advance(a->props.data[i].location.begin); + writer.identifier(a->props.data[i].name.value); + if (a->props.data[i].type) + { + writer.symbol(":"); + visualizeTypeAnnotation(*a->props.data[i].type); + } + } + if (a->indexer) + { + comma(); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + writer.symbol("]"); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + } + } + } + + Position endPos = a->location.end; + if (endPos.column > 0) + --endPos.column; + advance(endPos); + + writer.symbol("}"); + } + else if (auto a = typeAnnotation.as()) + { + const auto cstNode = lookupCstNode(a); + writer.keyword("typeof"); + if (cstNode) + advance(cstNode->openPosition); + writer.symbol("("); + visualize(*a->expr); + if (cstNode) + advance(cstNode->closePosition); + writer.symbol(")"); + } + else if (const auto& a = typeAnnotation.as()) + { + if (a->types.size == 2) + { + AstType* l = a->types.data[0]; + AstType* r = a->types.data[1]; + + auto lta = l->as(); + if (lta && lta->name == "nil") + std::swap(l, r); + + // it's still possible that we had a (T | U) or (T | nil) and not (nil | T) + auto rta = r->as(); + if (rta && rta->name == "nil") + { + bool wrap = l->as() || l->as(); + + if (wrap) + writer.symbol("("); + + visualizeTypeAnnotation(*l); + + if (wrap) + writer.symbol(")"); + + writer.symbol("?"); + return; + } + } + + for (size_t i = 0; i < a->types.size; ++i) + { + if (i > 0) + { + writer.maybeSpace(a->types.data[i]->location.begin, 2); + writer.symbol("|"); + } + + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); + } + } + else if (const auto& a = typeAnnotation.as()) + { + for (size_t i = 0; i < a->types.size; ++i) + { + if (i > 0) + { + writer.maybeSpace(a->types.data[i]->location.begin, 2); + writer.symbol("&"); + } + + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); + } + } + else if (const auto& a = typeAnnotation.as()) + { + writer.symbol("("); + visualizeTypeAnnotation(*a->type); + writer.symbol(")"); + } + else if (const auto& a = typeAnnotation.as()) + { + writer.keyword(a->value ? "true" : "false"); + } + else if (const auto& a = typeAnnotation.as()) + { + if (const auto cstNode = lookupCstNode(a)) + { + writer.sourceString( + std::string_view(cstNode->sourceString.data, cstNode->sourceString.size), cstNode->quoteStyle, cstNode->blockDepth + ); + } + else + writer.string(std::string_view(a->value.data, a->value.size)); + } + else if (typeAnnotation.is()) + { + writer.symbol("%error-type%"); + } + else + { + LUAU_ASSERT(!"Unknown AstType"); + } + } +}; + std::string toString(AstNode* node) { StringWriter writer; writer.pos = node->location.begin; - Printer printer(writer); - printer.writeTypes = true; + if (FFlag::LuauStoreCSTData) + { + Printer printer(writer, CstNodeMap{nullptr}); + printer.writeTypes = true; - if (auto statNode = node->asStat()) - printer.visualize(*statNode); - else if (auto exprNode = node->asExpr()) - printer.visualize(*exprNode); - else if (auto typeNode = node->asType()) - printer.visualizeTypeAnnotation(*typeNode); + if (auto statNode = node->asStat()) + printer.visualize(*statNode); + else if (auto exprNode = node->asExpr()) + printer.visualize(*exprNode); + else if (auto typeNode = node->asType()) + printer.visualizeTypeAnnotation(*typeNode); + } + else + { + Printer_DEPRECATED printer(writer); + printer.writeTypes = true; + + if (auto statNode = node->asStat()) + printer.visualize(*statNode); + else if (auto exprNode = node->asExpr()) + printer.visualize(*exprNode); + else if (auto typeNode = node->asType()) + printer.visualizeTypeAnnotation(*typeNode); + } return writer.str(); } @@ -1206,24 +2629,48 @@ void dump(AstNode* node) printf("%s\n", toString(node).c_str()); } -std::string transpile(AstStatBlock& block) +std::string transpile(AstStatBlock& block, const CstNodeMap& cstNodeMap) { StringWriter writer; - Printer(writer).visualizeBlock(block); + if (FFlag::LuauStoreCSTData) + { + Printer(writer, cstNodeMap).visualizeBlock(block); + } + else + { + Printer_DEPRECATED(writer).visualizeBlock(block); + } + return writer.str(); +} + +std::string transpileWithTypes(AstStatBlock& block, const CstNodeMap& cstNodeMap) +{ + StringWriter writer; + if (FFlag::LuauStoreCSTData) + { + Printer printer(writer, cstNodeMap); + printer.writeTypes = true; + printer.visualizeBlock(block); + } + else + { + Printer_DEPRECATED printer(writer); + printer.writeTypes = true; + printer.visualizeBlock(block); + } return writer.str(); } std::string transpileWithTypes(AstStatBlock& block) { - StringWriter writer; - Printer printer(writer); - printer.writeTypes = true; - printer.visualizeBlock(block); - return writer.str(); + // TODO: remove this interface? + return transpileWithTypes(block, CstNodeMap{nullptr}); } TranspileResult transpile(std::string_view source, ParseOptions options, bool withTypes) { + options.storeCstData = true; + auto allocator = Allocator{}; auto names = AstNameTable{allocator}; ParseResult parseResult = Parser::parse(source.data(), source.size(), names, allocator, options); @@ -1241,9 +2688,9 @@ TranspileResult transpile(std::string_view source, ParseOptions options, bool wi return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"}; if (withTypes) - return TranspileResult{transpileWithTypes(*parseResult.root)}; + return TranspileResult{transpileWithTypes(*parseResult.root, parseResult.cstNodeMap)}; - return TranspileResult{transpile(*parseResult.root)}; + return TranspileResult{transpile(*parseResult.root, parseResult.cstNodeMap)}; } } // namespace Luau diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index aee91ec3..bb08856c 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAGVARIABLE(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -478,24 +479,12 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } -FreeType::FreeType(TypeLevel level) +// New constructors +FreeType::FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound) : index(Unifiable::freshIndex()) , level(level) - , scope(nullptr) -{ -} - -FreeType::FreeType(Scope* scope) - : index(Unifiable::freshIndex()) - , level{} - , scope(scope) -{ -} - -FreeType::FreeType(Scope* scope, TypeLevel level) - : index(Unifiable::freshIndex()) - , level(level) - , scope(scope) + , lowerBound(lowerBound) + , upperBound(upperBound) { } @@ -507,6 +496,40 @@ FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound) { } +FreeType::FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) + , lowerBound(lowerBound) + , upperBound(upperBound) +{ +} + +// Old constructors +FreeType::FreeType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ + LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds); +} + +FreeType::FreeType(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ + LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds); +} + +FreeType::FreeType(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ + LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds); +} + GenericType::GenericType() : index(Unifiable::freshIndex()) , name("g" + std::to_string(index)) diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index 617bd305..e4e9e293 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -3,6 +3,7 @@ #include "Luau/TypeArena.h" LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena); +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -22,7 +23,34 @@ TypeId TypeArena::addTV(Type&& tv) return allocated; } -TypeId TypeArena::freshType(TypeLevel level) +TypeId TypeArena::freshType(NotNull builtins, TypeLevel level) +{ + TypeId allocated = types.allocate(FreeType{level, builtins->neverType, builtins->unknownType}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(NotNull builtins, Scope* scope) +{ + TypeId allocated = types.allocate(FreeType{scope, builtins->neverType, builtins->unknownType}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(NotNull builtins, Scope* scope, TypeLevel level) +{ + TypeId allocated = types.allocate(FreeType{scope, level, builtins->neverType, builtins->unknownType}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType_DEPRECATED(TypeLevel level) { TypeId allocated = types.allocate(FreeType{level}); @@ -31,7 +59,7 @@ TypeId TypeArena::freshType(TypeLevel level) return allocated; } -TypeId TypeArena::freshType(Scope* scope) +TypeId TypeArena::freshType_DEPRECATED(Scope* scope) { TypeId allocated = types.allocate(FreeType{scope}); @@ -40,7 +68,7 @@ TypeId TypeArena::freshType(Scope* scope) return allocated; } -TypeId TypeArena::freshType(Scope* scope, TypeLevel level) +TypeId TypeArena::freshType_DEPRECATED(Scope* scope, TypeLevel level) { TypeId allocated = types.allocate(FreeType{scope, level}); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index fb312da9..32c0f4db 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -31,6 +31,7 @@ LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(InferGlobalTypes) LUAU_FASTFLAGVARIABLE(LuauTableKeysAreRValues) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -2105,7 +2106,10 @@ TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey) } else { - expectedRets = module->internalTypes.addTypePack({module->internalTypes.freshType(scope, TypeLevel{})}); + expectedRets = module->internalTypes.addTypePack( + {FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, scope, TypeLevel{}) + : module->internalTypes.freshType_DEPRECATED(scope, TypeLevel{})} + ); } TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets)); @@ -2357,7 +2361,8 @@ TypeId TypeChecker2::flattenPack(TypePackId pack) return *fst; else if (auto ftp = get(pack)) { - TypeId result = module->internalTypes.addType(FreeType{ftp->scope}); + TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, ftp->scope) + : module->internalTypes.addType(FreeType{ftp->scope}); TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); TypePack* resultPack = emplaceTypePack(asMutable(pack)); @@ -2419,6 +2424,8 @@ void TypeChecker2::visit(AstType* ty) return visit(t); else if (auto t = ty->as()) return visit(t); + else if (auto t = ty->as()) + return visit(t->type); } void TypeChecker2::visit(AstTypeReference* ty) diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index 9af1599d..a5f69460 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -47,7 +47,9 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies) LUAU_FASTFLAG(DebugLuauEqSatSimplification) -LUAU_FASTFLAG(LuauRemoveNotAnyHack) +LUAU_FASTFLAGVARIABLE(LuauMetatableTypeFunctions) +LUAU_FASTFLAGVARIABLE(LuauClipNestedAndRecursiveUnion) +LUAU_FASTFLAGVARIABLE(LuauDoNotGeneralizeInTypeFunctions) namespace Luau { @@ -825,7 +827,7 @@ TypeFunctionReductionResult lenTypeFunction( return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy, /* avoidSealingTables */ true); if (!maybeGeneralized) @@ -917,7 +919,7 @@ TypeFunctionReductionResult unmTypeFunction( return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy); if (!maybeGeneralized) @@ -1030,7 +1032,7 @@ std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunc AstStat* stmtArray[] = {&stmtReturn}; AstArray stmts{stmtArray, 1}; AstStatBlock exec{Location{}, stmts}; - ParseResult parseResult{&exec, 1}; + ParseResult parseResult{&exec, 1, {}, {}, {}, CstNodeMap{nullptr}}; BytecodeBuilder builder; try @@ -1160,7 +1162,7 @@ TypeFunctionReductionResult numericBinopTypeFunction( return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); @@ -1397,7 +1399,7 @@ TypeFunctionReductionResult concatTypeFunction( return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); @@ -1512,7 +1514,7 @@ TypeFunctionReductionResult andTypeFunction( return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); @@ -1567,7 +1569,7 @@ TypeFunctionReductionResult orTypeFunction( return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); @@ -1653,7 +1655,7 @@ static TypeFunctionReductionResult comparisonTypeFunction( rhsTy = follow(rhsTy); // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); @@ -1791,7 +1793,7 @@ TypeFunctionReductionResult eqTypeFunction( return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); @@ -1936,7 +1938,7 @@ TypeFunctionReductionResult refineTypeFunction( auto stepRefine = [&ctx](TypeId target, TypeId discriminant) -> std::pair> { std::vector toBlock; - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional targetMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, target); std::optional discriminantMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, discriminant); @@ -1988,16 +1990,8 @@ TypeFunctionReductionResult refineTypeFunction( */ if (auto nt = get(discriminant)) { - if (FFlag::LuauRemoveNotAnyHack) - { - if (get(follow(nt->ty))) - return {target, {}}; - } - else - { - if (get(follow(nt->ty))) - return {target, {}}; - } + if (get(follow(nt->ty))) + return {target, {}}; } // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the @@ -2070,7 +2064,7 @@ TypeFunctionReductionResult singletonTypeFunction( return {std::nullopt, Reduction::MaybeOk, {type}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, type); if (!maybeGeneralized) @@ -2091,6 +2085,43 @@ TypeFunctionReductionResult singletonTypeFunction( return {ctx->builtins->unknownType, Reduction::MaybeOk, {}, {}}; } +struct CollectUnionTypeOptions : TypeOnceVisitor +{ + NotNull ctx; + DenseHashSet options{nullptr}; + DenseHashSet blockingTypes{nullptr}; + + explicit CollectUnionTypeOptions(NotNull ctx) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , ctx(ctx) + { + } + + bool visit(TypeId ty) override + { + options.insert(ty); + if (isPending(ty, ctx->solver)) + blockingTypes.insert(ty); + return false; + } + + bool visit(TypePackId tp) override + { + return false; + } + + bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) override + { + if (tfit.function->name != builtinTypeFunctions().unionFunc.name) + { + options.insert(ty); + blockingTypes.insert(ty); + return false; + } + return true; + } +}; + TypeFunctionReductionResult unionTypeFunction( TypeId instance, const std::vector& typeParams, @@ -2108,6 +2139,35 @@ TypeFunctionReductionResult unionTypeFunction( if (typeParams.size() == 1) return {follow(typeParams[0]), Reduction::MaybeOk, {}, {}}; + if (FFlag::LuauClipNestedAndRecursiveUnion) + { + + CollectUnionTypeOptions collector{ctx}; + collector.traverse(instance); + + if (!collector.blockingTypes.empty()) + { + std::vector blockingTypes{collector.blockingTypes.begin(), collector.blockingTypes.end()}; + return {std::nullopt, Reduction::MaybeOk, std::move(blockingTypes), {}}; + } + + TypeId resultTy = ctx->builtins->neverType; + for (auto ty : collector.options) + { + SimplifyResult result = simplifyUnion(ctx->builtins, ctx->arena, resultTy, ty); + // This condition might fire if one of the arguments to this type + // function is a free type somewhere deep in a nested union or + // intersection type, even though we ran a pass above to capture + // some blocked types. + if (!result.blockedTypes.empty()) + return {std::nullopt, Reduction::MaybeOk, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + resultTy = result.result; + } + + return {resultTy, Reduction::MaybeOk, {}, {}}; + } + // we need to follow all of the type parameters. std::vector types; types.reserve(typeParams.size()); @@ -2179,14 +2239,11 @@ TypeFunctionReductionResult intersectTypeFunction( for (auto ty : typeParams) types.emplace_back(follow(ty)); - if (FFlag::LuauRemoveNotAnyHack) - { - // if we only have two parameters and one is `*no-refine*`, we're all done. - if (types.size() == 2 && get(types[1])) - return {types[0], Reduction::MaybeOk, {}, {}}; - else if (types.size() == 2 && get(types[0])) - return {types[1], Reduction::MaybeOk, {}, {}}; - } + // if we only have two parameters and one is `*no-refine*`, we're all done. + if (types.size() == 2 && get(types[1])) + return {types[0], Reduction::MaybeOk, {}, {}}; + else if (types.size() == 2 && get(types[0])) + return {types[1], Reduction::MaybeOk, {}, {}}; // check to see if the operand types are resolved enough, and wait to reduce if not // if any of them are `never`, the intersection will always be `never`, so we can reduce directly. @@ -2203,7 +2260,7 @@ TypeFunctionReductionResult intersectTypeFunction( for (auto ty : types) { // skip any `*no-refine*` types. - if (FFlag::LuauRemoveNotAnyHack && get(ty)) + if (get(ty)) continue; SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, resultTy, ty); @@ -2722,6 +2779,215 @@ TypeFunctionReductionResult rawgetTypeFunction( return indexFunctionImpl(typeParams, packParams, ctx, /* isRaw */ true); } +TypeFunctionReductionResult setmetatableTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("setmetatable type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + const Location location = ctx->constraint ? ctx->constraint->location : Location{}; + + TypeId targetTy = follow(typeParams.at(0)); + TypeId metatableTy = follow(typeParams.at(1)); + + std::shared_ptr targetNorm = ctx->normalizer->normalize(targetTy); + + // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!targetNorm) + return {std::nullopt, Reduction::MaybeOk, {}, {}}; + + // cannot setmetatable on something without table parts. + if (!targetNorm->hasTables()) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + // we're trying to reject any type that has not normalized to a table or a union/intersection of tables. + if (targetNorm->hasTops() || targetNorm->hasBooleans() || targetNorm->hasErrors() || targetNorm->hasNils() || + targetNorm->hasNumbers() || targetNorm->hasStrings() || targetNorm->hasThreads() || targetNorm->hasBuffers() || + targetNorm->hasFunctions() || targetNorm->hasTyvars() || targetNorm->hasClasses()) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + // if the supposed metatable is not a table, we will fail to reduce. + if (!get(metatableTy) && !get(metatableTy)) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + if (targetNorm->tables.size() == 1) + { + TypeId table = *targetNorm->tables.begin(); + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional metatableMetamethod = findMetatableEntry(ctx->builtins, dummy, table, "__metatable", location); + + // if the `__metatable` metamethod is present, then the table is locked and we cannot `setmetatable` on it. + if (metatableMetamethod) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + TypeId withMetatable = ctx->arena->addType(MetatableType{table, metatableTy}); + + return {withMetatable, Reduction::MaybeOk, {}, {}}; + } + + TypeId result = ctx->builtins->neverType; + + for (auto componentTy : targetNorm->tables) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional metatableMetamethod = findMetatableEntry(ctx->builtins, dummy, componentTy, "__metatable", location); + + // if the `__metatable` metamethod is present, then the table is locked and we cannot `setmetatable` on it. + if (metatableMetamethod) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + TypeId withMetatable = ctx->arena->addType(MetatableType{componentTy, metatableTy}); + SimplifyResult simplified = simplifyUnion(ctx->builtins, ctx->arena, result, withMetatable); + + if (!simplified.blockedTypes.empty()) + { + std::vector blockedTypes{}; + blockedTypes.reserve(simplified.blockedTypes.size()); + for (auto ty : simplified.blockedTypes) + blockedTypes.push_back(ty); + return {std::nullopt, Reduction::MaybeOk, blockedTypes, {}}; + } + + result = simplified.result; + } + + return {result, Reduction::MaybeOk, {}, {}}; +} + +static TypeFunctionReductionResult getmetatableHelper( + TypeId targetTy, + const Location& location, + NotNull ctx +) +{ + targetTy = follow(targetTy); + + std::optional metatable = std::nullopt; + bool erroneous = true; + + if (auto table = get(targetTy)) + erroneous = false; + + if (auto mt = get(targetTy)) + { + metatable = mt->metatable; + erroneous = false; + } + + if (auto clazz = get(targetTy)) + { + metatable = clazz->metatable; + erroneous = false; + } + + if (auto primitive = get(targetTy)) + { + metatable = primitive->metatable; + erroneous = false; + } + + if (auto singleton = get(targetTy)) + { + if (get(singleton)) + { + auto primitiveString = get(ctx->builtins->stringType); + metatable = primitiveString->metatable; + } + erroneous = false; + } + + if (erroneous) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional metatableMetamethod = findMetatableEntry(ctx->builtins, dummy, targetTy, "__metatable", location); + + if (metatableMetamethod) + return {metatableMetamethod, Reduction::MaybeOk, {}, {}}; + + if (metatable) + return {metatable, Reduction::MaybeOk, {}, {}}; + + return {ctx->builtins->nilType, Reduction::MaybeOk, {}, {}}; +} + +TypeFunctionReductionResult getmetatableTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("getmetatable type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + const Location location = ctx->constraint ? ctx->constraint->location : Location{}; + + TypeId targetTy = follow(typeParams.at(0)); + + if (isPending(targetTy, ctx->solver)) + return {std::nullopt, Reduction::MaybeOk, {targetTy}, {}}; + + if (auto ut = get(targetTy)) + { + std::vector options{}; + options.reserve(ut->options.size()); + + for (auto option : ut->options) + { + TypeFunctionReductionResult result = getmetatableHelper(option, location, ctx); + + if (!result.result) + return result; + + options.push_back(*result.result); + } + + return {ctx->arena->addType(UnionType{std::move(options)}), Reduction::MaybeOk, {}, {}}; + } + + if (auto it = get(targetTy)) + { + std::vector parts{}; + parts.reserve(it->parts.size()); + + for (auto part : it->parts) + { + TypeFunctionReductionResult result = getmetatableHelper(part, location, ctx); + + if (!result.result) + return result; + + parts.push_back(*result.result); + } + + return {ctx->arena->addType(IntersectionType{std::move(parts)}), Reduction::MaybeOk, {}, {}}; + } + + return getmetatableHelper(targetTy, location, ctx); +} + + BuiltinTypeFunctions::BuiltinTypeFunctions() : userFunc{"user", userDefinedTypeFunction} , notFunc{"not", notTypeFunction} @@ -2748,6 +3014,8 @@ BuiltinTypeFunctions::BuiltinTypeFunctions() , rawkeyofFunc{"rawkeyof", rawkeyofTypeFunction} , indexFunc{"index", indexTypeFunction} , rawgetFunc{"rawget", rawgetTypeFunction} + , setmetatableFunc{"setmetatable", setmetatableTypeFunction} + , getmetatableFunc{"getmetatable", getmetatableTypeFunction} { } @@ -2794,6 +3062,12 @@ void BuiltinTypeFunctions::addToScope(NotNull arena, NotNull s scope->exportedTypeBindings[indexFunc.name] = mkBinaryTypeFunction(&indexFunc); scope->exportedTypeBindings[rawgetFunc.name] = mkBinaryTypeFunction(&rawgetFunc); + + if (FFlag::LuauMetatableTypeFunctions) + { + scope->exportedTypeBindings[setmetatableFunc.name] = mkBinaryTypeFunction(&setmetatableFunc); + scope->exportedTypeBindings[getmetatableFunc.name] = mkUnaryTypeFunction(&getmetatableFunc); + } } const BuiltinTypeFunctions& builtinTypeFunctions() diff --git a/Analysis/src/TypeFunctionRuntime.cpp b/Analysis/src/TypeFunctionRuntime.cpp index de1302fc..c5c54477 100644 --- a/Analysis/src/TypeFunctionRuntime.cpp +++ b/Analysis/src/TypeFunctionRuntime.cpp @@ -15,7 +15,6 @@ LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixInner) -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixNoReadWrite) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunGenerics) LUAU_FASTFLAGVARIABLE(LuauUserTypeFunCloneTail) @@ -683,10 +682,8 @@ static int readTableProp(lua_State* L) auto prop = tftt->props.at(tfsst->value); if (prop.readTy) allocTypeUserData(L, (*prop.readTy)->type); - else if (FFlag::LuauUserTypeFunFixNoReadWrite) - lua_pushnil(L); else - luaL_error(L, "type.readproperty: property %s is write-only, and therefore does not have a read type.", tfsst->value.c_str()); + lua_pushnil(L); return 1; } @@ -723,10 +720,8 @@ static int writeTableProp(lua_State* L) auto prop = tftt->props.at(tfsst->value); if (prop.writeTy) allocTypeUserData(L, (*prop.writeTy)->type); - else if (FFlag::LuauUserTypeFunFixNoReadWrite) - lua_pushnil(L); else - luaL_error(L, "type.writeproperty: property %s is read-only, and therefore does not have a write type.", tfsst->value.c_str()); + lua_pushnil(L); return 1; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4a243856..25d1f5c2 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,6 +33,8 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauOldSolverCreatesChildScopePointers) +LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -761,8 +763,12 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& state struct Demoter : Substitution { - Demoter(TypeArena* arena) + TypeArena* arena = nullptr; + NotNull builtins; + Demoter(TypeArena* arena, NotNull builtins) : Substitution(TxnLog::empty(), arena) + , arena(arena) + , builtins(builtins) { } @@ -788,7 +794,8 @@ struct Demoter : Substitution { auto ftv = get(ty); LUAU_ASSERT(ftv); - return addType(FreeType{demotedLevel(ftv->level)}); + return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtins, demotedLevel(ftv->level)) + : addType(FreeType{demotedLevel(ftv->level)}); } TypePackId clean(TypePackId tp) override @@ -835,7 +842,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& retur } } - Demoter demoter{¤tModule->internalTypes}; + Demoter demoter{¤tModule->internalTypes, builtinTypes}; demoter.demote(expectedTypes); TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; @@ -4408,7 +4415,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } } - Demoter demoter{¤tModule->internalTypes}; + Demoter demoter{¤tModule->internalTypes, builtinTypes}; demoter.demote(expectedTypes); return expectedTypes; @@ -5273,7 +5280,8 @@ TypeId TypeChecker::freshType(const ScopePtr& scope) TypeId TypeChecker::freshType(TypeLevel level) { - return currentModule->internalTypes.addType(Type(FreeType(level))); + return FFlag::LuauFreeTypesMustHaveBounds ? currentModule->internalTypes.freshType(builtinTypes, level) + : currentModule->internalTypes.addType(Type(FreeType(level))); } TypeId TypeChecker::singletonType(bool value) @@ -5718,6 +5726,12 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno } else if (const auto& un = annotation.as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (un->types.size == 1) + return resolveType(scope, *un->types.data[0]); + } + std::vector types; for (AstType* ann : un->types) types.push_back(resolveType(scope, *ann)); @@ -5726,12 +5740,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno } else if (const auto& un = annotation.as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (un->types.size == 1) + return resolveType(scope, *un->types.data[0]); + } + std::vector types; for (AstType* ann : un->types) types.push_back(resolveType(scope, *ann)); return addType(IntersectionType{types}); } + else if (const auto& g = annotation.as()) + { + return resolveType(scope, *g->type); + } else if (const auto& tsb = annotation.as()) { return singletonType(tsb->value); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 6a562a3a..bb68503f 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -5,6 +5,7 @@ #include "Luau/Normalize.h" #include "Luau/Scope.h" #include "Luau/ToString.h" +#include "Luau/Type.h" #include "Luau/TypeInfer.h" #include @@ -12,6 +13,7 @@ LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete); LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope); +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -323,7 +325,7 @@ TypePack extendTypePack( trackInteriorFreeType(ftp->scope, t); } else - t = arena.freshType(ftp->scope); + t = FFlag::LuauFreeTypesMustHaveBounds ? arena.freshType(builtinTypes, ftp->scope) : arena.freshType_DEPRECATED(ftp->scope); } newPack.head.push_back(t); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 5d71d5cb..926245ea 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering) LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -1648,7 +1649,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (FFlag::LuauSolverV2) return freshType(NotNull{types}, builtinTypes, scope); else - return types->freshType(scope, level); + return FFlag::LuauFreeTypesMustHaveBounds ? types->freshType(builtinTypes, scope, level) : types->freshType_DEPRECATED(scope, level); }; const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 736f24a2..d4764656 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -1204,6 +1204,18 @@ public: const AstArray value; }; +class AstTypeGroup : public AstType +{ +public: + LUAU_RTTI(AstTypeGroup) + + explicit AstTypeGroup(const Location& location, AstType* type); + + void visit(AstVisitor* visitor) override; + + AstType* type; +}; + class AstTypePack : public AstNode { public: @@ -1470,6 +1482,10 @@ public: { return visit(static_cast(node)); } + virtual bool visit(class AstTypeGroup* node) + { + return visit(static_cast(node)); + } virtual bool visit(class AstTypeError* node) { return visit(static_cast(node)); diff --git a/Ast/include/Luau/Cst.h b/Ast/include/Luau/Cst.h new file mode 100644 index 00000000..bea3df90 --- /dev/null +++ b/Ast/include/Luau/Cst.h @@ -0,0 +1,334 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" + +#include + +namespace Luau +{ + +extern int gCstRttiIndex; + +template +struct CstRtti +{ + static const int value; +}; + +template +const int CstRtti::value = ++gCstRttiIndex; + +#define LUAU_CST_RTTI(Class) \ + static int CstClassIndex() \ + { \ + return CstRtti::value; \ + } + +class CstNode +{ +public: + explicit CstNode(int classIndex) + : classIndex(classIndex) + { + } + + template + bool is() const + { + return classIndex == T::CstClassIndex(); + } + template + T* as() + { + return classIndex == T::CstClassIndex() ? static_cast(this) : nullptr; + } + template + const T* as() const + { + return classIndex == T::CstClassIndex() ? static_cast(this) : nullptr; + } + + const int classIndex; +}; + +class CstExprConstantNumber : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprConstantNumber) + + explicit CstExprConstantNumber(const AstArray& value); + + AstArray value; +}; + +class CstExprConstantString : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprConstantNumber) + + enum QuoteStyle + { + QuotedSingle, + QuotedDouble, + QuotedRaw, + QuotedInterp, + }; + + CstExprConstantString(AstArray sourceString, QuoteStyle quoteStyle, unsigned int blockDepth); + + AstArray sourceString; + QuoteStyle quoteStyle; + unsigned int blockDepth; +}; + +class CstExprCall : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprCall) + + CstExprCall(std::optional openParens, std::optional closeParens, AstArray commaPositions); + + std::optional openParens; + std::optional closeParens; + AstArray commaPositions; +}; + +class CstExprIndexExpr : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprIndexExpr) + + CstExprIndexExpr(Position openBracketPosition, Position closeBracketPosition); + + Position openBracketPosition; + Position closeBracketPosition; +}; + +class CstExprTable : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprTable) + + enum Separator + { + Comma, + Semicolon, + }; + + struct Item + { + std::optional indexerOpenPosition; // '[', only if Kind == General + std::optional indexerClosePosition; // ']', only if Kind == General + std::optional equalsPosition; // only if Kind != List + std::optional separator; // may be missing for last Item + std::optional separatorPosition; + }; + + explicit CstExprTable(const AstArray& items); + + AstArray items; +}; + +// TODO: Shared between unary and binary, should we split? +class CstExprOp : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprOp) + + explicit CstExprOp(Position opPosition); + + Position opPosition; +}; + +class CstExprIfElse : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprIfElse) + + CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf); + + Position thenPosition; + Position elsePosition; + bool isElseIf; +}; + +class CstExprInterpString : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprInterpString) + + explicit CstExprInterpString(AstArray> sourceStrings, AstArray stringPositions); + + AstArray> sourceStrings; + AstArray stringPositions; +}; + +class CstStatDo : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatDo) + + explicit CstStatDo(Position endPosition); + + Position endPosition; +}; + +class CstStatRepeat : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatRepeat) + + explicit CstStatRepeat(Position untilPosition); + + Position untilPosition; +}; + +class CstStatReturn : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatReturn) + + explicit CstStatReturn(AstArray commaPositions); + + AstArray commaPositions; +}; + +class CstStatLocal : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatLocal) + + CstStatLocal(AstArray varsCommaPositions, AstArray valuesCommaPositions); + + AstArray varsCommaPositions; + AstArray valuesCommaPositions; +}; + +class CstStatFor : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatFor) + + CstStatFor(Position equalsPosition, Position endCommaPosition, std::optional stepCommaPosition); + + Position equalsPosition; + Position endCommaPosition; + std::optional stepCommaPosition; +}; + +class CstStatForIn : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatForIn) + + CstStatForIn(AstArray varsCommaPositions, AstArray valuesCommaPositions); + + AstArray varsCommaPositions; + AstArray valuesCommaPositions; +}; + +class CstStatAssign : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatAssign) + + CstStatAssign(AstArray varsCommaPositions, Position equalsPosition, AstArray valuesCommaPositions); + + AstArray varsCommaPositions; + Position equalsPosition; + AstArray valuesCommaPositions; +}; + +class CstStatCompoundAssign : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatCompoundAssign) + + explicit CstStatCompoundAssign(Position opPosition); + + Position opPosition; +}; + +class CstStatLocalFunction : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatLocalFunction) + + explicit CstStatLocalFunction(Position functionKeywordPosition); + + Position functionKeywordPosition; +}; + +class CstTypeReference : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeReference) + + CstTypeReference( + std::optional prefixPointPosition, + Position openParametersPosition, + AstArray parametersCommaPositions, + Position closeParametersPosition + ); + + std::optional prefixPointPosition; + Position openParametersPosition; + AstArray parametersCommaPositions; + Position closeParametersPosition; +}; + +class CstTypeTable : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeTable) + + struct Item + { + enum struct Kind + { + Indexer, + Property, + StringProperty, + }; + + Kind kind; + Position indexerOpenPosition; // '[', only if Kind != Property + Position indexerClosePosition; // ']' only if Kind != Property + Position colonPosition; + std::optional separator; // may be missing for last Item + std::optional separatorPosition; + + CstExprConstantString* stringInfo = nullptr; // only if Kind == StringProperty + }; + + CstTypeTable(AstArray items, bool isArray); + + AstArray items; + bool isArray = false; +}; + +class CstTypeTypeof : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeTypeof) + + CstTypeTypeof(Position openPosition, Position closePosition); + + Position openPosition; + Position closePosition; +}; + +class CstTypeSingletonString : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeSingletonString) + + CstTypeSingletonString(AstArray sourceString, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth); + + AstArray sourceString; + CstExprConstantString::QuoteStyle quoteStyle; + unsigned int blockDepth; +}; + +} // namespace Luau \ No newline at end of file diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index 3d93cf75..20814860 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -87,6 +87,12 @@ struct Lexeme Reserved_END }; + enum struct QuoteStyle + { + Single, + Double, + }; + Type type; Location location; @@ -111,6 +117,8 @@ public: Lexeme(const Location& location, Type type, const char* name); unsigned int getLength() const; + unsigned int getBlockDepth() const; + QuoteStyle getQuoteStyle() const; std::string toString() const; }; diff --git a/Ast/include/Luau/ParseOptions.h b/Ast/include/Luau/ParseOptions.h index ff727a0b..ac8e9348 100644 --- a/Ast/include/Luau/ParseOptions.h +++ b/Ast/include/Luau/ParseOptions.h @@ -29,6 +29,8 @@ struct ParseOptions bool allowDeclarationSyntax = false; bool captureComments = false; std::optional parseFragment = std::nullopt; + bool storeCstData = false; + bool noErrorLimit = false; }; } // namespace Luau diff --git a/Ast/include/Luau/ParseResult.h b/Ast/include/Luau/ParseResult.h index 9c0a9527..1ad9c5e9 100644 --- a/Ast/include/Luau/ParseResult.h +++ b/Ast/include/Luau/ParseResult.h @@ -10,6 +10,7 @@ namespace Luau { class AstStatBlock; +class CstNode; class ParseError : public std::exception { @@ -55,6 +56,8 @@ struct Comment Location location; }; +using CstNodeMap = DenseHashMap; + struct ParseResult { AstStatBlock* root; @@ -64,6 +67,8 @@ struct ParseResult std::vector errors; std::vector commentLocations; + + CstNodeMap cstNodeMap{nullptr}; }; static constexpr const char* kParseNameError = "%error-id%"; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index ce98f58e..584782ee 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -8,6 +8,7 @@ #include "Luau/StringUtils.h" #include "Luau/DenseHash.h" #include "Luau/Common.h" +#include "Luau/Cst.h" #include #include @@ -173,14 +174,18 @@ private: ); // explist ::= {exp `,'} exp - void parseExprList(TempVector& result); + void parseExprList(TempVector& result, TempVector* commaPositions = nullptr); // binding ::= Name [`:` Type] Binding parseBinding(); // bindinglist ::= (binding | `...') {`,' bindinglist} // Returns the location of the vararg ..., or std::nullopt if the function is not vararg. - std::tuple parseBindingList(TempVector& result, bool allowDot3 = false); + std::tuple parseBindingList( + TempVector& result, + bool allowDot3 = false, + TempVector* commaPositions = nullptr + ); AstType* parseOptionalType(); @@ -201,7 +206,17 @@ private: std::optional parseOptionalReturnType(); std::pair parseReturnType(); - AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional accessLocation); + struct TableIndexerResult + { + AstTableIndexer* node; + Position indexerOpenPosition; + Position indexerClosePosition; + Position colonPosition; + }; + + TableIndexerResult parseTableIndexer(AstTableAccess access, std::optional accessLocation); + // Remove with FFlagLuauStoreCSTData + AstTableIndexer* parseTableIndexer_DEPRECATED(AstTableAccess access, std::optional accessLocation); AstTypeOrPack parseFunctionType(bool allowPack, const AstArray& attributes); AstType* parseFunctionTypeTail( @@ -259,6 +274,8 @@ private: // args ::= `(' [explist] `)' | tableconstructor | String AstExpr* parseFunctionArgs(AstExpr* func, bool self); + std::optional tableSeparator(); + // tableconstructor ::= `{' [fieldlist] `}' // fieldlist ::= field {fieldsep field} [fieldsep] // field ::= `[' exp `]' `=' exp | Name `=' exp | exp @@ -280,9 +297,13 @@ private: std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); // `<' Type[, ...] `>' - AstArray parseTypeParams(); + AstArray parseTypeParams( + Position* openingPosition = nullptr, + TempVector* commaPositions = nullptr, + Position* closingPosition = nullptr + ); - std::optional> parseCharArray(); + std::optional> parseCharArray(AstArray* originalString = nullptr); AstExpr* parseString(); AstExpr* parseNumber(); @@ -292,6 +313,9 @@ private: void restoreLocals(unsigned int offset); + /// Returns string quote style and block depth + std::pair extractStringDetails(); + // check that parser is at lexeme/symbol, move to next lexeme/symbol on success, report failure and continue on failure bool expectAndConsume(char value, const char* context = nullptr); bool expectAndConsume(Lexeme::Type type, const char* context = nullptr); @@ -435,6 +459,7 @@ private: std::vector scratchAttr; std::vector scratchStat; std::vector> scratchString; + std::vector> scratchString2; std::vector scratchExpr; std::vector scratchExprAux; std::vector scratchName; @@ -442,15 +467,20 @@ private: std::vector scratchBinding; std::vector scratchLocal; std::vector scratchTableTypeProps; + std::vector scratchCstTableTypeProps; std::vector scratchType; std::vector scratchTypeOrPack; std::vector scratchDeclaredClassProps; std::vector scratchItem; + std::vector scratchCstItem; std::vector scratchArgName; std::vector scratchGenericTypes; std::vector scratchGenericTypePacks; std::vector> scratchOptArgName; + std::vector scratchPosition; std::string scratchData; + + CstNodeMap cstNodeMap; }; } // namespace Luau diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 8e5befad..5fa63149 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -1091,6 +1091,18 @@ void AstTypeSingletonString::visit(AstVisitor* visitor) visitor->visit(this); } +AstTypeGroup::AstTypeGroup(const Location& location, AstType* type) + : AstType(ClassIndex(), location) + , type(type) +{ +} + +void AstTypeGroup::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + type->visit(visitor); +} + AstTypeError::AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex) : AstType(ClassIndex(), location) , types(types) diff --git a/Ast/src/Cst.cpp b/Ast/src/Cst.cpp new file mode 100644 index 00000000..e2faf6e7 --- /dev/null +++ b/Ast/src/Cst.cpp @@ -0,0 +1,169 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Ast.h" +#include "Luau/Cst.h" +#include "Luau/Common.h" + +namespace Luau +{ + +int gCstRttiIndex = 0; + +CstExprConstantNumber::CstExprConstantNumber(const AstArray& value) + : CstNode(CstClassIndex()) + , value(value) +{ +} + +CstExprConstantString::CstExprConstantString(AstArray sourceString, QuoteStyle quoteStyle, unsigned int blockDepth) + : CstNode(CstClassIndex()) + , sourceString(sourceString) + , quoteStyle(quoteStyle) + , blockDepth(blockDepth) +{ + LUAU_ASSERT(blockDepth == 0 || quoteStyle == QuoteStyle::QuotedRaw); +} + +CstExprCall::CstExprCall(std::optional openParens, std::optional closeParens, AstArray commaPositions) + : CstNode(CstClassIndex()) + , openParens(openParens) + , closeParens(closeParens) + , commaPositions(commaPositions) +{ +} + +CstExprIndexExpr::CstExprIndexExpr(Position openBracketPosition, Position closeBracketPosition) + : CstNode(CstClassIndex()) + , openBracketPosition(openBracketPosition) + , closeBracketPosition(closeBracketPosition) +{ +} + +CstExprTable::CstExprTable(const AstArray& items) + : CstNode(CstClassIndex()) + , items(items) +{ +} + +CstExprOp::CstExprOp(Position opPosition) + : CstNode(CstClassIndex()) + , opPosition(opPosition) +{ +} + +CstExprIfElse::CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf) + : CstNode(CstClassIndex()) + , thenPosition(thenPosition) + , elsePosition(elsePosition) + , isElseIf(isElseIf) +{ +} + +CstExprInterpString::CstExprInterpString(AstArray> sourceStrings, AstArray stringPositions) + : CstNode(CstClassIndex()) + , sourceStrings(sourceStrings) + , stringPositions(stringPositions) +{ +} + +CstStatDo::CstStatDo(Position endPosition) + : CstNode(CstClassIndex()) + , endPosition(endPosition) +{ +} + +CstStatRepeat::CstStatRepeat(Position untilPosition) + : CstNode(CstClassIndex()) + , untilPosition(untilPosition) +{ +} + +CstStatReturn::CstStatReturn(AstArray commaPositions) + : CstNode(CstClassIndex()) + , commaPositions(commaPositions) +{ +} + +CstStatLocal::CstStatLocal(AstArray varsCommaPositions, AstArray valuesCommaPositions) + : CstNode(CstClassIndex()) + , varsCommaPositions(varsCommaPositions) + , valuesCommaPositions(valuesCommaPositions) +{ +} + +CstStatFor::CstStatFor(Position equalsPosition, Position endCommaPosition, std::optional stepCommaPosition) + : CstNode(CstClassIndex()) + , equalsPosition(equalsPosition) + , endCommaPosition(endCommaPosition) + , stepCommaPosition(stepCommaPosition) +{ +} + +CstStatForIn::CstStatForIn(AstArray varsCommaPositions, AstArray valuesCommaPositions) + : CstNode(CstClassIndex()) + , varsCommaPositions(varsCommaPositions) + , valuesCommaPositions(valuesCommaPositions) +{ +} + +CstStatAssign::CstStatAssign( + AstArray varsCommaPositions, + Position equalsPosition, + AstArray valuesCommaPositions +) + : CstNode(CstClassIndex()) + , varsCommaPositions(varsCommaPositions) + , equalsPosition(equalsPosition) + , valuesCommaPositions(valuesCommaPositions) +{ +} + +CstStatCompoundAssign::CstStatCompoundAssign(Position opPosition) + : CstNode(CstClassIndex()) + , opPosition(opPosition) +{ +} + +CstStatLocalFunction::CstStatLocalFunction(Position functionKeywordPosition) + : CstNode(CstClassIndex()) + , functionKeywordPosition(functionKeywordPosition) +{ +} + +CstTypeReference::CstTypeReference( + std::optional prefixPointPosition, + Position openParametersPosition, + AstArray parametersCommaPositions, + Position closeParametersPosition +) + : CstNode(CstClassIndex()) + , prefixPointPosition(prefixPointPosition) + , openParametersPosition(openParametersPosition) + , parametersCommaPositions(parametersCommaPositions) + , closeParametersPosition(closeParametersPosition) +{ +} + +CstTypeTable::CstTypeTable(AstArray items, bool isArray) + : CstNode(CstClassIndex()) + , items(items) + , isArray(isArray) +{ +} + +CstTypeTypeof::CstTypeTypeof(Position openPosition, Position closePosition) + : CstNode(CstClassIndex()) + , openPosition(openPosition) + , closePosition(closePosition) +{ +} + +CstTypeSingletonString::CstTypeSingletonString(AstArray sourceString, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth) + : CstNode(CstClassIndex()) + , sourceString(sourceString) + , quoteStyle(quoteStyle) + , blockDepth(blockDepth) +{ + LUAU_ASSERT(quoteStyle != CstExprConstantString::QuotedInterp); +} + +} // namespace Luau diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 9aea4968..557295e0 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -306,6 +306,38 @@ static char unescape(char ch) } } +unsigned int Lexeme::getBlockDepth() const +{ + LUAU_ASSERT(type == Lexeme::RawString || type == Lexeme::BlockComment); + + // If we have a well-formed string, we are guaranteed to see 2 `]` characters after the end of the string contents + LUAU_ASSERT(*(data + length) == ']'); + unsigned int depth = 0; + do + { + depth++; + } while (*(data + length + depth) != ']'); + + return depth - 1; +} + +Lexeme::QuoteStyle Lexeme::getQuoteStyle() const +{ + LUAU_ASSERT(type == Lexeme::QuotedString); + + // If we have a well-formed string, we are guaranteed to see a closing delimiter after the string + LUAU_ASSERT(data); + + char quote = *(data + length); + if (quote == '\'') + return Lexeme::QuoteStyle::Single; + else if (quote == '"') + return Lexeme::QuoteStyle::Double; + + LUAU_ASSERT(!"Unknown quote style"); + return Lexeme::QuoteStyle::Double; // unreachable, but required due to compiler warning +} + Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names, Position startPosition) : buffer(buffer) , bufferSize(bufferSize) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 0d91f5d5..3fa0ccc9 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -24,6 +24,10 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames) LUAU_FASTFLAGVARIABLE(LuauFixFunctionNameStartPosition) LUAU_FASTFLAGVARIABLE(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAGVARIABLE(LuauStoreCSTData) +LUAU_FASTFLAGVARIABLE(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) +LUAU_FASTFLAGVARIABLE(LuauAstTypeGroup) +LUAU_FASTFLAGVARIABLE(ParserNoErrorLimit) namespace Luau { @@ -166,14 +170,14 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n AstStatBlock* root = p.parseChunk(); size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n'); - return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; + return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations), std::move(p.cstNodeMap)}; } catch (ParseError& err) { // when catching a fatal error, append it to the list of non-fatal errors and return p.parseErrors.push_back(err); - return ParseResult{nullptr, 0, {}, p.parseErrors}; + return ParseResult{nullptr, 0, {}, p.parseErrors, {}, std::move(p.cstNodeMap)}; } } @@ -184,6 +188,7 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc , recursionCounter(0) , endMismatchSuspect(Lexeme(Location(), Lexeme::Eof)) , localMap(AstName()) + , cstNodeMap(nullptr) { Function top; top.vararg = true; @@ -496,6 +501,7 @@ AstStat* Parser::parseRepeat() functionStack.back().loopDepth--; + Position untilPosition = lexer.current().location.begin; bool hasUntil = expectMatchEndAndConsume(Lexeme::ReservedUntil, matchRepeat); body->hasEnd = hasUntil; @@ -503,7 +509,17 @@ AstStat* Parser::parseRepeat() restoreLocals(localsBegin); - return allocator.alloc(Location(start, cond->location), cond, body, hasUntil); + if (FFlag::LuauStoreCSTData) + { + AstStatRepeat* node = allocator.alloc(Location(start, cond->location), cond, body, hasUntil); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(untilPosition); + return node; + } + else + { + return allocator.alloc(Location(start, cond->location), cond, body, hasUntil); + } } // do block end @@ -518,8 +534,12 @@ AstStat* Parser::parseDo() body->location.begin = start.begin; + Position endPosition = lexer.current().location.begin; body->hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[body] = allocator.alloc(endPosition); + return body; } @@ -559,18 +579,22 @@ AstStat* Parser::parseFor() if (lexer.current().type == '=') { + Position equalsPosition = lexer.current().location.begin; nextLexeme(); AstExpr* from = parseExpr(); + Position endCommaPosition = lexer.current().location.begin; expectAndConsume(',', "index range"); AstExpr* to = parseExpr(); + std::optional stepCommaPosition = std::nullopt; AstExpr* step = nullptr; if (lexer.current().type == ',') { + stepCommaPosition = lexer.current().location.begin; nextLexeme(); step = parseExpr(); @@ -596,25 +620,46 @@ AstStat* Parser::parseFor() bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); + if (FFlag::LuauStoreCSTData) + { + AstStatFor* node = allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(equalsPosition, endCommaPosition, stepCommaPosition); + return node; + } + else + { + return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); + } } else { TempVector names(scratchBinding); + TempVector varsCommaPosition(scratchPosition); names.push_back(varname); if (lexer.current().type == ',') { - nextLexeme(); + if (FFlag::LuauStoreCSTData && options.storeCstData) + { + varsCommaPosition.push_back(lexer.current().location.begin); + nextLexeme(); + parseBindingList(names, false, &varsCommaPosition); + } + else + { + nextLexeme(); - parseBindingList(names); + parseBindingList(names); + } } Location inLocation = lexer.current().location; bool hasIn = expectAndConsume(Lexeme::ReservedIn, "for loop"); TempVector values(scratchExpr); - parseExprList(values); + TempVector valuesCommaPositions(scratchPosition); + parseExprList(values, (FFlag::LuauStoreCSTData && options.storeCstData) ? &valuesCommaPositions : nullptr); Lexeme matchDo = lexer.current(); bool hasDo = expectAndConsume(Lexeme::ReservedDo, "for loop"); @@ -639,7 +684,18 @@ AstStat* Parser::parseFor() bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); + if (FFlag::LuauStoreCSTData) + { + AstStatForIn* node = + allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(varsCommaPosition), copy(valuesCommaPositions)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); + } } } @@ -835,6 +891,7 @@ AstStat* Parser::parseLocal(const AstArray& attributes) Lexeme matchFunction = lexer.current(); nextLexeme(); + Position functionKeywordPosition = matchFunction.location.begin; // matchFunction is only used for diagnostics; to make it suitable for detecting missed indentation between // `local function` and `end`, we patch the token to begin at the column where `local` starts if (matchFunction.location.begin.line == start.begin.line) @@ -850,7 +907,17 @@ AstStat* Parser::parseLocal(const AstArray& attributes) Location location{start.begin, body->location.end}; - return allocator.alloc(location, var, body); + if (FFlag::LuauStoreCSTData) + { + AstStatLocalFunction* node = allocator.alloc(location, var, body); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(functionKeywordPosition); + return node; + } + else + { + return allocator.alloc(location, var, body); + } } else { @@ -868,13 +935,18 @@ AstStat* Parser::parseLocal(const AstArray& attributes) matchRecoveryStopOnToken['=']++; TempVector names(scratchBinding); - parseBindingList(names); + TempVector varsCommaPositions(scratchPosition); + if (FFlag::LuauStoreCSTData && options.storeCstData) + parseBindingList(names, false, &varsCommaPositions); + else + parseBindingList(names); matchRecoveryStopOnToken['=']--; TempVector vars(scratchLocal); TempVector values(scratchExpr); + TempVector valuesCommaPositions(scratchPosition); std::optional equalsSignLocation; @@ -884,7 +956,7 @@ AstStat* Parser::parseLocal(const AstArray& attributes) nextLexeme(); - parseExprList(values); + parseExprList(values, (FFlag::LuauStoreCSTData && options.storeCstData) ? &valuesCommaPositions : nullptr); } for (size_t i = 0; i < names.size(); ++i) @@ -892,7 +964,17 @@ AstStat* Parser::parseLocal(const AstArray& attributes) Location end = values.empty() ? lexer.previousLocation() : values.back()->location; - return allocator.alloc(Location(start, end), copy(vars), copy(values), equalsSignLocation); + if (FFlag::LuauStoreCSTData) + { + AstStatLocal* node = allocator.alloc(Location(start, end), copy(vars), copy(values), equalsSignLocation); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(varsCommaPositions), copy(valuesCommaPositions)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(vars), copy(values), equalsSignLocation); + } } } @@ -904,13 +986,24 @@ AstStat* Parser::parseReturn() nextLexeme(); TempVector list(scratchExpr); + TempVector commaPositions(scratchPosition); if (!blockFollow(lexer.current()) && lexer.current().type != ';') - parseExprList(list); + parseExprList(list, (FFlag::LuauStoreCSTData && options.storeCstData) ? &commaPositions : nullptr); Location end = list.empty() ? start : list.back()->location; - return allocator.alloc(Location(start, end), copy(list)); + if (FFlag::LuauStoreCSTData) + { + AstStatReturn* node = allocator.alloc(Location(start, end), copy(list)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(commaPositions)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(list)); + } } // type Name [`<' varlist `>'] `=' Type @@ -1151,14 +1244,21 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArraylocation, "Cannot have more than one class indexer"); } else { - indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt); + if (FFlag::LuauStoreCSTData) + indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt).node; + else + indexer = parseTableIndexer_DEPRECATED(AstTableAccess::ReadWrite, std::nullopt); } } else @@ -1220,10 +1320,13 @@ AstStat* Parser::parseAssignment(AstExpr* initial) initial = reportExprError(initial->location, copy({initial}), "Assigned expression must be a variable or a field"); TempVector vars(scratchExpr); + TempVector varsCommaPositions(scratchPosition); vars.push_back(initial); while (lexer.current().type == ',') { + if (FFlag::LuauStoreCSTData && options.storeCstData) + varsCommaPositions.push_back(lexer.current().location.begin); nextLexeme(); AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); @@ -1234,12 +1337,23 @@ AstStat* Parser::parseAssignment(AstExpr* initial) vars.push_back(expr); } + Position equalsPosition = lexer.current().location.begin; expectAndConsume('=', "assignment"); TempVector values(scratchExprAux); - parseExprList(values); + TempVector valuesCommaPositions(scratchPosition); + parseExprList(values, FFlag::LuauStoreCSTData && options.storeCstData ? &valuesCommaPositions : nullptr); - return allocator.alloc(Location(initial->location, values.back()->location), copy(vars), copy(values)); + if (FFlag::LuauStoreCSTData) + { + AstStatAssign* node = allocator.alloc(Location(initial->location, values.back()->location), copy(vars), copy(values)); + cstNodeMap[node] = allocator.alloc(copy(varsCommaPositions), equalsPosition, copy(valuesCommaPositions)); + return node; + } + else + { + return allocator.alloc(Location(initial->location, values.back()->location), copy(vars), copy(values)); + } } // var [`+=' | `-=' | `*=' | `/=' | `%=' | `^=' | `..='] exp @@ -1250,11 +1364,22 @@ AstStat* Parser::parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op) initial = reportExprError(initial->location, copy({initial}), "Assigned expression must be a variable or a field"); } + Position opPosition = lexer.current().location.begin; nextLexeme(); AstExpr* value = parseExpr(); - return allocator.alloc(Location(initial->location, value->location), op, initial, value); + if (FFlag::LuauStoreCSTData) + { + AstStatCompoundAssign* node = allocator.alloc(Location(initial->location, value->location), op, initial, value); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(opPosition); + return node; + } + else + { + return allocator.alloc(Location(initial->location, value->location), op, initial, value); + } } std::pair> Parser::prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args) @@ -1370,12 +1495,14 @@ std::pair Parser::parseFunctionBody( } // explist ::= {exp `,'} exp -void Parser::parseExprList(TempVector& result) +void Parser::parseExprList(TempVector& result, TempVector* commaPositions) { result.push_back(parseExpr()); while (lexer.current().type == ',') { + if (FFlag::LuauStoreCSTData && commaPositions) + commaPositions->push_back(lexer.current().location.begin); nextLexeme(); if (lexer.current().type == ')') @@ -1402,7 +1529,7 @@ Parser::Binding Parser::parseBinding() } // bindinglist ::= (binding | `...') [`,' bindinglist] -std::tuple Parser::parseBindingList(TempVector& result, bool allowDot3) +std::tuple Parser::parseBindingList(TempVector& result, bool allowDot3, TempVector* commaPositions) { while (true) { @@ -1425,6 +1552,8 @@ std::tuple Parser::parseBindingList(TempVectorpush_back(lexer.current().location.begin); nextLexeme(); } @@ -1559,15 +1688,31 @@ std::pair Parser::parseReturnType() if (lexer.current().type != Lexeme::SkinnyArrow && resultNames.empty()) { // If it turns out that it's just '(A)', it's possible that there are unions/intersections to follow, so fold over it. - if (result.size() == 1) + if (FFlag::LuauAstTypeGroup) { - AstType* returnType = parseTypeSuffix(result[0], innerBegin); + if (result.size() == 1 && varargAnnotation == nullptr) + { + AstType* returnType = parseTypeSuffix(allocator.alloc(location, result[0]), begin.location); - // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack - // type to successfully parse. We need the span of the whole annotation. - Position endPos = result.size() == 1 ? location.end : returnType->location.end; + // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack + // type to successfully parse. We need the span of the whole annotation. + Position endPos = result.size() == 1 ? location.end : returnType->location.end; - return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + } + } + else + { + if (result.size() == 1) + { + AstType* returnType = parseTypeSuffix(result[0], innerBegin); + + // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack + // type to successfully parse. We need the span of the whole annotation. + Position endPos = result.size() == 1 ? location.end : returnType->location.end; + + return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + } } return {location, AstTypeList{copy(result), varargAnnotation}}; @@ -1578,8 +1723,61 @@ std::pair Parser::parseReturnType() return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } +std::pair Parser::extractStringDetails() +{ + LUAU_ASSERT(FFlag::LuauStoreCSTData); + + CstExprConstantString::QuoteStyle style; + unsigned int blockDepth = 0; + + switch (lexer.current().type) + { + case Lexeme::QuotedString: + style = lexer.current().getQuoteStyle() == Lexeme::QuoteStyle::Double ? CstExprConstantString::QuotedDouble + : CstExprConstantString::QuotedSingle; + break; + case Lexeme::InterpStringSimple: + style = CstExprConstantString::QuotedInterp; + break; + case Lexeme::RawString: + { + style = CstExprConstantString::QuotedRaw; + blockDepth = lexer.current().getBlockDepth(); + break; + } + default: + LUAU_ASSERT(false && "Invalid string type"); + } + + return {style, blockDepth}; +} + // TableIndexer ::= `[' Type `]' `:' Type -AstTableIndexer* Parser::parseTableIndexer(AstTableAccess access, std::optional accessLocation) +Parser::TableIndexerResult Parser::parseTableIndexer(AstTableAccess access, std::optional accessLocation) +{ + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + + AstType* index = parseType(); + + Position indexerClosePosition = lexer.current().location.begin; + expectMatchAndConsume(']', begin); + + Position colonPosition = lexer.current().location.begin; + expectAndConsume(':', "table field"); + + AstType* result = parseType(); + + return { + allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location), access, accessLocation}), + begin.location.begin, + indexerClosePosition, + colonPosition, + }; +} + +// Remove with FFlagLuauStoreCSTData +AstTableIndexer* Parser::parseTableIndexer_DEPRECATED(AstTableAccess access, std::optional accessLocation) { const Lexeme begin = lexer.current(); nextLexeme(); // [ @@ -1604,6 +1802,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) incrementRecursionCounter("type annotation"); TempVector props(scratchTableTypeProps); + TempVector cstItems(scratchCstTableTypeProps); AstTableIndexer* indexer = nullptr; Location start = lexer.current().location; @@ -1611,6 +1810,8 @@ AstType* Parser::parseTableType(bool inDeclarationContext) MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table type"); + bool isArray = false; + while (lexer.current().type != '}') { AstTableAccess access = AstTableAccess::ReadWrite; @@ -1636,9 +1837,18 @@ AstType* Parser::parseTableType(bool inDeclarationContext) { const Lexeme begin = lexer.current(); nextLexeme(); // [ - std::optional> chars = parseCharArray(); + CstExprConstantString::QuoteStyle style; + unsigned int blockDepth = 0; + if (FFlag::LuauStoreCSTData && options.storeCstData) + std::tie(style, blockDepth) = extractStringDetails(); + + AstArray sourceString; + std::optional> chars = parseCharArray(options.storeCstData ? &sourceString : nullptr); + + Position indexerClosePosition = lexer.current().location.begin; expectMatchAndConsume(']', begin); + Position colonPosition = lexer.current().location.begin; expectAndConsume(':', "table field"); AstType* type = parseType(); @@ -1647,7 +1857,19 @@ AstType* Parser::parseTableType(bool inDeclarationContext) bool containsNull = chars && (memchr(chars->data, 0, chars->size) != nullptr); if (chars && !containsNull) + { props.push_back(AstTableProp{AstName(chars->data), begin.location, type, access, accessLocation}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back(CstTypeTable::Item{ + CstTypeTable::Item::Kind::StringProperty, + begin.location.begin, + indexerClosePosition, + colonPosition, + tableSeparator(), + lexer.current().location.begin, + allocator.alloc(sourceString, style, blockDepth) + }); + } else report(begin.location, "String literal contains malformed escape sequence or \\0"); } @@ -1657,14 +1879,35 @@ AstType* Parser::parseTableType(bool inDeclarationContext) { // maybe we don't need to parse the entire badIndexer... // however, we either have { or [ to lint, not the entire table type or the bad indexer. - AstTableIndexer* badIndexer = parseTableIndexer(access, accessLocation); + AstTableIndexer* badIndexer; + if (FFlag::LuauStoreCSTData) + badIndexer = parseTableIndexer(access, accessLocation).node; + else + badIndexer = parseTableIndexer_DEPRECATED(access, accessLocation); // we lose all additional indexer expressions from the AST after error recovery here report(badIndexer->location, "Cannot have more than one table indexer"); } else { - indexer = parseTableIndexer(access, accessLocation); + if (FFlag::LuauStoreCSTData) + { + auto tableIndexerResult = parseTableIndexer(access, accessLocation); + indexer = tableIndexerResult.node; + if (options.storeCstData) + cstItems.push_back(CstTypeTable::Item{ + CstTypeTable::Item::Kind::Indexer, + tableIndexerResult.indexerOpenPosition, + tableIndexerResult.indexerClosePosition, + tableIndexerResult.colonPosition, + tableSeparator(), + lexer.current().location.begin, + }); + } + else + { + indexer = parseTableIndexer_DEPRECATED(access, accessLocation); + } } } else if (props.empty() && !indexer && !(lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':')) @@ -1672,6 +1915,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) AstType* type = parseType(); // array-like table type: {T} desugars into {[number]: T} + isArray = true; AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber, std::nullopt, type->location); indexer = allocator.alloc(AstTableIndexer{index, type, type->location, access, accessLocation}); @@ -1684,11 +1928,21 @@ AstType* Parser::parseTableType(bool inDeclarationContext) if (!name) break; + Position colonPosition = lexer.current().location.begin; expectAndConsume(':', "table field"); AstType* type = parseType(inDeclarationContext); props.push_back(AstTableProp{name->name, name->location, type, access, accessLocation}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back(CstTypeTable::Item{ + CstTypeTable::Item::Kind::Property, + Position{0, 0}, + Position{0, 0}, + colonPosition, + tableSeparator(), + lexer.current().location.begin + }); } if (lexer.current().type == ',' || lexer.current().type == ';') @@ -1707,7 +1961,17 @@ AstType* Parser::parseTableType(bool inDeclarationContext) if (!expectMatchAndConsume('}', matchBrace, /* searchForMissing = */ FFlag::LuauErrorRecoveryForTableTypes)) end = lexer.previousLocation(); - return allocator.alloc(Location(start, end), copy(props), indexer); + if (FFlag::LuauStoreCSTData) + { + AstTypeTable* node = allocator.alloc(Location(start, end), copy(props), indexer); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(cstItems), isArray); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(props), indexer); + } } // ReturnType ::= Type | `(' TypeList `)' @@ -1752,7 +2016,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; else - return {params[0], {}}; + { + if (FFlag::LuauAstTypeGroup) + return {allocator.alloc(Location(parameterStart.location, params[0]->location), params[0]), {}}; + else + return {params[0], {}}; + } } if (!forceFunctionType && !returnTypeIntroducer && allowPack) @@ -1874,8 +2143,16 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) ParseError::raise(parts.back()->location, "Exceeded allowed type length; simplify your type annotation to make the code compile"); } - if (parts.size() == 1) - return parts[0]; + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (parts.size() == 1 && !isUnion && !isIntersection) + return parts[0]; + } + else + { + if (parts.size() == 1) + return parts[0]; + } if (isUnion && isIntersection) { @@ -1979,13 +2256,35 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { - if (std::optional> value = parseCharArray()) + if (FFlag::LuauStoreCSTData) { - AstArray svalue = *value; - return {allocator.alloc(start, svalue)}; + CstExprConstantString::QuoteStyle style; + unsigned int blockDepth = 0; + if (options.storeCstData) + std::tie(style, blockDepth) = extractStringDetails(); + + AstArray originalString; + if (std::optional> value = parseCharArray(options.storeCstData ? &originalString : nullptr)) + { + AstArray svalue = *value; + auto node = allocator.alloc(start, svalue); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(originalString, style, blockDepth); + return {node}; + } + else + return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; } else - return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; + { + if (std::optional> value = parseCharArray()) + { + AstArray svalue = *value; + return {allocator.alloc(start, svalue)}; + } + else + return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; + } } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { @@ -2001,17 +2300,30 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) else if (lexer.current().type == Lexeme::Name) { std::optional prefix; + std::optional prefixPointPosition; std::optional prefixLocation; Name name = parseName("type name"); if (lexer.current().type == '.') { - Position pointPosition = lexer.current().location.begin; - nextLexeme(); + if (FFlag::LuauStoreCSTData) + { + prefixPointPosition = lexer.current().location.begin; + nextLexeme(); - prefix = name.name; - prefixLocation = name.location; - name = parseIndexName("field name", pointPosition); + prefix = name.name; + prefixLocation = name.location; + name = parseIndexName("field name", *prefixPointPosition); + } + else + { + Position pointPosition = lexer.current().location.begin; + nextLexeme(); + + prefix = name.name; + prefixLocation = name.location; + name = parseIndexName("field name", pointPosition); + } } else if (lexer.current().type == Lexeme::Dot3) { @@ -2029,23 +2341,53 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) expectMatchAndConsume(')', typeofBegin); - return {allocator.alloc(Location(start, end), expr), {}}; + if (FFlag::LuauStoreCSTData) + { + AstTypeTypeof* node = allocator.alloc(Location(start, end), expr); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(typeofBegin.location.begin, end.begin); + return {node, {}}; + } + else + { + return {allocator.alloc(Location(start, end), expr), {}}; + } } bool hasParameters = false; AstArray parameters{}; + Position parametersOpeningPosition{0, 0}; + TempVector parametersCommaPositions(scratchPosition); + Position parametersClosingPosition{0, 0}; if (lexer.current().type == '<') { hasParameters = true; - parameters = parseTypeParams(); + if (FFlag::LuauStoreCSTData && options.storeCstData) + parameters = parseTypeParams(¶metersOpeningPosition, ¶metersCommaPositions, ¶metersClosingPosition); + else + parameters = parseTypeParams(); } Location end = lexer.previousLocation(); - return { - allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {} - }; + if (FFlag::LuauStoreCSTData) + { + AstTypeReference* node = + allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc( + prefixPointPosition, parametersOpeningPosition, copy(parametersCommaPositions), parametersClosingPosition + ); + return {node, {}}; + } + else + { + return { + allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), + {} + }; + } } else if (lexer.current().type == '{') { @@ -2296,11 +2638,14 @@ AstExpr* Parser::parseExpr(unsigned int limit) if (uop) { + Position opPosition = lexer.current().location.begin; nextLexeme(); AstExpr* subexpr = parseExpr(unaryPriority); expr = allocator.alloc(Location(start, subexpr->location), *uop, subexpr); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[expr] = allocator.alloc(opPosition); } else { @@ -2315,12 +2660,15 @@ AstExpr* Parser::parseExpr(unsigned int limit) while (op && binaryPriority[*op].left > limit) { + Position opPosition = lexer.current().location.begin; nextLexeme(); // read sub-expression with higher priority AstExpr* next = parseExpr(binaryPriority[*op].right); expr = allocator.alloc(Location(start, next->location), *op, expr, next); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[expr] = allocator.alloc(opPosition); op = parseBinaryOp(lexer.current()); if (!op) @@ -2420,11 +2768,14 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) AstExpr* index = parseExpr(); + Position closeBracketPosition = lexer.current().location.begin; Position end = lexer.current().location.end; expectMatchAndConsume(']', matchBracket); expr = allocator.alloc(Location(start, end), expr, index); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[expr] = allocator.alloc(matchBracket.position, closeBracketPosition); } else if (lexer.current().type == ':') { @@ -2652,16 +3003,29 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) nextLexeme(); TempVector args(scratchExpr); + TempVector commaPositions(scratchPosition); if (lexer.current().type != ')') - parseExprList(args); + parseExprList(args, (FFlag::LuauStoreCSTData && options.storeCstData) ? &commaPositions : nullptr); Location end = lexer.current().location; Position argEnd = end.end; expectMatchAndConsume(')', matchParen); - return allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + if (FFlag::LuauStoreCSTData) + { + AstExprCall* node = allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc( + matchParen.position, lexer.previousLocation().begin, copy(commaPositions) + ); + return node; + } + else + { + return allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + } } else if (lexer.current().type == '{') { @@ -2669,14 +3033,35 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) AstExpr* expr = parseTableConstructor(); Position argEnd = lexer.previousLocation().end; - return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + if (FFlag::LuauStoreCSTData) + { + AstExprCall* node = + allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(std::nullopt, std::nullopt, AstArray{nullptr, 0}); + return node; + } + else + { + return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + } } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { Location argLocation = lexer.current().location; AstExpr* expr = parseString(); - return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + if (FFlag::LuauStoreCSTData) + { + AstExprCall* node = allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(std::nullopt, std::nullopt, AstArray{nullptr, 0}); + return node; + } + else + { + return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + } } else { @@ -2710,6 +3095,17 @@ LUAU_NOINLINE void Parser::reportAmbiguousCallError() ); } +std::optional Parser::tableSeparator() +{ + LUAU_ASSERT(FFlag::LuauStoreCSTData); + if (lexer.current().type == ',') + return CstExprTable::Comma; + else if (lexer.current().type == ';') + return CstExprTable::Semicolon; + else + return std::nullopt; +} + // tableconstructor ::= `{' [fieldlist] `}' // fieldlist ::= field {fieldsep field} [fieldsep] // field ::= `[' exp `]' `=' exp | Name `=' exp | exp @@ -2717,6 +3113,7 @@ LUAU_NOINLINE void Parser::reportAmbiguousCallError() AstExpr* Parser::parseTableConstructor() { TempVector items(scratchItem); + TempVector cstItems(scratchCstItem); Location start = lexer.current().location; @@ -2730,23 +3127,29 @@ AstExpr* Parser::parseTableConstructor() if (lexer.current().type == '[') { + Position indexerOpenPosition = lexer.current().location.begin; MatchLexeme matchLocationBracket = lexer.current(); nextLexeme(); AstExpr* key = parseExpr(); + Position indexerClosePosition = lexer.current().location.begin; expectMatchAndConsume(']', matchLocationBracket); + Position equalsPosition = lexer.current().location.begin; expectAndConsume('=', "table field"); AstExpr* value = parseExpr(); items.push_back({AstExprTable::Item::General, key, value}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back({indexerOpenPosition, indexerClosePosition, equalsPosition, tableSeparator(), lexer.current().location.begin}); } else if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == '=') { Name name = parseName("table field"); + Position equalsPosition = lexer.current().location.begin; expectAndConsume('=', "table field"); AstArray nameString; @@ -2760,12 +3163,16 @@ AstExpr* Parser::parseTableConstructor() func->debugname = name.name; items.push_back({AstExprTable::Item::Record, key, value}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back({std::nullopt, std::nullopt, equalsPosition, tableSeparator(), lexer.current().location.begin}); } else { AstExpr* expr = parseExpr(); items.push_back({AstExprTable::Item::List, nullptr, expr}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back({std::nullopt, std::nullopt, std::nullopt, tableSeparator(), lexer.current().location.begin}); } if (lexer.current().type == ',' || lexer.current().type == ';') @@ -2787,7 +3194,17 @@ AstExpr* Parser::parseTableConstructor() if (!expectMatchAndConsume('}', matchBrace)) end = lexer.previousLocation(); - return allocator.alloc(Location(start, end), copy(items)); + if (FFlag::LuauStoreCSTData) + { + AstExprTable* node = allocator.alloc(Location(start, end), copy(items)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(cstItems)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(items)); + } } AstExpr* Parser::parseIfElseExpr() @@ -2799,11 +3216,14 @@ AstExpr* Parser::parseIfElseExpr() AstExpr* condition = parseExpr(); + Position thenPosition = lexer.current().location.begin; bool hasThen = expectAndConsume(Lexeme::ReservedThen, "if then else expression"); AstExpr* trueExpr = parseExpr(); AstExpr* falseExpr = nullptr; + Position elsePosition = lexer.current().location.begin; + bool isElseIf = false; if (lexer.current().type == Lexeme::ReservedElseif) { unsigned int oldRecursionCount = recursionCounter; @@ -2811,6 +3231,8 @@ AstExpr* Parser::parseIfElseExpr() hasElse = true; falseExpr = parseIfElseExpr(); recursionCounter = oldRecursionCount; + if (FFlag::LuauStoreCSTData) + isElseIf = true; } else { @@ -2820,7 +3242,17 @@ AstExpr* Parser::parseIfElseExpr() Location end = falseExpr->location; - return allocator.alloc(Location(start, end), condition, hasThen, trueExpr, hasElse, falseExpr); + if (FFlag::LuauStoreCSTData) + { + AstExprIfElse* node = allocator.alloc(Location(start, end), condition, hasThen, trueExpr, hasElse, falseExpr); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(thenPosition, elsePosition, isElseIf); + return node; + } + else + { + return allocator.alloc(Location(start, end), condition, hasThen, trueExpr, hasElse, falseExpr); + } } // Name @@ -2970,13 +3402,15 @@ std::pair, AstArray> Parser::parseG return {generics, genericPacks}; } -AstArray Parser::parseTypeParams() +AstArray Parser::parseTypeParams(Position* openingPosition, TempVector* commaPositions, Position* closingPosition) { TempVector parameters{scratchTypeOrPack}; if (lexer.current().type == '<') { Lexeme begin = lexer.current(); + if (FFlag::LuauStoreCSTData && openingPosition) + *openingPosition = begin.location.begin; nextLexeme(); while (true) @@ -3022,7 +3456,15 @@ AstArray Parser::parseTypeParams() // the next lexeme is one that follows a type // (&, |, ?), then assume that this was actually a // parenthesized type. - parameters.push_back({parseTypeSuffix(explicitTypePack->typeList.types.data[0], begin), {}}); + if (FFlag::LuauAstTypeGroup) + { + auto parenthesizedType = explicitTypePack->typeList.types.data[0]; + parameters.push_back( + {parseTypeSuffix(allocator.alloc(parenthesizedType->location, parenthesizedType), begin), {}} + ); + } + else + parameters.push_back({parseTypeSuffix(explicitTypePack->typeList.types.data[0], begin), {}}); } else { @@ -3064,18 +3506,24 @@ AstArray Parser::parseTypeParams() } if (lexer.current().type == ',') + { + if (FFlag::LuauStoreCSTData && commaPositions) + commaPositions->push_back(lexer.current().location.begin); nextLexeme(); + } else break; } + if (FFlag::LuauStoreCSTData && closingPosition) + *closingPosition = lexer.current().location.begin; expectMatchAndConsume('>', begin); } return copy(parameters); } -std::optional> Parser::parseCharArray() +std::optional> Parser::parseCharArray(AstArray* originalString) { LUAU_ASSERT( lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || @@ -3083,6 +3531,11 @@ std::optional> Parser::parseCharArray() ); scratchData.assign(lexer.current().data, lexer.current().getLength()); + if (FFlag::LuauStoreCSTData) + { + if (originalString) + *originalString = copy(scratchData); + } if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { @@ -3120,15 +3573,38 @@ AstExpr* Parser::parseString() LUAU_ASSERT(false && "Invalid string type"); } - if (std::optional> value = parseCharArray()) - return allocator.alloc(location, *value, style); + if (FFlag::LuauStoreCSTData) + { + CstExprConstantString::QuoteStyle fullStyle; + unsigned int blockDepth; + if (options.storeCstData) + std::tie(fullStyle, blockDepth) = extractStringDetails(); + + AstArray originalString; + if (std::optional> value = parseCharArray(options.storeCstData ? &originalString : nullptr)) + { + AstExprConstantString* node = allocator.alloc(location, *value, style); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(originalString, fullStyle, blockDepth); + return node; + } + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); + } else - return reportExprError(location, {}, "String literal contains malformed escape sequence"); + { + if (std::optional> value = parseCharArray()) + return allocator.alloc(location, *value, style); + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); + } } AstExpr* Parser::parseInterpString() { TempVector> strings(scratchString); + TempVector> sourceStrings(scratchString2); + TempVector stringPositions(scratchPosition); TempVector expressions(scratchExpr); Location startLocation = lexer.current().location; @@ -3146,6 +3622,12 @@ AstExpr* Parser::parseInterpString() scratchData.assign(currentLexeme.data, currentLexeme.getLength()); + if (FFlag::LuauStoreCSTData && options.storeCstData) + { + sourceStrings.push_back(copy(scratchData)); + stringPositions.push_back(currentLexeme.location.begin); + } + if (!Lexer::fixupQuotedString(scratchData)) { nextLexeme(); @@ -3210,7 +3692,15 @@ AstExpr* Parser::parseInterpString() AstArray> stringsArray = copy(strings); AstArray expressionsArray = copy(expressions); - return allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); + if (FFlag::LuauStoreCSTData) + { + AstExprInterpString* node = allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(sourceStrings), copy(stringPositions)); + return node; + } + else + return allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); } AstExpr* Parser::parseNumber() @@ -3218,6 +3708,9 @@ AstExpr* Parser::parseNumber() Location start = lexer.current().location; scratchData.assign(lexer.current().data, lexer.current().getLength()); + AstArray sourceData; + if (FFlag::LuauStoreCSTData && options.storeCstData) + sourceData = copy(scratchData); // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al if (scratchData.find('_') != std::string::npos) @@ -3232,7 +3725,17 @@ AstExpr* Parser::parseNumber() if (result == ConstantNumberParseResult::Malformed) return reportExprError(start, {}, "Malformed number"); - return allocator.alloc(start, value, result); + if (FFlag::LuauStoreCSTData) + { + AstExprConstantNumber* node = allocator.alloc(start, value, result); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(sourceData); + return node; + } + else + { + return allocator.alloc(start, value, result); + } } AstLocal* Parser::pushLocal(const Binding& binding) @@ -3509,7 +4012,7 @@ void Parser::report(const Location& location, const char* format, va_list args) parseErrors.emplace_back(location, message); - if (parseErrors.size() >= unsigned(FInt::LuauParseErrorLimit)) + if (parseErrors.size() >= unsigned(FInt::LuauParseErrorLimit) && (!FFlag::ParserNoErrorLimit || !options.noErrorLimit)) ParseError::raise(location, "Reached error limit (%d)", int(FInt::LuauParseErrorLimit)); } diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index 137c554c..3e980566 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -20,9 +20,21 @@ #elif defined(__linux__) || defined(__APPLE__) -// Defined in unwind.h which may not be easily discoverable on various platforms -extern "C" void __register_frame(const void*) __attribute__((weak)); -extern "C" void __deregister_frame(const void*) __attribute__((weak)); +// __register_frame and __deregister_frame are defined in libgcc or libc++ +// (depending on how it's built). We want to declare them as weak symbols +// so that if they're provided by a shared library, we'll use them, and if +// not, we'll disable some c++ exception handling support. However, if they're +// declared as weak and the definitions are linked in a static library +// that's not linked with whole-archive, then the symbols will technically be defined here, +// and the linker won't look for the strong ones in the library. +#ifndef LUAU_ENABLE_REGISTER_FRAME +#define REGISTER_FRAME_WEAK __attribute__((weak)) +#else +#define REGISTER_FRAME_WEAK +#endif + +extern "C" void __register_frame(const void*) REGISTER_FRAME_WEAK; +extern "C" void __deregister_frame(const void*) REGISTER_FRAME_WEAK; extern "C" void __unw_add_dynamic_fde() __attribute__((weak)); #endif @@ -121,7 +133,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz #endif #elif defined(__linux__) || defined(__APPLE__) - if (!__register_frame) + if (!&__register_frame) return nullptr; visitFdeEntries(unwindData, __register_frame); @@ -150,7 +162,7 @@ void destroyBlockUnwindInfo(void* context, void* unwindData) #endif #elif defined(__linux__) || defined(__APPLE__) - if (!__deregister_frame) + if (!&__deregister_frame) { CODEGEN_ASSERT(!"Cannot deregister unwind information"); return; diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index cb8dbcaf..68ae1e8c 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -14,7 +14,7 @@ inline bool isFlagExperimental(const char* flag) "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative "StudioReportLuauAny2", // takes telemetry data for usage of any types - "LuauTableCloneClonesType2", // requires fixes in lua-apps code, terrifyingly + "LuauTableCloneClonesType3", // requires fixes in lua-apps code, terrifyingly "LuauSolverV2", // makes sure we always have at least one entry nullptr, diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index 9fe9798e..e251447b 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -121,6 +121,10 @@ static LuauBytecodeType getType( { return LBC_TYPE_ANY; } + else if (const AstTypeGroup* group = ty->as()) + { + return getType(group->type, generics, typeAliases, resolveAliases, hostVectorType, userdataTypes, bytecode); + } return LBC_TYPE_ANY; } diff --git a/Sources.cmake b/Sources.cmake index 306d1530..1c312cb9 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -17,6 +17,7 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/Allocator.h Ast/include/Luau/Ast.h Ast/include/Luau/Confusables.h + Ast/include/Luau/Cst.h Ast/include/Luau/Lexer.h Ast/include/Luau/Location.h Ast/include/Luau/ParseOptions.h @@ -28,6 +29,7 @@ target_sources(Luau.Ast PRIVATE Ast/src/Allocator.cpp Ast/src/Ast.cpp Ast/src/Confusables.cpp + Ast/src/Cst.cpp Ast/src/Lexer.cpp Ast/src/Location.cpp Ast/src/Parser.cpp diff --git a/VM/src/lbuflib.cpp b/VM/src/lbuflib.cpp index 7edec3ad..17ca8b0b 100644 --- a/VM/src/lbuflib.cpp +++ b/VM/src/lbuflib.cpp @@ -270,7 +270,7 @@ static int buffer_readbits(lua_State* L) uint64_t data = 0; -#if LUAU_BIG_ENDIAN +#if defined(LUAU_BIG_ENDIAN) for (int i = int(endbyte) - 1; i >= int(startbyte); i--) data = (data << 8) + uint8_t(((char*)buf)[i]); #else @@ -306,7 +306,7 @@ static int buffer_writebits(lua_State* L) uint64_t data = 0; -#if LUAU_BIG_ENDIAN +#if defined(LUAU_BIG_ENDIAN) for (int i = int(endbyte) - 1; i >= int(startbyte); i--) data = data * 256 + uint8_t(((char*)buf)[i]); #else @@ -318,7 +318,7 @@ static int buffer_writebits(lua_State* L) data = (data & ~mask) | ((uint64_t(value) << subbyteoffset) & mask); -#if LUAU_BIG_ENDIAN +#if defined(LUAU_BIG_ENDIAN) for (int i = int(startbyte); i < int(endbyte); i++) { ((char*)buf)[i] = data & 0xff; diff --git a/tests/AnyTypeSummary.test.cpp b/tests/AnyTypeSummary.test.cpp index 7184ef76..12e02264 100644 --- a/tests/AnyTypeSummary.test.cpp +++ b/tests/AnyTypeSummary.test.cpp @@ -20,7 +20,8 @@ LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) LUAU_FASTFLAG(LuauAlwaysFillInFunctionCallDiscriminantTypes) -LUAU_FASTFLAG(LuauRemoveNotAnyHack) +LUAU_FASTFLAG(LuauStoreCSTData) +LUAU_FASTFLAG(LuauAstTypeGroup) struct ATSFixture : BuiltinsFixture @@ -74,7 +75,22 @@ export type t8 = t0 &((true | any)->('')) LUAU_ASSERT(module->ats.typeInfo.size() == 1); LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Alias); - LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0 &((true | any)->(''))"); + if (FFlag::LuauStoreCSTData && FFlag::LuauAstTypeGroup) + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0& (( true | any)->(''))"); + } + else if (FFlag::LuauStoreCSTData) + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0 &(( true | any)->(''))"); + } + else if (FFlag::LuauAstTypeGroup) + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0& ((true | any)->(''))"); + } + else + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0 &((true | any)->(''))"); + } } TEST_CASE_FIXTURE(ATSFixture, "typepacks") @@ -413,7 +429,6 @@ TEST_CASE_FIXTURE(ATSFixture, "CannotExtendTable") {FFlag::LuauSolverV2, true}, {FFlag::StudioReportLuauAny2, true}, {FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes, true}, - {FFlag::LuauRemoveNotAnyHack, true}, }; fileResolver.source["game/Gui/Modules/A"] = R"( @@ -507,7 +522,6 @@ TEST_CASE_FIXTURE(ATSFixture, "racing_collision_2") {FFlag::LuauSolverV2, true}, {FFlag::StudioReportLuauAny2, true}, {FFlag::LuauAlwaysFillInFunctionCallDiscriminantTypes, true}, - {FFlag::LuauRemoveNotAnyHack, true}, }; fileResolver.source["game/Gui/Modules/A"] = R"( @@ -577,13 +591,26 @@ initialize() LUAU_ASSERT(module->ats.typeInfo.size() == 11); LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncArg); - LUAU_ASSERT( - module->ats.typeInfo[0].node == - "local function onCharacterAdded(character: Model)\n\n character.DescendantAdded:Connect(function(descendant)\n if " - "descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n end)\n\n\n for _, descendant in " - "character:GetDescendants()do\n if descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n " - "end\nend" - ); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ( + module->ats.typeInfo[0].node, + "local function onCharacterAdded(character: Model)\n\n character.DescendantAdded:Connect(function(descendant)\n if " + "descendant:IsA('BasePart') then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n end)\n\n\n for _, descendant in " + "character:GetDescendants() do\n if descendant:IsA('BasePart') then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n " + "end\nend" + ); + } + else + { + LUAU_ASSERT( + module->ats.typeInfo[0].node == + "local function onCharacterAdded(character: Model)\n\n character.DescendantAdded:Connect(function(descendant)\n if " + "descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n end)\n\n\n for _, descendant in " + "character:GetDescendants()do\n if descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n " + "end\nend" + ); + } } TEST_CASE_FIXTURE(ATSFixture, "racing_spawning_1") diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index e6e67020..de30be04 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -11,6 +11,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauAstTypeGroup) + struct JsonEncoderFixture { Allocator allocator; @@ -471,10 +473,17 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") { AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); - std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,36","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","nameLocation":"0,11 - 0,17","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","nameLocation":"0,23 - 0,29","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","nameLocation":"0,32 - 0,35","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","nameLocation":"0,42 - 0,48","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; - - CHECK(toJson(statement) == expected); + if (FFlag::LuauAstTypeGroup) + { + std::string_view expected = R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeGroup","location":"0,9 - 0,36","type":{"type":"AstTypeFunction","location":"0,10 - 0,36","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","nameLocation":"0,11 - 0,17","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeGroup","location":"0,22 - 0,36","type":{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","nameLocation":"0,23 - 0,29","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","nameLocation":"0,32 - 0,35","parameters":[]}]}}]}}},{"type":"AstTypeGroup","location":"0,40 - 0,55","type":{"type":"AstTypeFunction","location":"0,41 - 0,55","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","nameLocation":"0,42 - 0,48","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[]}}}]},"exported":false})"; + CHECK(toJson(statement) == expected); + } + else + { + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,36","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","nameLocation":"0,11 - 0,17","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","nameLocation":"0,23 - 0,29","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","nameLocation":"0,32 - 0,35","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","nameLocation":"0,42 - 0,48","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; + CHECK(toJson(statement) == expected); + } } TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_type_literal") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index dd02671f..6a8bca05 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -4329,4 +4329,21 @@ local var = data;@1 CHECK_EQ(ac.context, AutocompleteContext::Statement); } +TEST_CASE_FIXTURE(ACBuiltinsFixture, "require_tracing") +{ + fileResolver.source["Module/A"] = R"( +return { x = 0 } + )"; + + fileResolver.source["Module/B"] = R"( +local result = require(script.Parent.A) +local x = 1 + result. + )"; + + auto ac = autocomplete("Module/B", Position{2, 21}); + + CHECK(ac.entryMap.size() == 1); + CHECK(ac.entryMap.count("x")); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index e88c77e7..af04fb77 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -8757,6 +8757,23 @@ end ); } +TEST_CASE("TypeGroup") +{ + CHECK_EQ( + "\n" + compileTypeTable(R"( +function myfunc(test: (string), foo: nil) +end + +function myfunc2(test: (string | nil), foo: nil) +end +)"), + R"( +0: function(string, nil) +1: function(string?, nil) +)" + ); +} + TEST_CASE("BuiltinFoldMathK") { // we can fold math.pi at optimization level 2 diff --git a/tests/EqSatSimplification.test.cpp b/tests/EqSatSimplification.test.cpp index 0331d067..6fe2660f 100644 --- a/tests/EqSatSimplification.test.cpp +++ b/tests/EqSatSimplification.test.cpp @@ -3,6 +3,7 @@ #include "Fixture.h" #include "Luau/EqSatSimplification.h" +#include "Luau/Type.h" using namespace Luau; @@ -76,7 +77,7 @@ TEST_CASE_FIXTURE(ESFixture, "number | string") TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = number | t1") { - TypeId ty = arena->freshType(nullptr); + TypeId ty = arena->freshType(builtinTypes, nullptr); asMutable(ty)->ty.emplace(std::vector{builtinTypes->numberType, ty}); CHECK("number" == simplifyStr(ty)); @@ -450,7 +451,7 @@ TEST_CASE_FIXTURE(ESFixture, "(boolean | nil) & (false | nil)") TEST_CASE_FIXTURE(ESFixture, "free & string & number") { Scope scope{builtinTypes->anyTypePack}; - const TypeId freeTy = arena->addType(FreeType{&scope}); + const TypeId freeTy = arena->freshType(builtinTypes, &scope); CHECK("never" == simplifyStr(arena->addType(IntersectionType{{freeTy, builtinTypes->numberType, builtinTypes->stringType}}))); } diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index 326a5c98..58bbc16a 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -9,6 +9,7 @@ #include "Luau/Common.h" #include "Luau/Frontend.h" #include "Luau/AutocompleteTypes.h" +#include "Luau/Type.h" #include #include @@ -27,6 +28,14 @@ LUAU_FASTFLAG(LuauStoreSolverTypeOnModule); LUAU_FASTFLAG(LexerResumesFromPosition2) LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection) LUAU_FASTINT(LuauParseErrorLimit) +LUAU_FASTFLAG(LuauCloneIncrementalModule) + +LUAU_FASTFLAG(LuauIncrementalAutocompleteBugfixes) +LUAU_FASTFLAG(LuauReferenceAllocatorInNewSolver) +LUAU_FASTFLAG(LuauMixedModeDefFinderTraversesTypeOf) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) + +LUAU_FASTFLAG(LuauBetterReverseDependencyTracking) static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) { @@ -46,15 +55,25 @@ static FrontendOptions getOptions() return options; } +static ModuleResolver& getModuleResolver(Luau::Frontend& frontend) +{ + return FFlag::LuauSolverV2 ? frontend.moduleResolver : frontend.moduleResolverForAutocomplete; +} + template struct FragmentAutocompleteFixtureImpl : BaseType { - ScopedFastFlag sffs[5] = { + static_assert(std::is_base_of_v, "BaseType must be a descendant of Fixture"); + + ScopedFastFlag sffs[8] = { {FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete, true}, {FFlag::LuauStoreSolverTypeOnModule, true}, {FFlag::LuauSymbolEquality, true}, - {FFlag::LexerResumesFromPosition2, true} + {FFlag::LexerResumesFromPosition2, true}, + {FFlag::LuauReferenceAllocatorInNewSolver, true}, + {FFlag::LuauIncrementalAutocompleteBugfixes, true}, + {FFlag::LuauBetterReverseDependencyTracking, true}, }; FragmentAutocompleteFixtureImpl() @@ -128,6 +147,26 @@ struct FragmentAutocompleteFixtureImpl : BaseType result = autocompleteFragment(updated, cursorPos, fragmentEndPosition); assertions(result); } + + std::pair typecheckFragmentForModule( + const ModuleName& module, + const std::string& document, + Position cursorPos, + std::optional fragmentEndPosition = std::nullopt + ) + { + return Luau::typecheckFragment(this->frontend, module, cursorPos, getOptions(), document, fragmentEndPosition); + } + + FragmentAutocompleteResult autocompleteFragmentForModule( + const ModuleName& module, + const std::string& document, + Position cursorPos, + std::optional fragmentEndPosition = std::nullopt + ) + { + return Luau::fragmentAutocomplete(this->frontend, document, module, cursorPos, getOptions(), nullCallback, fragmentEndPosition); + } }; struct FragmentAutocompleteFixture : FragmentAutocompleteFixtureImpl @@ -162,10 +201,13 @@ end // 'for autocomplete'. loadDefinition(fakeVecDecl); loadDefinition(fakeVecDecl, /* For Autocomplete Module */ true); + + addGlobalBinding(frontend.globals, "game", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "game", Binding{builtinTypes->anyType}); } }; -//NOLINTBEGIN(bugprone-unchecked-optional-access) +// NOLINTBEGIN(bugprone-unchecked-optional-access) TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals") @@ -574,7 +616,7 @@ t } TEST_SUITE_END(); -//NOLINTEND(bugprone-unchecked-optional-access) +// NOLINTEND(bugprone-unchecked-optional-access) TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); @@ -710,6 +752,57 @@ tbl. CHECK_EQ(AutocompleteContext::Property, fragment.acResults.context); } +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "typecheck_fragment_handles_stale_module") +{ + const std::string sourceName = "MainModule"; + fileResolver.source[sourceName] = "local x = 5"; + + CheckResult checkResult = frontend.check(sourceName, getOptions()); + LUAU_REQUIRE_NO_ERRORS(checkResult); + + auto [result, _] = typecheckFragmentForModule(sourceName, fileResolver.source[sourceName], Luau::Position(0, 0)); + CHECK_EQ(result, FragmentTypeCheckStatus::Success); + + frontend.markDirty(sourceName); + frontend.parse(sourceName); + + CHECK_NE(frontend.getSourceModule(sourceName), nullptr); + + auto [result2, __] = typecheckFragmentForModule(sourceName, fileResolver.source[sourceName], Luau::Position(0, 0)); + CHECK_EQ(result2, FragmentTypeCheckStatus::SkipAutocomplete); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "typecheck_fragment_handles_unusable_module") +{ + const std::string sourceA = "MainModule"; + fileResolver.source[sourceA] = R"( +local Modules = game:GetService('Gui').Modules +local B = require(Modules.B) +return { hello = B } +)"; + + const std::string sourceB = "game/Gui/Modules/B"; + fileResolver.source[sourceB] = R"(return {hello = "hello"})"; + + CheckResult result = frontend.check(sourceA, getOptions()); + CHECK(!frontend.isDirty(sourceA, getOptions().forAutocomplete)); + + std::weak_ptr weakModule = getModuleResolver(frontend).getModule(sourceB); + REQUIRE(!weakModule.expired()); + + frontend.markDirty(sourceB); + CHECK(frontend.isDirty(sourceA, getOptions().forAutocomplete)); + + frontend.check(sourceB, getOptions()); + CHECK(weakModule.expired()); + + auto [status, _] = typecheckFragmentForModule(sourceA, fileResolver.source[sourceA], Luau::Position(0, 0)); + CHECK_EQ(status, FragmentTypeCheckStatus::SkipAutocomplete); + + auto [status2, _2] = typecheckFragmentForModule(sourceB, fileResolver.source[sourceB], Luau::Position(3, 20)); + CHECK_EQ(status2, FragmentTypeCheckStatus::Success); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("FragmentAutocompleteTests"); @@ -1677,4 +1770,187 @@ type A = <>random non code text here ); } +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_handles_stale_module") +{ + const std::string sourceName = "MainModule"; + fileResolver.source[sourceName] = "local x = 5"; + + frontend.check(sourceName, getOptions()); + frontend.markDirty(sourceName); + frontend.parse(sourceName); + + FragmentAutocompleteResult result = autocompleteFragmentForModule(sourceName, fileResolver.source[sourceName], Luau::Position(0, 0)); + CHECK(result.acResults.entryMap.empty()); + CHECK_EQ(result.incrementalModule, nullptr); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "require_tracing") +{ + fileResolver.source["MainModule/A"] = R"( +return { x = 0 } + )"; + + fileResolver.source["MainModule"] = R"( +local result = require(script.A) +local x = 1 + result. + )"; + + autocompleteFragmentInBothSolvers( + fileResolver.source["MainModule"], + fileResolver.source["MainModule"], + Position{2, 21}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.size() == 1); + CHECK(result.acResults.entryMap.count("x")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "fragment_ac_must_traverse_typeof_and_not_ice") +{ + // This test ensures that we traverse typeof expressions for defs that are being referred to in the fragment + // In this case, we want to ensure we populate the incremental environment with the reference to `m` + // Without this, we would ice as we will refer to the local `m` before it's declaration + ScopedFastFlag sff{FFlag::LuauMixedModeDefFinderTraversesTypeOf, true}; + const std::string source = R"( +--!strict +local m = {} +-- and here +function m:m1() end +type nt = typeof(m) + +return m +)"; + const std::string updated = R"( +--!strict +local m = {} +-- and here +function m:m1() end +type nt = typeof(m) +l +return m +)"; + + autocompleteFragmentInBothSolvers(source, updated, Position{6, 2}, [](auto& _) {}); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "generalization_crash_when_old_solver_freetypes_have_no_bounds_set") +{ + ScopedFastFlag sff{FFlag::LuauFreeTypesMustHaveBounds, true}; + const std::string source = R"( +local UserInputService = game:GetService("UserInputService"); + +local Camera = workspace.CurrentCamera; + +UserInputService.InputBegan:Connect(function(Input) + if (Input.KeyCode == Enum.KeyCode.One) then + local Up = Input.Foo + local Vector = -(Up:Unit) + end +end) +)"; + + const std::string dest = R"( +local UserInputService = game:GetService("UserInputService"); + +local Camera = workspace.CurrentCamera; + +UserInputService.InputBegan:Connect(function(Input) + if (Input.KeyCode == Enum.KeyCode.One) then + local Up = Input.Foo + local Vector = -(Up:Unit()) + end +end) +)"; + + autocompleteFragmentInBothSolvers(source, dest, Position{8, 36}, [](auto& _) {}); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_ensures_memory_isolation") +{ + ScopedFastFlag sff{FFlag::LuauCloneIncrementalModule, true}; + ToStringOptions opt; + opt.exhaustive = true; + opt.exhaustive = true; + opt.functionTypeArguments = true; + opt.maxTableLength = 0; + opt.maxTypeLength = 0; + + auto checkAndExamine = [&](const std::string& src, const std::string& idName, const std::string& idString) + { + check(src, getOptions()); + auto id = getType(idName, true); + LUAU_ASSERT(id); + CHECK_EQ(Luau::toString(*id, opt), idString); + }; + + auto getTypeFromModule = [](ModulePtr module, const std::string& name) -> std::optional + { + if (!module->hasModuleScope()) + return std::nullopt; + return lookupName(module->getModuleScope(), name); + }; + + auto fragmentACAndCheck = [&](const std::string& updated, const Position& pos, const std::string& idName) + { + FragmentAutocompleteResult result = autocompleteFragment(updated, pos, std::nullopt); + auto fragId = getTypeFromModule(result.incrementalModule, idName); + LUAU_ASSERT(fragId); + + auto srcId = getType(idName, true); + LUAU_ASSERT(srcId); + + CHECK((*fragId)->owningArena != (*srcId)->owningArena); + CHECK(&(result.incrementalModule->internalTypes) == (*fragId)->owningArena); + }; + + const std::string source = R"(local module = {} +f +return module)"; + + const std::string updated1 = R"(local module = {} +function module.a +return module)"; + + const std::string updated2 = R"(local module = {} +function module.ab +return module)"; + + { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + checkAndExamine(source, "module", "{ }"); + // [TODO] CLI-140762 we shouldn't mutate stale module in autocompleteFragment + // early return since the following checking will fail, which it shouldn't! + fragmentACAndCheck(updated1, Position{1, 17}, "module"); + fragmentACAndCheck(updated2, Position{1, 18}, "module"); + } + + { + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + checkAndExamine(source, "module", "{ }"); + // [TODO] CLI-140762 we shouldn't mutate stale module in autocompleteFragment + // early return since the following checking will fail, which it shouldn't! + fragmentACAndCheck(updated1, Position{1, 17}, "module"); + fragmentACAndCheck(updated2, Position{1, 18}, "module"); + } +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_shouldnt_crash_on_cross_module_mutation") +{ + ScopedFastFlag sff{FFlag::LuauCloneIncrementalModule, true}; + const std::string source = R"(local module = {} +function module. +return module +)"; + + const std::string updated = R"(local module = {} +function module.f +return module +)"; + + autocompleteFragmentInBothSolvers(source, updated, Position{1, 18}, [](FragmentAutocompleteResult& result) {}); +} + + TEST_SUITE_END(); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index fce96e48..9491e28a 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/DenseHash.h" #include "Luau/Frontend.h" #include "Luau/RequireTracer.h" @@ -17,6 +18,7 @@ LUAU_FASTFLAG(DebugLuauFreezeArena); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauReferenceAllocatorInNewSolver); LUAU_FASTFLAG(LuauSelectivelyRetainDFGArena) +LUAU_FASTFLAG(LuauBetterReverseDependencyTracking); namespace { @@ -1572,4 +1574,207 @@ return {x = a, y = b, z = c} CHECK(mod->keyArena.allocator.empty()); } +TEST_CASE_FIXTURE(FrontendFixture, "test_traverse_dependents") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + return require(game:GetService('Gui').Modules.A) + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B.hello} + )"; + fileResolver.source["game/Gui/Modules/D"] = R"( + local Modules = game:GetService('Gui').Modules + local C = require(Modules.C) + return {d_value = C.c_value} + )"; + + frontend.check("game/Gui/Modules/D"); + + std::vector visited; + frontend.traverseDependents( + "game/Gui/Modules/B", + [&visited](SourceNode& node) + { + visited.push_back(node.name); + return true; + } + ); + + CHECK_EQ(std::vector{"game/Gui/Modules/B", "game/Gui/Modules/C", "game/Gui/Modules/D"}, visited); +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_traverse_dependents_early_exit") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + return require(game:GetService('Gui').Modules.A) + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B.hello} + )"; + + frontend.check("game/Gui/Modules/C"); + + std::vector visited; + frontend.traverseDependents( + "game/Gui/Modules/A", + [&visited](SourceNode& node) + { + visited.push_back(node.name); + return node.name != "game/Gui/Modules/B"; + } + ); + + CHECK_EQ(std::vector{"game/Gui/Modules/A", "game/Gui/Modules/B"}, visited); +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_dependents_stored_on_node_as_graph_updates") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + + auto updateSource = [&](const std::string& name, const std::string& source) + { + fileResolver.source[name] = source; + frontend.markDirty(name); + }; + + auto validateMatchesRequireLists = [&](const std::string& message) + { + DenseHashMap> dependents{{}}; + for (const auto& module : frontend.sourceNodes) + { + for (const auto& dep : module.second->requireSet) + dependents[dep].push_back(module.first); + } + + for (const auto& module : frontend.sourceNodes) + { + Set& dependentsForModule = module.second->dependents; + for (const auto& dep : dependents[module.first]) + CHECK_MESSAGE(1 == dependentsForModule.count(dep), "Mismatch in dependents for " << module.first << ": " << message); + } + }; + + auto validateSecondDependsOnFirst = [&](const std::string& from, const std::string& to, bool expected) + { + SourceNode& fromNode = *frontend.sourceNodes[from]; + CHECK_MESSAGE( + fromNode.dependents.count(to) == int(expected), + "Expected " << from << " to " << (expected ? std::string() : std::string("not ")) << "have a reverse dependency on " << to + ); + }; + + // C -> B -> A + { + updateSource("game/Gui/Modules/A", "return {hello=5, world=true}"); + updateSource("game/Gui/Modules/B", R"( + return require(game:GetService('Gui').Modules.A) + )"); + updateSource("game/Gui/Modules/C", R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B} + )"); + frontend.check("game/Gui/Modules/C"); + + validateMatchesRequireLists("Initial check"); + + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/B", true); + validateSecondDependsOnFirst("game/Gui/Modules/B", "game/Gui/Modules/C", true); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/A", false); + } + + // C -> B, A + { + updateSource("game/Gui/Modules/B", R"( + return 1 + )"); + frontend.check("game/Gui/Modules/C"); + + validateMatchesRequireLists("Removing dependency B->A"); + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/B", false); + } + + // C -> B -> A + { + updateSource("game/Gui/Modules/B", R"( + return require(game:GetService('Gui').Modules.A) + )"); + frontend.check("game/Gui/Modules/C"); + + validateMatchesRequireLists("Adding back B->A"); + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/B", true); + } + + // C -> B -> A, D -> (C,B,A) + { + updateSource("game/Gui/Modules/D", R"( + local C = require(game:GetService('Gui').Modules.C) + local B = require(game:GetService('Gui').Modules.B) + local A = require(game:GetService('Gui').Modules.A) + return {d_value = C.c_value} + )"); + frontend.check("game/Gui/Modules/D"); + + validateMatchesRequireLists("Adding D->C, D->B, D->A"); + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/D", true); + validateSecondDependsOnFirst("game/Gui/Modules/B", "game/Gui/Modules/D", true); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/D", true); + } + + // B -> A, C <-> D + { + updateSource("game/Gui/Modules/D", "return require(game:GetService('Gui').Modules.C)"); + updateSource("game/Gui/Modules/C", "return require(game:GetService('Gui').Modules.D)"); + frontend.check("game/Gui/Modules/D"); + + validateMatchesRequireLists("Adding cycle D->C, C->D"); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/D", true); + validateSecondDependsOnFirst("game/Gui/Modules/D", "game/Gui/Modules/C", true); + } + + // B -> A, C -> D, D -> error + { + updateSource("game/Gui/Modules/D", "return require(game:GetService('Gui').Modules.C.)"); + frontend.check("game/Gui/Modules/D"); + + validateMatchesRequireLists("Adding error dependency D->C."); + validateSecondDependsOnFirst("game/Gui/Modules/D", "game/Gui/Modules/C", true); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/D", false); + } +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_invalid_dependency_tracking_per_module_resolver") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + ScopedFastFlag newSolver{FFlag::LuauSolverV2, false}; + + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = "return require(game:GetService('Gui').Modules.A)"; + + FrontendOptions opts; + opts.forAutocomplete = false; + + frontend.check("game/Gui/Modules/B", opts); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/B", opts.forAutocomplete)); + CHECK(!frontend.allModuleDependenciesValid("game/Gui/Modules/B", !opts.forAutocomplete)); + + opts.forAutocomplete = true; + frontend.check("game/Gui/Modules/A", opts); + + CHECK(!frontend.allModuleDependenciesValid("game/Gui/Modules/B", opts.forAutocomplete)); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/B", !opts.forAutocomplete)); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/A", !opts.forAutocomplete)); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/A", opts.forAutocomplete)); +} + TEST_SUITE_END(); diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp index 1388b900..b9e4eaf1 100644 --- a/tests/Generalization.test.cpp +++ b/tests/Generalization.test.cpp @@ -179,9 +179,9 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "functions_containing_cyclic_tables_can TEST_CASE_FIXTURE(GeneralizationFixture, "union_type_traversal_doesnt_crash") { // t1 where t1 = ('h <: (t1 <: 'i)) | ('j <: (t1 <: 'i)) - TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId i = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId h = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId j = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); TypeId unionType = arena.addType(UnionType{{h, j}}); getMutable(h)->upperBound = i; getMutable(h)->lowerBound = builtinTypes.neverType; @@ -196,9 +196,9 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "union_type_traversal_doesnt_crash") TEST_CASE_FIXTURE(GeneralizationFixture, "intersection_type_traversal_doesnt_crash") { // t1 where t1 = ('h <: (t1 <: 'i)) & ('j <: (t1 <: 'i)) - TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId i = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId h = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId j = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); TypeId intersectionType = arena.addType(IntersectionType{{h, j}}); getMutable(h)->upperBound = i; diff --git a/tests/Instantiation2.test.cpp b/tests/Instantiation2.test.cpp index fff98e60..fcd136fb 100644 --- a/tests/Instantiation2.test.cpp +++ b/tests/Instantiation2.test.cpp @@ -4,6 +4,7 @@ #include "Fixture.h" #include "ClassFixture.h" +#include "Luau/Type.h" #include "ScopedFlags.h" #include "doctest.h" @@ -29,7 +30,7 @@ TEST_CASE_FIXTURE(Fixture, "weird_cyclic_instantiation") DenseHashMap genericSubstitutions{nullptr}; DenseHashMap genericPackSubstitutions{nullptr}; - TypeId freeTy = arena.freshType(&scope); + TypeId freeTy = arena.freshType(builtinTypes, &scope); FreeType* ft = getMutable(freeTy); REQUIRE(ft); ft->lowerBound = idTy; diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 803a9e97..6133305d 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -248,4 +248,185 @@ TEST_CASE("string_interpolation_with_unicode_escape") CHECK_EQ(lexer.next().type, Lexeme::Eof); } +TEST_CASE("single_quoted_string") +{ + const std::string testInput = "'test'"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::QuotedString); + CHECK_EQ(lexeme.getQuoteStyle(), Lexeme::QuoteStyle::Single); +} + +TEST_CASE("double_quoted_string") +{ + const std::string testInput = R"("test")"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::QuotedString); + CHECK_EQ(lexeme.getQuoteStyle(), Lexeme::QuoteStyle::Double); +} + +TEST_CASE("lexer_determines_string_block_depth_0") +{ + const std::string testInput = "[[ test ]]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_0_multiline_1") +{ + const std::string testInput = R"([[ test + ]])"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_0_multiline_2") +{ + const std::string testInput = R"([[ + test + ]])"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_0_multiline_3") +{ + const std::string testInput = R"([[ + test ]])"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_1") +{ + const std::string testInput = "[=[[%s]]=]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 1); +} + +TEST_CASE("lexer_determines_string_block_depth_2") +{ + const std::string testInput = "[==[ test ]==]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + +TEST_CASE("lexer_determines_string_block_depth_2_multiline_1") +{ + const std::string testInput = R"([==[ test + ]==])"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + +TEST_CASE("lexer_determines_string_block_depth_2_multiline_2") +{ + const std::string testInput = R"([==[ + test + ]==])"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + +TEST_CASE("lexer_determines_string_block_depth_2_multiline_3") +{ + const std::string testInput = R"([==[ + + test ]==])"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + + +TEST_CASE("lexer_determines_comment_block_depth_0") +{ + const std::string testInput = "--[[ test ]]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::BlockComment); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_1") +{ + const std::string testInput = "--[=[ μέλλον ]=]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::BlockComment); + CHECK_EQ(lexeme.getBlockDepth(), 1); +} + +TEST_CASE("lexer_determines_string_block_depth_2") +{ + const std::string testInput = "--[==[ test ]==]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::BlockComment); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 786d57b8..0e026edf 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -12,7 +12,7 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauNormalizationTracksCyclicPairsThroughInhabitance) +LUAU_FASTFLAG(LuauFixNormalizedIntersectionOfNegatedClass) using namespace Luau; namespace @@ -851,17 +851,17 @@ TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") { + ScopedFastFlag _{FFlag::LuauFixNormalizedIntersectionOfNegatedClass, true}; createSomeClasses(&frontend); CHECK("(Parent & ~Child) | Unrelated" == toString(normal("(Parent & Not) | Unrelated"))); CHECK("((class & ~Child) | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not"))); - CHECK("Child" == toString(normal("Not & Child"))); + CHECK("never" == toString(normal("Not & Child"))); CHECK("((class & ~Parent) | Child | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not | Child"))); CHECK("(boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not"))); CHECK( "(Parent | Unrelated | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not & Not & Not>")) ); - CHECK("Child" == toString(normal("(Child | Unrelated) & Not"))); } @@ -962,7 +962,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "final_types_are_cached") TEST_CASE_FIXTURE(NormalizeFixture, "non_final_types_can_be_normalized_but_are_not_cached") { - TypeId a = arena.freshType(&globalScope); + TypeId a = arena.freshType(builtinTypes, &globalScope); std::shared_ptr na1 = normalizer.normalize(a); std::shared_ptr na2 = normalizer.normalize(a); @@ -1034,7 +1034,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "normalizer_should_be_able_to_detect_cyclic_t if (!FFlag::LuauSolverV2) return; ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 0}; - ScopedFastFlag sff{FFlag::LuauNormalizationTracksCyclicPairsThroughInhabitance, true}; + CheckResult result = check(R"( --!strict diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 32476e86..2395efb6 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -21,6 +21,8 @@ LUAU_FASTFLAG(LuauErrorRecoveryForTableTypes) LUAU_FASTFLAG(LuauErrorRecoveryForClassNames) LUAU_FASTFLAG(LuauFixFunctionNameStartPosition) LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) +LUAU_FASTFLAG(LuauAstTypeGroup) namespace { @@ -369,7 +371,10 @@ TEST_CASE_FIXTURE(Fixture, "return_type_is_an_intersection_type_if_led_with_one_ AstTypeIntersection* returnAnnotation = annotation->returnTypes.types.data[0]->as(); REQUIRE(returnAnnotation != nullptr); - CHECK(returnAnnotation->types.data[0]->as()); + if (FFlag::LuauAstTypeGroup) + CHECK(returnAnnotation->types.data[0]->as()); + else + CHECK(returnAnnotation->types.data[0]->as()); CHECK(returnAnnotation->types.data[1]->as()); } @@ -2418,6 +2423,91 @@ TEST_CASE_FIXTURE(Fixture, "invalid_user_defined_type_functions") matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'"); } +TEST_CASE_FIXTURE(Fixture, "leading_union_intersection_with_single_type_preserves_the_union_intersection_ast_node") +{ + ScopedFastFlag _{FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType, true}; + AstStatBlock* block = parse(R"( + type Foo = | string + type Bar = & number + )"); + + REQUIRE_EQ(2, block->body.size); + + const auto alias1 = block->body.data[0]->as(); + REQUIRE(alias1); + + const auto unionType = alias1->type->as(); + REQUIRE(unionType); + CHECK_EQ(1, unionType->types.size); + + const auto alias2 = block->body.data[1]->as(); + REQUIRE(alias2); + + const auto intersectionType = alias2->type->as(); + REQUIRE(intersectionType); + CHECK_EQ(1, intersectionType->types.size); +} + +TEST_CASE_FIXTURE(Fixture, "parse_simple_ast_type_group") +{ + ScopedFastFlag _{FFlag::LuauAstTypeGroup, true}; + + AstStatBlock* stat = parse(R"( + type Foo = (string) + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto alias1 = stat->body.data[0]->as(); + REQUIRE(alias1); + + auto group1 = alias1->type->as(); + REQUIRE(group1); + CHECK(group1->type->is()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_nested_ast_type_group") +{ + ScopedFastFlag _{FFlag::LuauAstTypeGroup, true}; + + AstStatBlock* stat = parse(R"( + type Foo = ((string)) + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto alias1 = stat->body.data[0]->as(); + REQUIRE(alias1); + + auto group1 = alias1->type->as(); + REQUIRE(group1); + + auto group2 = group1->type->as(); + REQUIRE(group2); + CHECK(group2->type->is()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_return_type_ast_type_group") +{ + ScopedFastFlag _{FFlag::LuauAstTypeGroup, true}; + + AstStatBlock* stat = parse(R"( + type Foo = () -> (string) + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto alias1 = stat->body.data[0]->as(); + REQUIRE(alias1); + + auto funcType = alias1->type->as(); + REQUIRE(funcType); + + REQUIRE_EQ(1, funcType->returnTypes.types.size); + REQUIRE(!funcType->returnTypes.tailType); + CHECK(funcType->returnTypes.types.data[0]->is()); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -3688,7 +3778,14 @@ TEST_CASE_FIXTURE(Fixture, "grouped_function_type") auto unionTy = paramTy.type->as(); LUAU_ASSERT(unionTy); CHECK_EQ(unionTy->types.size, 2); - CHECK(unionTy->types.data[0]->is()); // () -> () + if (FFlag::LuauAstTypeGroup) + { + auto groupTy = unionTy->types.data[0]->as(); // (() -> ()) + REQUIRE(groupTy); + CHECK(groupTy->type->is()); // () -> () + } + else + CHECK(unionTy->types.data[0]->is()); // () -> () CHECK(unionTy->types.data[1]->is()); // nil } diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index c9eb3450..11027e6f 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -324,8 +324,7 @@ TEST_CASE_FIXTURE(Fixture, "free") { DOES_NOT_PASS_NEW_SOLVER_GUARD(); - Type type{TypeVariant{FreeType{TypeLevel{0, 0}}}}; - + Type type{TypeVariant{FreeType{TypeLevel{0, 0}, builtinTypes->neverType, builtinTypes->unknownType}}}; ToDotOptions opts; opts.showPointers = false; CHECK_EQ( diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index dc63be77..3505f96d 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -12,6 +12,11 @@ using namespace Luau; +LUAU_FASTFLAG(LuauStoreCSTData) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAG(LuauAstTypeGroup); +LUAU_FASTFLAG(LexerFixInterpStringStart) + TEST_SUITE_BEGIN("TranspilerTests"); TEST_CASE("test_1") @@ -42,6 +47,37 @@ TEST_CASE("string_literals_containing_utf8") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("if_stmt_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( if This then Once() end)"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( if This then Once() end)"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( if This then Once() end)"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( if This then Once() end)"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( if This then Once() else Other() end)"; + CHECK_EQ(five, transpile(five).code); + + const std::string six = R"( if This then Once() else Other() end)"; + CHECK_EQ(six, transpile(six).code); + + const std::string seven = R"( if This then Once() elseif true then Other() end)"; + CHECK_EQ(seven, transpile(seven).code); + + const std::string eight = R"( if This then Once() elseif true then Other() end)"; + CHECK_EQ(eight, transpile(eight).code); + + const std::string nine = R"( if This then Once() elseif true then Other() end)"; + CHECK_EQ(nine, transpile(nine).code); +} + TEST_CASE("elseif_chains_indent_sensibly") { const std::string code = R"( @@ -62,17 +98,31 @@ TEST_CASE("elseif_chains_indent_sensibly") TEST_CASE("strips_type_annotations") { const std::string code = R"( local s: string= 'hello there' )"; - const std::string expected = R"( local s = 'hello there' )"; - - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + const std::string expected = R"( local s = 'hello there' )"; + CHECK_EQ(expected, transpile(code).code); + } + else + { + const std::string expected = R"( local s = 'hello there' )"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE("strips_type_assertion_expressions") { const std::string code = R"( local s= some_function() :: any+ something_else() :: number )"; - const std::string expected = R"( local s= some_function() + something_else() )"; - - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + const std::string expected = R"( local s= some_function() + something_else() )"; + CHECK_EQ(expected, transpile(code).code); + } + else + { + const std::string expected = R"( local s= some_function() + something_else() )"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE("function_taking_ellipsis") @@ -97,24 +147,89 @@ TEST_CASE("for_loop") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("for_loop_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( for index = 1, 10 do call(index) end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( for index = 1 , 10 do call(index) end )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( for index = 1, 10 , 3 do call(index) end )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( for index = 1, 10 do call(index) end )"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( for index = 1, 10 do call(index) end )"; + CHECK_EQ(five, transpile(five).code); +} + TEST_CASE("for_in_loop") { const std::string code = R"( for k, v in ipairs(x)do end )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("for_in_loop_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( for k, v in ipairs(x) do end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( for k, v in ipairs(x) do end )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( for k , v in ipairs(x) do end )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( for k, v in next , t do end )"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( for k, v in ipairs(x) do end )"; + CHECK_EQ(five, transpile(five).code); +} + TEST_CASE("while_loop") { const std::string code = R"( while f(x)do print() end )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("while_loop_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( while f(x) do print() end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( while f(x) do print() end )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( while f(x) do print() end )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( while f(x) do print() end )"; + CHECK_EQ(four, transpile(four).code); +} + TEST_CASE("repeat_until_loop") { const std::string code = R"( repeat print() until f(x) )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("repeat_until_loop_condition_on_new_line") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + repeat + print() + until + f(x) )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("lambda") { const std::string one = R"( local p=function(o, m, g) return 77 end )"; @@ -124,6 +239,43 @@ TEST_CASE("lambda") CHECK_EQ(two, transpile(two).code); } +TEST_CASE("local_assignment") +{ + const std::string one = R"( local x = 1 )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local x, y, z = 1, 2, 3 )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( local x )"; + CHECK_EQ(three, transpile(three).code); +} + +TEST_CASE("local_assignment_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( local x = 1 )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local x = 1 )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( local x = 1 )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( local x , y = 1, 2 )"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( local x, y = 1, 2 )"; + CHECK_EQ(five, transpile(five).code); + + const std::string six = R"( local x, y = 1 , 2 )"; + CHECK_EQ(six, transpile(six).code); + + const std::string seven = R"( local x, y = 1, 2 )"; + CHECK_EQ(seven, transpile(seven).code); +} + TEST_CASE("local_function") { const std::string one = R"( local function p(o, m, g) return 77 end )"; @@ -133,6 +285,16 @@ TEST_CASE("local_function") CHECK_EQ(two, transpile(two).code); } +TEST_CASE("local_function_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( local function p(o, m, ...) end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local function p(o, m, ...) end )"; + CHECK_EQ(two, transpile(two).code); +} + TEST_CASE("function") { const std::string one = R"( function p(o, m, g) return 77 end )"; @@ -142,6 +304,19 @@ TEST_CASE("function") CHECK_EQ(two, transpile(two).code); } +TEST_CASE("returns_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( return 1 )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( return 1 , 2 )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( return 1, 2 )"; + CHECK_EQ(three, transpile(three).code); +} + TEST_CASE("table_literals") { const std::string code = R"( local t={1, 2, 3, foo='bar', baz=99,[5.5]='five point five', 'end'} )"; @@ -184,6 +359,59 @@ TEST_CASE("table_literal_closing_brace_at_correct_position") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("table_literal_with_semicolon_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1; y = 2 } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_trailing_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1, y = 2, } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_spaces_around_separator") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1 , y = 2 } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_spaces_around_equals") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1 } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_multiline_with_indexers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { + ["my first value"] = "x"; + ["my second value"] = "y"; + } + )"; + + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("method_calls") { const std::string code = R"( foo.bar.baz:quux() )"; @@ -201,8 +429,15 @@ TEST_CASE("spaces_between_keywords_even_if_it_pushes_the_line_estimation_off") // Luau::Parser doesn't exactly preserve the string representation of numbers in Lua, so we can find ourselves // falling out of sync with the original code. We need to push keywords out so that there's at least one space between them. const std::string code = R"( if math.abs(raySlope) < .01 then return 0 end )"; - const std::string expected = R"( if math.abs(raySlope) < 0.01 then return 0 end)"; - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ(code, transpile(code).code); + } + else + { + const std::string expected = R"( if math.abs(raySlope) < 0.01 then return 0 end)"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE("numbers") @@ -214,8 +449,70 @@ TEST_CASE("numbers") TEST_CASE("infinity") { const std::string code = R"( local a = 1e500 local b = 1e400 )"; - const std::string expected = R"( local a = 1e500 local b = 1e500 )"; - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ(code, transpile(code).code); + } + else + { + const std::string expected = R"( local a = 1e500 local b = 1e500 )"; + CHECK_EQ(expected, transpile(code).code); + } +} + +TEST_CASE("numbers_with_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = 123_456_789 )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("hexadecimal_numbers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = 0xFFFF )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("binary_numbers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = 0b0101 )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("single_quoted_strings") +{ + const std::string code = R"( local a = 'hello world' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("double_quoted_strings") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = "hello world" )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("simple_interp_string") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = `hello world` )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("raw_strings") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = [[ hello world ]] )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("raw_strings_with_blocks") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = [==[ hello world ]==] )"; + CHECK_EQ(code, transpile(code).code); } TEST_CASE("escaped_strings") @@ -230,6 +527,33 @@ TEST_CASE("escaped_strings_2") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("escaped_strings_newline") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + print("foo \ + bar") + )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("escaped_strings_raw") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local x = [=[\v<((do|load)file|require)\s*\(?['"]\zs[^'"]+\ze['"]]=] )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("position_correctly_updated_when_writing_multiline_string") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + call([[ + testing + ]]) )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("need_a_space_between_number_literals_and_dots") { const std::string code = R"( return point and math.ceil(point* 100000* 100)/ 100000 .. '%'or '' )"; @@ -242,6 +566,86 @@ TEST_CASE("binary_keywords") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("function_call_parentheses_no_args") +{ + const std::string code = R"( call() )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_one_arg") +{ + const std::string code = R"( call(arg) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_multiple_args") +{ + const std::string code = R"( call(arg1, arg3, arg3) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_multiple_args_no_space") +{ + const std::string code = R"( call(arg1,arg3,arg3) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_multiple_args_space_before_commas") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call(arg1 ,arg3 ,arg3) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_spaces_before_parentheses") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call () )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_spaces_within_parentheses") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call( ) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_string_double_quotes") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call "string" )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_string_single_quotes") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call 'string' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_string_no_space") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call'string' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_table_literal") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call { x = 1 } )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_table_literal_no_space") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call{x=1} )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("do_blocks") { const std::string code = R"( @@ -258,6 +662,19 @@ TEST_CASE("do_blocks") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("nested_do_block") +{ + const std::string code = R"( + do + do + local x = 1 + end + end + )"; + + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("emit_a_do_block_in_cases_of_potentially_ambiguous_syntax") { const std::string code = R"( @@ -267,6 +684,106 @@ TEST_CASE("emit_a_do_block_in_cases_of_potentially_ambiguous_syntax") CHECK_EQ(code, transpile(code).code); } +TEST_CASE_FIXTURE(Fixture, "parentheses_multiline") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( +local test = ( + x +) + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "stmt_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( local test = 1; )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local test = 1 ; )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "do_block_ending_with_semicolon") +{ + std::string code = R"( + do + return; + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_stmt_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + if init then + x = string.sub(x, utf8.offset(x, init)); + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_stmt_semicolon_2") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + if (t < 1) then return c/2*t*t + b end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "for_loop_stmt_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + for i,v in ... do + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "while_do_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + while true do + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "function_definition_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + function foo() + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE("roundtrip_types") { const std::string code = R"( @@ -337,9 +854,16 @@ TEST_CASE("a_table_key_can_be_the_empty_string") TEST_CASE("always_emit_a_space_after_local_keyword") { std::string code = "do local aZZZZ = Workspace.P1.Shape local bZZZZ = Enum.PartType.Cylinder end"; - std::string expected = "do local aZZZZ = Workspace.P1 .Shape local bZZZZ= Enum.PartType.Cylinder end"; - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ(code, transpile(code).code); + } + else + { + std::string expected = "do local aZZZZ = Workspace.P1 .Shape local bZZZZ= Enum.PartType.Cylinder end"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE_FIXTURE(Fixture, "types_should_not_be_considered_cyclic_if_they_are_not_recursive") @@ -427,6 +951,80 @@ TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") CHECK_EQ(code, transpile(code).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else_multiple_conditions") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else_multiple_conditions_2") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + local x = if yes + then nil + else if no + then if this + then that + else other + else nil + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_then_else_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_then_else_spaces_between_else_if") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + return + if a then "was a" else + if b then "was b" else + if c then "was c" else + "was nothing!" + )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_type_reference_import") { fileResolver.source["game/A"] = R"( @@ -442,6 +1040,34 @@ local a: Import.Type CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_reference_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( local _: Foo.Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Foo .Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Foo. Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type <> )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type< > )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type< number> )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_type_packs") { std::string code = R"( @@ -471,7 +1097,10 @@ TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_3") { std::string code = "local a: nil | (string & number)"; - CHECK_EQ("local a: ( string & number)?", transpile(code, {}, true).code); + if (FFlag::LuauAstTypeGroup) + CHECK_EQ("local a: (string & number)?", transpile(code, {}, true).code); + else + CHECK_EQ("local a: ( string & number)?", transpile(code, {}, true).code); } TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested") @@ -495,6 +1124,26 @@ TEST_CASE_FIXTURE(Fixture, "transpile_varargs") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "index_name_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = "local _ = a.name"; + CHECK_EQ(one, transpile(one, {}, true).code); + + std::string two = "local _ = a .name"; + CHECK_EQ(two, transpile(two, {}, true).code); + + std::string three = "local _ = a. name"; + CHECK_EQ(three, transpile(three, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "index_name_ends_with_digit") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "sparkles.Color = Color3.new()"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_index_expr") { std::string code = "local a = {1, 2, 3} local b = a[2]"; @@ -502,6 +1151,22 @@ TEST_CASE_FIXTURE(Fixture, "transpile_index_expr") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "index_expr_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = "local _ = a[2]"; + CHECK_EQ(one, transpile(one, {}, true).code); + + std::string two = "local _ = a [2]"; + CHECK_EQ(two, transpile(two, {}, true).code); + + std::string three = "local _ = a[ 2]"; + CHECK_EQ(three, transpile(three, {}, true).code); + + std::string four = "local _ = a[2 ]"; + CHECK_EQ(four, transpile(four, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_unary") { std::string code = R"( @@ -516,6 +1181,32 @@ local d = #e CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "unary_spaces_around_tokens") +{ + std::string code = R"( +local _ = -1 +local _ = - 1 +local _ = not true +local _ = not true +local _ = #e +local _ = # e + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "binary_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( +local _ = 1+1 +local _ = 1 +1 +local _ = 1+ 1 + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_break_continue") { std::string code = R"( @@ -546,6 +1237,16 @@ a ..= ' - result' CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "compound_assignment_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = R"( a += 1 )"; + CHECK_EQ(one, transpile(one, {}, true).code); + + std::string two = R"( a += 1 )"; + CHECK_EQ(two, transpile(two, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") { std::string code = "a, b, c = 1, 2, 3"; @@ -553,6 +1254,31 @@ TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_assign_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = "a = 1"; + CHECK_EQ(one, transpile(one).code); + + std::string two = "a = 1"; + CHECK_EQ(two, transpile(two).code); + + std::string three = "a = 1"; + CHECK_EQ(three, transpile(three).code); + + std::string four = "a , b = 1, 2"; + CHECK_EQ(four, transpile(four).code); + + std::string five = "a, b = 1, 2"; + CHECK_EQ(five, transpile(five).code); + + std::string six = "a, b = 1 , 2"; + CHECK_EQ(six, transpile(six).code); + + std::string seven = "a, b = 1, 2"; + CHECK_EQ(seven, transpile(seven).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_generic_function") { std::string code = R"( @@ -682,13 +1408,58 @@ TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple_types") TEST_CASE_FIXTURE(Fixture, "transpile_string_interp") { + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; std::string code = R"( local _ = `hello {name}` )"; CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp_multiline") +{ + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; + std::string code = R"( local _ = `hello { + name + }!` )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp_on_new_line") +{ + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; + std::string code = R"( + error( + `a {b} c` + ) + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp_multiline_escape") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( local _ = `hello \ + world!` )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") { + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; std::string code = R"( local _ = ` bracket = \{, backtick = \` = {'ok'} ` )"; CHECK_EQ(code, transpile(code, {}, true).code); @@ -701,4 +1472,186 @@ TEST_CASE_FIXTURE(Fixture, "transpile_type_functions") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_typeof_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type X = typeof(x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof(x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof (x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof( x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof(x ) )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_single_quoted_string_types") +{ + const std::string code = R"( type a = 'hello world' )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_double_quoted_string_types") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( type a = "hello world" )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_raw_string_types") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type a = [[ hello world ]] )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type a = [==[ hello world ]==] )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_escaped_string_types") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( type a = "\\b\\t\\n\\\\" )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_semicolon_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + type Foo = { + bar: number; + baz: number; + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_access_modifiers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + read bar: number, + write baz: number, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { read string } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { + read [string]: number, + read ["property"]: number + } )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_spaces_between_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar : number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number , } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [ string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string ]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string] : number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_preserve_original_indexer_style") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + [number]: string + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( + type Foo = { { number } } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_preserve_indexer_location") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + [number]: string, + property: number, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( + type Foo = { + property: number, + [number]: string, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( + type Foo = { + property: number, + [number]: string, + property2: number, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_preserve_property_definition_style") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + ["$$typeof1"]: string, + ['$$typeof2']: string, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TypeFunction.test.cpp b/tests/TypeFunction.test.cpp index 0d717629..096b3876 100644 --- a/tests/TypeFunction.test.cpp +++ b/tests/TypeFunction.test.cpp @@ -14,6 +14,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) LUAU_DYNAMIC_FASTINT(LuauTypeFamilyApplicationCartesianProductLimit) +LUAU_FASTFLAG(LuauMetatableTypeFunctions) struct TypeFunctionFixture : Fixture { @@ -1262,4 +1263,195 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_len_type_function_follow") )"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_assigns_correct_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, { __index: {} }> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId id = requireTypeAlias("Identity"); + CHECK_EQ(toString(id, {true}), "{ @metatable { __index: { } }, { } }"); + const MetatableType* mt = get(id); + REQUIRE(mt); + CHECK_EQ(toString(mt->metatable), "{ __index: { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_assigns_correct_metatable_2") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, { __index: {} }> + type FooBar = setmetatable<{}, Identity> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId id = requireTypeAlias("Identity"); + CHECK_EQ(toString(id, {true}), "{ @metatable { __index: { } }, { } }"); + const MetatableType* mt = get(id); + REQUIRE(mt); + CHECK_EQ(toString(mt->metatable), "{ __index: { } }"); + + TypeId foobar = requireTypeAlias("FooBar"); + const MetatableType* mt2 = get(foobar); + REQUIRE(mt2); + CHECK_EQ(toString(mt2->metatable, {true}), "{ @metatable { __index: { } }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_errors_on_metatable_with_metatable_metamethod") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, { __metatable: "blocked" }> + type Bad = setmetatable + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeId id = requireTypeAlias("Identity"); + CHECK_EQ(toString(id, {true}), "{ @metatable { __metatable: \"blocked\" }, { } }"); + const MetatableType* mt = get(id); + REQUIRE(mt); + CHECK_EQ(toString(mt->metatable), "{ __metatable: \"blocked\" }"); +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_errors_on_invalid_set") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_errors_on_nontable_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, string> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_type_function_returns_nil_if_no_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type TableWithNoMetatable = getmetatable<{}> + type NumberWithNoMetatable = getmetatable + type BooleanWithNoMetatable = getmetatable + type BooleanLiteralWithNoMetatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tableResult = requireTypeAlias("TableWithNoMetatable"); + CHECK_EQ(toString(tableResult), "nil"); + + auto numberResult = requireTypeAlias("NumberWithNoMetatable"); + CHECK_EQ(toString(numberResult), "nil"); + + auto booleanResult = requireTypeAlias("BooleanWithNoMetatable"); + CHECK_EQ(toString(booleanResult), "nil"); + + auto booleanLiteralResult = requireTypeAlias("BooleanLiteralWithNoMetatable"); + CHECK_EQ(toString(booleanLiteralResult), "nil"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_correct_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + local metatable = { __index = { w = 4 } } + local obj = setmetatable({x = 1, y = 2, z = 3}, metatable) + type Metatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireTypeAlias("Metatable"), {true}), "{ __index: { w: number } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_correct_metatable_for_union") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, {}> + type Metatable = getmetatable + type IntersectMetatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const PrimitiveType* stringType = get(builtinTypes->stringType); + REQUIRE(stringType->metatable); + + TypeArena arena = TypeArena{}; + + std::string expected1 = toString(arena.addType(UnionType{{*stringType->metatable, builtinTypes->emptyTableType}}), {true}); + CHECK_EQ(toString(requireTypeAlias("Metatable"), {true}), expected1); + + std::string expected2 = toString(arena.addType(IntersectionType{{*stringType->metatable, builtinTypes->emptyTableType}}), {true}); + CHECK_EQ(toString(requireTypeAlias("IntersectMetatable"), {true}), expected2); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_correct_metatable_for_string") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Metatable = getmetatable + type Metatable2 = getmetatable<"foo"> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const PrimitiveType* stringType = get(builtinTypes->stringType); + REQUIRE(stringType->metatable); + + std::string expected = toString(*stringType->metatable, {true}); + + CHECK_EQ(toString(requireTypeAlias("Metatable"), {true}), expected); + CHECK_EQ(toString(requireTypeAlias("Metatable2"), {true}), expected); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_respects_metatable_metamethod") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + local metatable = { __metatable = "Test" } + local obj = setmetatable({x = 1, y = 2, z = 3}, metatable) + type Metatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireTypeAlias("Metatable")), "string"); +} + TEST_SUITE_END(); diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp index 98c8d2a8..bdde63f5 100644 --- a/tests/TypeFunction.user.test.cpp +++ b/tests/TypeFunction.user.test.cpp @@ -8,7 +8,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite) LUAU_FASTFLAG(LuauUserTypeFunFixInner) LUAU_FASTFLAG(LuauUserTypeFunGenerics) LUAU_FASTFLAG(LuauUserTypeFunCloneTail) @@ -667,7 +666,6 @@ TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works") TEST_CASE_FIXTURE(ClassFixture, "write_of_readonly_is_nil") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag udtfRwFix{FFlag::LuauUserTypeFunFixNoReadWrite, true}; CheckResult result = check(R"( type function getclass(arg) diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index a9e5951a..3972fd6b 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -2,6 +2,7 @@ #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/AstQuery.h" @@ -9,6 +10,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauFixInfiniteRecursionInNormalization) TEST_SUITE_BEGIN("TypeAliases"); @@ -1178,4 +1180,33 @@ TEST_CASE_FIXTURE(Fixture, "bound_type_in_alias_segfault") )")); } +TEST_CASE_FIXTURE(BuiltinsFixture, "gh1632_no_infinite_recursion_in_normalization") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauFixInfiniteRecursionInNormalization, true}, + }; + + CheckResult result = check(R"( + type Node = { + value: T, + next: Node?, + -- remove `prev`, solves issue + prev: Node?, + }; + + type List = { + head: Node? + } + + local function IsFront(list: List, nodeB: Node) + -- remove if statement below, solves issue + if (list.head == nodeB) then + end + end + )"); + + LUAU_CHECK_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 1e67739d..96443aeb 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -10,10 +10,9 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauTypestateBuiltins2) -LUAU_FASTFLAG(LuauStringFormatArityFix) -LUAU_FASTFLAG(LuauTableCloneClonesType2) +LUAU_FASTFLAG(LuauTableCloneClonesType3) LUAU_FASTFLAG(LuauStringFormatErrorSuppression) +LUAU_FASTFLAG(LuauFreezeIgnorePersistent) TEST_SUITE_BEGIN("BuiltinTests"); @@ -809,8 +808,6 @@ TEST_CASE_FIXTURE(Fixture, "string_format_as_method") TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_trivial_arity") { - ScopedFastFlag sff{FFlag::LuauStringFormatArityFix, true}; - CheckResult result = check(R"( string.format() )"); @@ -1134,15 +1131,13 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins2) + if (FFlag::LuauSolverV2) CHECK("Key 'b' not found in table '{ read a: number }'" == toString(result.errors[0])); - else if (FFlag::LuauSolverV2) - CHECK("Key 'b' not found in table '{ a: number }'" == toString(result.errors[0])); else CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); CHECK(Location({13, 18}, {13, 23}) == result.errors[0].location); - if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins2) + if (FFlag::LuauSolverV2) { CHECK_EQ("{ read a: number }", toString(requireTypeAtPosition({15, 19}))); CHECK_EQ("{ read b: string }", toString(requireTypeAtPosition({16, 19}))); @@ -1178,7 +1173,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_does_not_retroactively_block_mu LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins2) + if (FFlag::LuauSolverV2) { CHECK_EQ("{ a: number, q: string } | { read a: number, read q: string }", toString(requireType("t1"), {/*exhaustive */ true})); // before the assignment, it's `t1` @@ -1208,8 +1203,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_no_generic_table") end )"); - if (FFlag::LuauTypestateBuiltins2) - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_on_metatable") @@ -1236,13 +1230,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_errors_on_no_args") table.freeze() )"); - // this does not error in the new solver without the typestate builtins functionality. - if (FFlag::LuauSolverV2 && !FFlag::LuauTypestateBuiltins2) - { - LUAU_REQUIRE_NO_ERRORS(result); - return; - } - LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK(get(result.errors[0])); @@ -1255,25 +1242,40 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_errors_on_non_tables") table.freeze(42) )"); - // this does not error in the new solver without the typestate builtins functionality. - if (FFlag::LuauSolverV2 && !FFlag::LuauTypestateBuiltins2) - { - LUAU_REQUIRE_NO_ERRORS(result); - return; - } - LUAU_REQUIRE_ERROR_COUNT(1, result); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins2) + if (FFlag::LuauSolverV2) CHECK_EQ(toString(tm->wantedType), "table"); else CHECK_EQ(toString(tm->wantedType), "{- -}"); CHECK_EQ(toString(tm->givenType), "number"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_persistent_skip") +{ + ScopedFastFlag luauFreezeIgnorePersistent{FFlag::LuauFreezeIgnorePersistent, true}; + + CheckResult result = check(R"( + table.freeze(table) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_clone_persistent_skip") +{ + ScopedFastFlag luauFreezeIgnorePersistent{FFlag::LuauFreezeIgnorePersistent, true}; + + CheckResult result = check(R"( + table.clone(table) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { // In the new solver, nil can certainly be used where a generic is required, so all generic parameters are optional. @@ -1599,7 +1601,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_clone_type_states") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauTableCloneClonesType2) + if (FFlag::LuauTableCloneClonesType3) { CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, z: number }"); CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number }"); @@ -1611,6 +1613,40 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_clone_type_states") } } +TEST_CASE_FIXTURE(BuiltinsFixture, "table_clone_should_not_break") +{ + CheckResult result = check(R"( + local Immutable = {} + + function Immutable.Set(dictionary, key, value) + local new = table.clone(dictionary) + + new[key] = value + + return new + end + + return Immutable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_clone_should_not_break_2") +{ + CheckResult result = check(R"( + function set(dictionary, key, value) + local new = table.clone(dictionary) + + new[key] = value + + return new + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_should_support_any") { ScopedFastFlag _{FFlag::LuauSolverV2, true}; diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 2ab90ab5..ce1cef29 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -10,6 +10,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauNewSolverPrePopulateClasses) +LUAU_FASTFLAG(LuauClipNestedAndRecursiveUnion) TEST_SUITE_BEGIN("DefinitionTests"); @@ -541,4 +542,20 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_has_source_module_name_set") CHECK_EQ(ctv->definitionModuleName, "@test"); } +TEST_CASE_FIXTURE(Fixture, "recursive_redefinition_reduces_rightfully") +{ + ScopedFastFlag _{FFlag::LuauClipNestedAndRecursiveUnion, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local t: {[string]: string} = {} + + local function f() + t = t + end + + t = t + )")); +} + + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index bc1d55dd..942ef6a7 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -19,7 +19,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTINT(LuauTarjanChildLimit) -LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) LUAU_FASTFLAG(DebugLuauEqSatSimplification) TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -3012,9 +3011,6 @@ local u,v = id(3), id(id(44)) TEST_CASE_FIXTURE(Fixture, "hidden_variadics_should_not_break_subtyping") { - // Only applies to new solver. - ScopedFastFlag sff{FFlag::LuauRetrySubtypingWithoutHiddenPack, true}; - CheckResult result = check(R"( --!strict type FooType = { diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index a9109e1d..e5fdbdd3 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -12,7 +12,6 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) using namespace Luau; @@ -185,10 +184,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cross_module_table_freeze") ModulePtr b = frontend.moduleResolver.getModule("game/B"); REQUIRE(b != nullptr); // confirm that no cross-module mutation happened here! - if (FFlag::LuauSolverV2 && FFlag::LuauTypestateBuiltins2) + if (FFlag::LuauSolverV2) CHECK(toString(b->returnType) == "{ read a: number }"); - else if (FFlag::LuauSolverV2) - CHECK(toString(b->returnType) == "{ a: number }"); else CHECK(toString(b->returnType) == "{| a: number |}"); } diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index dea027c2..7460434e 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -17,6 +17,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauDoNotGeneralizeInTypeFunctions) TEST_SUITE_BEGIN("TypeInferOperators"); @@ -814,19 +815,19 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") { - CheckResult result = check(R"( - --!strict - local t = {} - while true and t[1] do - print(t[1].test) - end - )"); + ScopedFastFlag _{FFlag::LuauDoNotGeneralizeInTypeFunctions, true}; - // This infers a type for `t` of `{unknown}`, and so it makes sense that `t[1].test` would error. - if (FFlag::LuauSolverV2) - LUAU_REQUIRE_ERROR_COUNT(1, result); - else - LUAU_REQUIRE_NO_ERRORS(result); + // `t` will be inferred to be of type `{ { test: unknown } }` which is + // reasonable, in that it's empty with no bounds on its members. Optimally + // we might emit an error here that the `print(...)` expression is + // unreachable. + LUAU_REQUIRE_NO_ERRORS(check(R"( + --!strict + local t = {} + while true and t[1] do + print(t[1].test) + end + )")); } TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index a61b0fd1..ee7d713d 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + #include "Luau/TypeInfer.h" #include "Luau/RecursionCounter.h" @@ -12,6 +13,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(DebugLuauEqSatSimplification); +LUAU_FASTFLAG(LuauStoreCSTData); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauTarjanChildLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); @@ -46,7 +48,16 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expected = R"( + const std::string expected = FFlag::LuauStoreCSTData ? R"( + function f(a:{fn:()->(a,b...)}): () + if type(a) == 'boolean' then + local a1:boolean=a + elseif a.fn() then + local a2:{fn:()->(a,b...)}=a + end + end + )" + : R"( function f(a:{fn:()->(a,b...)}): () if type(a) == 'boolean'then local a1:boolean=a @@ -56,7 +67,16 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expectedWithNewSolver = R"( + const std::string expectedWithNewSolver = FFlag::LuauStoreCSTData ? R"( + function f(a:{fn:()->(unknown,...unknown)}): () + if type(a) == 'boolean' then + local a1:{fn:()->(unknown,...unknown)}&boolean=a + elseif a.fn() then + local a2:{fn:()->(unknown,...unknown)}&(class|function|nil|number|string|thread|buffer|table)=a + end + end + )" + : R"( function f(a:{fn:()->(unknown,...unknown)}): () if type(a) == 'boolean'then local a1:{fn:()->(unknown,...unknown)}&boolean=a @@ -66,7 +86,16 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expectedWithEqSat = R"( + const std::string expectedWithEqSat = FFlag::LuauStoreCSTData ? R"( + function f(a:{fn:()->(unknown,...unknown)}): () + if type(a) == 'boolean' then + local a1:{fn:()->(unknown,...unknown)}&boolean=a + elseif a.fn() then + local a2:{fn:()->(unknown,...unknown)}&negate=a + end + end + )" + : R"( function f(a:{fn:()->(unknown,...unknown)}): () if type(a) == 'boolean'then local a1:{fn:()->(unknown,...unknown)}&boolean=a @@ -535,10 +564,10 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - TypeId free1 = arena.addType(FreeType{scope.get()}); + TypeId free1 = arena.freshType(builtinTypes, scope.get()); TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - TypeId free2 = arena.addType(FreeType{scope.get()}); + TypeId free2 = arena.freshType(builtinTypes, scope.get()); TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; @@ -965,10 +994,10 @@ TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - TypeId free1 = arena.addType(FreeType{scope.get()}); + TypeId free1 = arena.freshType(builtinTypes, scope.get()); TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - TypeId free2 = arena.addType(FreeType{scope.get()}); + TypeId free2 = arena.freshType(builtinTypes, scope.get()); TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; @@ -1284,7 +1313,7 @@ TEST_CASE_FIXTURE(Fixture, "table_containing_non_final_type_is_erroneously_cache TableType* table = getMutable(tableTy); REQUIRE(table); - TypeId freeTy = arena.freshType(&globalScope); + TypeId freeTy = arena.freshType(builtinTypes, &globalScope); table->props["foo"] = Property::rw(freeTy); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 735f75ed..ea893528 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -18,7 +18,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauFixIndexerSubtypingOrdering) -LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) LUAU_FASTFLAG(LuauTableKeysAreRValues) LUAU_FASTFLAG(LuauAllowNilAssignmentToIndexer) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) @@ -2378,7 +2377,7 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table local c : string = t.m("hi") )"); - if (FFlag::LuauSolverV2 && FFlag::LuauRetrySubtypingWithoutHiddenPack) + if (FFlag::LuauSolverV2) { LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -2387,15 +2386,6 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table // This is not actually the expected behavior, but the typemismatch we were seeing before was for the wrong reason. // The behavior of this test is just regressed generally in the new solver, and will need to be consciously addressed. } - else if (FFlag::LuauSolverV2) - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK(get(result.errors[0])); - CHECK(Location{{6, 45}, {6, 46}} == result.errors[0].location); - - CHECK(get(result.errors[1])); - } // TODO: test behavior is wrong with LuauInstantiateInSubtyping until we can re-enable the covariant requirement for instantiation in subtyping else if (FFlag::LuauInstantiateInSubtyping) diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 297a5153..2ff97a25 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -24,6 +24,7 @@ LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTFLAG(InferGlobalTypes) +LUAU_FASTFLAG(LuauAstTypeGroup) using namespace Luau; @@ -1201,7 +1202,10 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") CHECK(3 == result.errors.size()); CHECK(Location{{2, 22}, {2, 41}} == result.errors[0].location); CHECK(Location{{3, 22}, {3, 42}} == result.errors[1].location); - CHECK(Location{{3, 23}, {3, 40}} == result.errors[2].location); + if (FFlag::LuauAstTypeGroup) + CHECK(Location{{3, 22}, {3, 40}} == result.errors[2].location); + else + CHECK(Location{{3, 23}, {3, 40}} == result.errors[2].location); CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[1])); CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[2])); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index ccfa6923..48f5d3ea 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -42,12 +42,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") { - Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) + Type functionOne{TypeVariant{ + FunctionType(arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) }}; - Type functionTwo{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))} - }; + Type functionTwo{TypeVariant{FunctionType( + arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}) + )}}; state.tryUnify(&functionTwo, &functionOne); CHECK(!state.failure); @@ -60,14 +61,16 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") { - TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; - Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) + TypePackVar argPackOne{TypePack{{arena.freshType(builtinTypes, globalScope->level)}, std::nullopt}}; + Type functionOne{TypeVariant{ + FunctionType(arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) }}; Type functionOneSaved = functionOne.clone(); - TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; - Type functionTwo{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->stringType})) + TypePackVar argPackTwo{TypePack{{arena.freshType(builtinTypes, globalScope->level)}, std::nullopt}}; + Type functionTwo{TypeVariant{ + FunctionType(arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({builtinTypes->stringType})) }}; Type functionTwoSaved = functionTwo.clone(); @@ -83,11 +86,11 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") { Type tableOne{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, + TableType{{{"foo", {arena.freshType(builtinTypes, globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; Type tableTwo{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, + TableType{{{"foo", {arena.freshType(builtinTypes, globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); @@ -106,7 +109,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") { Type tableOne{TypeVariant{ TableType{ - {{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, + {{"foo", {arena.freshType(builtinTypes, globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, std::nullopt, globalScope->level, TableState::Unsealed @@ -115,7 +118,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") Type tableTwo{TypeVariant{ TableType{ - {{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->stringType}}}, + {{"foo", {arena.freshType(builtinTypes, globalScope->level)}}, {"bar", {builtinTypes->stringType}}}, std::nullopt, globalScope->level, TableState::Unsealed @@ -295,7 +298,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") { - Type redirect{FreeType{TypeLevel{}}}; + Type redirect{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; Type table{TableType{}}; Type metatable{MetatableType{&redirect, &table}}; redirect = BoundType{&metatable}; // Now we have a metatable that is recursive on the table type @@ -318,7 +321,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") { - TypeId a = arena.addType(Type{FreeType{TypeLevel{}}}); + TypeId a = arena.freshType(builtinTypes, TypeLevel{}); TypeId b = builtinTypes->numberType; state.tryUnify(a, b); @@ -381,7 +384,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") TypePackVar packTmp{TypePack{{builtinTypes->anyType}, &variadicAny}}; TypePackVar packSub{TypePack{{builtinTypes->anyType, builtinTypes->anyType}, &packTmp}}; - Type freeTy{FreeType{TypeLevel{}}}; + Type freeTy{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; TypePackVar freeTp{FreeTypePack{TypeLevel{}}}; TypePackVar packSuper{TypePack{{&freeTy}, &freeTp}}; @@ -438,10 +441,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_two_unions_under_dcr_does_not_creat const std::shared_ptr scope = globalScope; const std::shared_ptr nestedScope = std::make_shared(scope); - const TypeId outerType = arena.freshType(scope.get()); - const TypeId outerType2 = arena.freshType(scope.get()); + const TypeId outerType = arena.freshType(builtinTypes, scope.get()); + const TypeId outerType2 = arena.freshType(builtinTypes, scope.get()); - const TypeId innerType = arena.freshType(nestedScope.get()); + const TypeId innerType = arena.freshType(builtinTypes, nestedScope.get()); state.enableNewSolver(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 037c0103..247894d1 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -621,7 +621,7 @@ TEST_CASE_FIXTURE(Fixture, "indexing_into_a_cyclic_union_doesnt_crash") TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); - TypeId badCyclicUnionTy = arena.freshType(frontend.globals.globalScope.get()); + TypeId badCyclicUnionTy = arena.freshType(builtinTypes, frontend.globals.globalScope.get()); UnionType u; u.options.push_back(badCyclicUnionTy); diff --git a/tests/TypePath.test.cpp b/tests/TypePath.test.cpp index bf831621..b281dcab 100644 --- a/tests/TypePath.test.cpp +++ b/tests/TypePath.test.cpp @@ -17,6 +17,7 @@ using namespace Luau::TypePath; LUAU_FASTFLAG(LuauSolverV2); LUAU_DYNAMIC_FASTINT(LuauTypePathMaximumTraverseSteps); +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds); struct TypePathFixture : Fixture { @@ -277,7 +278,7 @@ TEST_CASE_FIXTURE(TypePathFixture, "bounds") TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); - TypeId ty = arena.freshType(frontend.globals.globalScope.get()); + TypeId ty = arena.freshType(frontend.builtinTypes, frontend.globals.globalScope.get()); FreeType* ft = getMutable(ty); SUBCASE("upper") diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 9e21b1e0..1e5fdaf1 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -219,7 +219,7 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeIterator_with_only_cyclic_union") */ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { - Type ftv11{FreeType{TypeLevel{}}}; + Type ftv11{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; TypePackVar tp24{TypePack{{&ftv11}}}; TypePackVar tp17{TypePack{}}; @@ -469,8 +469,8 @@ TEST_CASE("content_reassignment") myAny.documentationSymbol = "@global/any"; TypeArena arena; - - TypeId futureAny = arena.addType(FreeType{TypeLevel{}}); + BuiltinTypes builtinTypes; + TypeId futureAny = arena.freshType(NotNull{&builtinTypes}, TypeLevel{}); asMutable(futureAny)->reassign(myAny); CHECK(get(futureAny) != nullptr); diff --git a/tests/VisitType.test.cpp b/tests/VisitType.test.cpp index 186afaa5..86063ae8 100644 --- a/tests/VisitType.test.cpp +++ b/tests/VisitType.test.cpp @@ -4,6 +4,7 @@ #include "Luau/RecursionCounter.h" +#include "Luau/Type.h" #include "doctest.h" using namespace Luau; @@ -54,7 +55,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") TEST_CASE_FIXTURE(Fixture, "some_free_types_do_not_have_bounds") { - Type t{FreeType{TypeLevel{}}}; + Type t{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; (void)toString(&t); }