diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 3047b905..d1a9cfcc 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -392,7 +392,7 @@ private: **/ std::vector> createGenerics( const ScopePtr& scope, - AstArray generics, + AstArray generics, bool useCache = false, bool addTypes = true ); @@ -409,7 +409,7 @@ private: **/ std::vector> createGenericPacks( const ScopePtr& scope, - AstArray packs, + AstArray packs, bool useCache = false, bool addTypes = true ); diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index 7c0e81ac..1f28abe9 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -221,8 +221,8 @@ private: void visitTypeList(AstTypeList l); - void visitGenerics(AstArray g); - void visitGenericPacks(AstArray g); + void visitGenerics(AstArray g); + void visitGenericPacks(AstArray g); }; } // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 7346a422..ebce78cf 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -21,6 +21,12 @@ LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection) namespace Luau { +using LogLuauProc = void (*)(std::string_view); +extern LogLuauProc logLuau; + +void setLogLuau(LogLuauProc ll); +void resetLogLuauProc(); + struct Module; struct AnyTypeSummary; diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index 871471a4..0c52b1f1 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -175,7 +175,7 @@ private: void visit(AstExprInterpString* interpString); void visit(AstExprError* expr); TypeId flattenPack(TypePackId pack); - void visitGenerics(AstArray generics, AstArray genericPacks); + void visitGenerics(AstArray generics, AstArray genericPacks); void visit(AstType* ty); void visit(AstTypeReference* ty); void visit(AstTypeTable* table); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 7f2e29b5..2b8dbc3a 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -399,8 +399,8 @@ private: const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, - const AstArray& genericPackNames, + const AstArray& genericNames, + const AstArray& genericPackNames, bool useCache = false ); diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index de69c17c..eb16b2fa 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -49,6 +49,27 @@ struct UnifierSharedState DenseHashSet tempSeenTp{nullptr}; UnifierCounters counters; + + bool reentrantTypeReduction = false; + +}; + +struct TypeReductionRentrancyGuard final +{ + explicit TypeReductionRentrancyGuard(NotNull sharedState) + : sharedState{sharedState} + { + sharedState->reentrantTypeReduction = true; + } + ~TypeReductionRentrancyGuard() + { + sharedState->reentrantTypeReduction = false; + } + TypeReductionRentrancyGuard(const TypeReductionRentrancyGuard&) = delete; + TypeReductionRentrancyGuard(TypeReductionRentrancyGuard&&) = delete; + +private: + NotNull sharedState; }; } // namespace Luau diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 7aee25ce..756451a7 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -34,6 +34,7 @@ LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2) LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType3) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) LUAU_FASTFLAGVARIABLE(LuauFreezeIgnorePersistent) +LUAU_FASTFLAGVARIABLE(LuauFollowTableFreeze) namespace Luau { @@ -1459,7 +1460,8 @@ bool MagicClone::infer(const MagicFunctionCallContext& context) static std::optional freezeTable(TypeId inputType, const MagicFunctionCallContext& context) { TypeArena* arena = context.solver->arena; - + if (FFlag::LuauFollowTableFreeze) + inputType = follow(inputType); if (auto mt = get(inputType)) { std::optional frozenTable = freezeTable(mt->table, context); diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index f77d7944..d830fac3 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -37,8 +37,8 @@ LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses) LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) LUAU_FASTFLAGVARIABLE(LuauTrackInteriorFreeTypesOnScope) -LUAU_FASTFLAGVARIABLE(InferGlobalTypes) LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) +LUAU_FASTFLAGVARIABLE(LuauInferLocalTypesInMultipleAssignments) namespace Luau { @@ -1025,37 +1025,49 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat TypePackId rvaluePack = checkPack(scope, statLocal->values, expectedTypes).tp; Checkpoint end = checkpoint(this); - if (hasAnnotation) + if (FFlag::LuauInferLocalTypesInMultipleAssignments) { + std::vector deferredTypes; + auto [head, tail] = flatten(rvaluePack); + for (size_t i = 0; i < statLocal->vars.size; ++i) { LUAU_ASSERT(get(assignees[i])); TypeIds* localDomain = localTypes.find(assignees[i]); LUAU_ASSERT(localDomain); - localDomain->insert(annotatedTypes[i]); + + if (statLocal->vars.data[i]->annotation) + { + localDomain->insert(annotatedTypes[i]); + } + else + { + if (i < head.size()) + { + localDomain->insert(head[i]); + } + else if (tail) + { + deferredTypes.push_back(arena->addType(BlockedType{})); + localDomain->insert(deferredTypes.back()); + } + else + { + localDomain->insert(builtinTypes->nilType); + } + } } - TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); - addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); - } - else - { - std::vector valueTypes; - valueTypes.reserve(statLocal->vars.size); - - auto [head, tail] = flatten(rvaluePack); - - if (head.size() >= statLocal->vars.size) + if (hasAnnotation) { - for (size_t i = 0; i < statLocal->vars.size; ++i) - valueTypes.push_back(head[i]); + TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); + addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); } - else - { - for (size_t i = 0; i < statLocal->vars.size; ++i) - valueTypes.push_back(arena->addType(BlockedType{})); - auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); + if (!deferredTypes.empty()) + { + LUAU_ASSERT(tail); + NotNull uc = addConstraint(scope, statLocal->location, UnpackConstraint{deferredTypes, *tail}); forEachConstraint( start, @@ -1063,20 +1075,69 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat this, [&uc](const ConstraintPtr& runBefore) { - uc->dependencies.push_back(NotNull{runBefore.get()}); + uc->dependencies.emplace_back(runBefore.get()); } ); - for (TypeId t : valueTypes) + for (TypeId t : deferredTypes) getMutable(t)->setOwner(uc); } - - for (size_t i = 0; i < statLocal->vars.size; ++i) + } + else + { + if (hasAnnotation) { - LUAU_ASSERT(get(assignees[i])); - TypeIds* localDomain = localTypes.find(assignees[i]); - LUAU_ASSERT(localDomain); - localDomain->insert(valueTypes[i]); + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + LUAU_ASSERT(get(assignees[i])); + TypeIds* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->insert(annotatedTypes[i]); + } + + TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); + addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); + } + else + { + std::vector valueTypes; + valueTypes.reserve(statLocal->vars.size); + + auto [head, tail] = flatten(rvaluePack); + + if (head.size() >= statLocal->vars.size) + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); + + forEachConstraint( + start, + end, + this, + [&uc](const ConstraintPtr& runBefore) + { + uc->dependencies.push_back(NotNull{runBefore.get()}); + } + ); + + for (TypeId t : valueTypes) + getMutable(t)->setOwner(uc); + } + + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + LUAU_ASSERT(get(assignees[i])); + TypeIds* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->insert(valueTypes[i]); + } } } @@ -2810,13 +2871,10 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob DefId def = dfg->getDef(global); rootScope->lvalueTypes[def] = rhsType; - if (FFlag::InferGlobalTypes) - { - // Sketchy: We're specifically looking for BlockedTypes that were - // initially created by ConstraintGenerator::prepopulateGlobalScope. - if (auto bt = get(follow(*annotatedTy)); bt && !bt->getOwner()) - emplaceType(asMutable(*annotatedTy), rhsType); - } + // Sketchy: We're specifically looking for BlockedTypes that were + // initially created by ConstraintGenerator::prepopulateGlobalScope. + if (auto bt = get(follow(*annotatedTy)); bt && !bt->getOwner()) + emplaceType(asMutable(*annotatedTy), rhsType); addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); } @@ -3535,33 +3593,34 @@ TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, const Ast std::vector> ConstraintGenerator::createGenerics( const ScopePtr& scope, - AstArray generics, + AstArray generics, bool useCache, bool addTypes ) { std::vector> result; - for (const auto& generic : generics) + for (const auto* generic : generics) { TypeId genericTy = nullptr; - if (auto it = scope->parent->typeAliasTypeParameters.find(generic.name.value); useCache && it != scope->parent->typeAliasTypeParameters.end()) + if (auto it = scope->parent->typeAliasTypeParameters.find(generic->name.value); + useCache && it != scope->parent->typeAliasTypeParameters.end()) genericTy = it->second; else { - genericTy = arena->addType(GenericType{scope.get(), generic.name.value}); - scope->parent->typeAliasTypeParameters[generic.name.value] = genericTy; + genericTy = arena->addType(GenericType{scope.get(), generic->name.value}); + scope->parent->typeAliasTypeParameters[generic->name.value] = genericTy; } std::optional defaultTy = std::nullopt; - if (generic.defaultValue) - defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false); + if (generic->defaultValue) + defaultTy = resolveType(scope, generic->defaultValue, /* inTypeArguments */ false); if (addTypes) - scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy}; + scope->privateTypeBindings[generic->name.value] = TypeFun{genericTy}; - result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); + result.emplace_back(generic->name.value, GenericTypeDefinition{genericTy, defaultTy}); } return result; @@ -3569,34 +3628,34 @@ std::vector> ConstraintGenerator::createG std::vector> ConstraintGenerator::createGenericPacks( const ScopePtr& scope, - AstArray generics, + AstArray generics, bool useCache, bool addTypes ) { std::vector> result; - for (const auto& generic : generics) + for (const auto* generic : generics) { TypePackId genericTy; - if (auto it = scope->parent->typeAliasTypePackParameters.find(generic.name.value); + if (auto it = scope->parent->typeAliasTypePackParameters.find(generic->name.value); useCache && it != scope->parent->typeAliasTypePackParameters.end()) genericTy = it->second; else { - genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}}); - scope->parent->typeAliasTypePackParameters[generic.name.value] = genericTy; + genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic->name.value}}); + scope->parent->typeAliasTypePackParameters[generic->name.value] = genericTy; } std::optional defaultTy = std::nullopt; - if (generic.defaultValue) - defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false); + if (generic->defaultValue) + defaultTy = resolveTypePack(scope, generic->defaultValue, /* inTypeArguments */ false); if (addTypes) - scope->privateTypePackBindings[generic.name.value] = genericTy; + scope->privateTypePackBindings[generic->name.value] = genericTy; - result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); + result.emplace_back(generic->name.value, GenericTypePackDefinition{genericTy, defaultTy}); } return result; @@ -3739,18 +3798,15 @@ struct GlobalPrepopulator : AstVisitor bool visit(AstStatAssign* assign) override { - if (FFlag::InferGlobalTypes) + for (const Luau::AstExpr* expr : assign->vars) { - for (const Luau::AstExpr* expr : assign->vars) + if (const AstExprGlobal* g = expr->as()) { - if (const AstExprGlobal* g = expr->as()) - { - if (!globalScope->lookup(g->name)) - globalScope->globalsToWarn.insert(g->name.value); + if (!globalScope->lookup(g->name)) + globalScope->globalsToWarn.insert(g->name.value); - TypeId bt = arena->addType(BlockedType{}); - globalScope->bindings[g->name] = Binding{bt, g->location}; - } + TypeId bt = arena->addType(BlockedType{}); + globalScope->bindings[g->name] = Binding{bt, g->location}; } } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index cb2f6bbf..6f7bd132 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -37,6 +37,7 @@ LUAU_FASTFLAGVARIABLE(LuauAllowNilAssignmentToIndexer) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) LUAU_FASTFLAGVARIABLE(LuauAlwaysFillInFunctionCallDiscriminantTypes) LUAU_FASTFLAGVARIABLE(LuauTrackInteriorFreeTablesOnScope) +LUAU_FASTFLAGVARIABLE(LuauPrecalculateMutatedFreeTypes) namespace Luau { @@ -438,6 +439,10 @@ void ConstraintSolver::run() snapshot = logger->prepareStepSnapshot(rootScope, c, force, unsolvedConstraints); } + std::optional> mutatedFreeTypes = std::nullopt; + if (FFlag::LuauPrecalculateMutatedFreeTypes) + mutatedFreeTypes = c->getMaybeMutatedFreeTypes(); + bool success = tryDispatch(c, force); progress |= success; @@ -447,20 +452,42 @@ void ConstraintSolver::run() unblock(c); unsolvedConstraints.erase(unsolvedConstraints.begin() + ptrdiff_t(i)); - // decrement the referenced free types for this constraint if we dispatched successfully! - for (auto ty : c->getMaybeMutatedFreeTypes()) + if (FFlag::LuauPrecalculateMutatedFreeTypes) { - size_t& refCount = unresolvedConstraints[ty]; - if (refCount > 0) - refCount -= 1; + for (auto ty : c->getMaybeMutatedFreeTypes()) + mutatedFreeTypes->insert(ty); + for (auto ty : *mutatedFreeTypes) + { + size_t& refCount = unresolvedConstraints[ty]; + if (refCount > 0) + refCount -= 1; - // We have two constraints that are designed to wait for the - // refCount on a free type to be equal to 1: the - // PrimitiveTypeConstraint and ReduceConstraint. We - // therefore wake any constraint waiting for a free type's - // refcount to be 1 or 0. - if (refCount <= 1) - unblock(ty, Location{}); + // We have two constraints that are designed to wait for the + // refCount on a free type to be equal to 1: the + // PrimitiveTypeConstraint and ReduceConstraint. We + // therefore wake any constraint waiting for a free type's + // refcount to be 1 or 0. + if (refCount <= 1) + unblock(ty, Location{}); + } + } + else + { + // decrement the referenced free types for this constraint if we dispatched successfully! + for (auto ty : c->getMaybeMutatedFreeTypes()) + { + size_t& refCount = unresolvedConstraints[ty]; + if (refCount > 0) + refCount -= 1; + + // We have two constraints that are designed to wait for the + // refCount on a free type to be equal to 1: the + // PrimitiveTypeConstraint and ReduceConstraint. We + // therefore wake any constraint waiting for a free type's + // refcount to be 1 or 0. + if (refCount <= 1) + unblock(ty, Location{}); + } } if (logger) diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index cff87858..46c87845 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -1260,21 +1260,21 @@ void DataFlowGraphBuilder::visitTypeList(AstTypeList l) visitTypePack(l.tailType); } -void DataFlowGraphBuilder::visitGenerics(AstArray g) +void DataFlowGraphBuilder::visitGenerics(AstArray g) { - for (AstGenericType generic : g) + for (AstGenericType* generic : g) { - if (generic.defaultValue) - visitType(generic.defaultValue); + if (generic->defaultValue) + visitType(generic->defaultValue); } } -void DataFlowGraphBuilder::visitGenericPacks(AstArray g) +void DataFlowGraphBuilder::visitGenericPacks(AstArray g) { - for (AstGenericTypePack generic : g) + for (AstGenericTypePack* generic : g) { - if (generic.defaultValue) - visitTypePack(generic.defaultValue); + if (generic->defaultValue) + visitTypePack(generic->defaultValue); } } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 0042d6fb..595e4905 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,205 +2,11 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAG(LuauBufferBitMethods2) -LUAU_FASTFLAGVARIABLE(LuauMathMapDefinition) LUAU_FASTFLAG(LuauVector2Constructor) namespace Luau { -static const std::string kBuiltinDefinitionLuaSrcChecked_DEPRECATED = R"BUILTIN_SRC( - -declare bit32: { - band: @checked (...number) -> number, - bor: @checked (...number) -> number, - bxor: @checked (...number) -> number, - btest: @checked (number, ...number) -> boolean, - rrotate: @checked (x: number, disp: number) -> number, - lrotate: @checked (x: number, disp: number) -> number, - lshift: @checked (x: number, disp: number) -> number, - arshift: @checked (x: number, disp: number) -> number, - rshift: @checked (x: number, disp: number) -> number, - bnot: @checked (x: number) -> number, - extract: @checked (n: number, field: number, width: number?) -> number, - replace: @checked (n: number, v: number, field: number, width: number?) -> number, - countlz: @checked (n: number) -> number, - countrz: @checked (n: number) -> number, - byteswap: @checked (n: number) -> number, -} - -declare math: { - frexp: @checked (n: number) -> (number, number), - ldexp: @checked (s: number, e: number) -> number, - fmod: @checked (x: number, y: number) -> number, - modf: @checked (n: number) -> (number, number), - pow: @checked (x: number, y: number) -> number, - exp: @checked (n: number) -> number, - - ceil: @checked (n: number) -> number, - floor: @checked (n: number) -> number, - abs: @checked (n: number) -> number, - sqrt: @checked (n: number) -> number, - - log: @checked (n: number, base: number?) -> number, - log10: @checked (n: number) -> number, - - rad: @checked (n: number) -> number, - deg: @checked (n: number) -> number, - - sin: @checked (n: number) -> number, - cos: @checked (n: number) -> number, - tan: @checked (n: number) -> number, - sinh: @checked (n: number) -> number, - cosh: @checked (n: number) -> number, - tanh: @checked (n: number) -> number, - atan: @checked (n: number) -> number, - acos: @checked (n: number) -> number, - asin: @checked (n: number) -> number, - atan2: @checked (y: number, x: number) -> number, - - min: @checked (number, ...number) -> number, - max: @checked (number, ...number) -> number, - - pi: number, - huge: number, - - randomseed: @checked (seed: number) -> (), - random: @checked (number?, number?) -> number, - - sign: @checked (n: number) -> number, - clamp: @checked (n: number, min: number, max: number) -> number, - noise: @checked (x: number, y: number?, z: number?) -> number, - round: @checked (n: number) -> number, - map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number, -} - -type DateTypeArg = { - year: number, - month: number, - day: number, - hour: number?, - min: number?, - sec: number?, - isdst: boolean?, -} - -type DateTypeResult = { - year: number, - month: number, - wday: number, - yday: number, - day: number, - hour: number, - min: number, - sec: number, - isdst: boolean, -} - -declare os: { - time: (time: DateTypeArg?) -> number, - date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string), - difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, - clock: () -> number, -} - -@checked declare function require(target: any): any - -@checked declare function getfenv(target: any): { [string]: any } - -declare _G: any -declare _VERSION: string - -declare function gcinfo(): number - -declare function print(...: T...) - -declare function type(value: T): string -declare function typeof(value: T): string - --- `assert` has a magic function attached that will give more detailed type information -declare function assert(value: T, errorMessage: string?): T -declare function error(message: T, level: number?): never - -declare function tostring(value: T): string -declare function tonumber(value: T, radix: number?): number? - -declare function rawequal(a: T1, b: T2): boolean -declare function rawget(tab: {[K]: V}, k: K): V -declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} -declare function rawlen(obj: {[K]: V} | string): number - -declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? - -declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number) - -declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) - --- FIXME: The actual type of `xpcall` is: --- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) --- Since we can't represent the return value, we use (boolean, R1...). -declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) - --- `select` has a magic function attached to provide more detailed type information -declare function select(i: string | number, ...: A...): ...any - --- FIXME: This type is not entirely correct - `loadstring` returns a function or --- (nil, string). -declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) - -@checked declare function newproxy(mt: boolean?): any - -declare coroutine: { - create: (f: (A...) -> R...) -> thread, - resume: (co: thread, A...) -> (boolean, R...), - running: () -> thread, - status: @checked (co: thread) -> "dead" | "running" | "normal" | "suspended", - wrap: (f: (A...) -> R...) -> ((A...) -> R...), - yield: (A...) -> R..., - isyieldable: () -> boolean, - close: @checked (co: thread) -> (boolean, any) -} - -declare table: { - concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, - insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), - maxn: (t: {V}) -> number, - remove: (t: {V}, number?) -> V?, - sort: (t: {V}, comp: ((V, V) -> boolean)?) -> (), - create: (count: number, value: V?) -> {V}, - find: (haystack: {V}, needle: V, init: number?) -> number?, - - unpack: (list: {V}, i: number?, j: number?) -> ...V, - pack: (...V) -> { n: number, [number]: V }, - - getn: (t: {V}) -> number, - foreach: (t: {[K]: V}, f: (K, V) -> ()) -> (), - foreachi: ({V}, (number, V) -> ()) -> (), - - move: (src: {V}, a: number, b: number, t: number, dst: {V}?) -> {V}, - clear: (table: {[K]: V}) -> (), - - isfrozen: (t: {[K]: V}) -> boolean, -} - -declare debug: { - info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), - traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), -} - -declare utf8: { - char: @checked (...number) -> string, - charpattern: string, - codes: @checked (str: string) -> ((string, number) -> (number, number), string, number), - codepoint: @checked (str: string, i: number?, j: number?) -> ...number, - len: @checked (s: string, i: number?, j: number?) -> (number?, number?), - offset: @checked (s: string, n: number?, i: number?) -> number, -} - --- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. -declare function unpack(tab: {V}, i: number?, j: number?): ...V - -)BUILTIN_SRC"; - static const std::string kBuiltinDefinitionBaseSrc = R"BUILTIN_SRC( @checked declare function require(target: any): any @@ -549,18 +355,15 @@ declare vector: { std::string getBuiltinDefinitionSource() { - std::string result = FFlag::LuauMathMapDefinition ? kBuiltinDefinitionBaseSrc : kBuiltinDefinitionLuaSrcChecked_DEPRECATED; + std::string result = kBuiltinDefinitionBaseSrc; - if (FFlag::LuauMathMapDefinition) - { - result += kBuiltinDefinitionBit32Src; - result += kBuiltinDefinitionMathSrc; - result += kBuiltinDefinitionOsSrc; - result += kBuiltinDefinitionCoroutineSrc; - result += kBuiltinDefinitionTableSrc; - result += kBuiltinDefinitionDebugSrc; - result += kBuiltinDefinitionUtf8Src; - } + result += kBuiltinDefinitionBit32Src; + result += kBuiltinDefinitionMathSrc; + result += kBuiltinDefinitionOsSrc; + result += kBuiltinDefinitionCoroutineSrc; + result += kBuiltinDefinitionTableSrc; + result += kBuiltinDefinitionDebugSrc; + result += kBuiltinDefinitionUtf8Src; result += FFlag::LuauBufferBitMethods2 ? kBuiltinDefinitionBufferSrc : kBuiltinDefinitionBufferSrc_DEPRECATED; diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index bc82d750..c864b836 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -34,7 +34,7 @@ LUAU_FASTFLAG(LuauReferenceAllocatorInNewSolver) LUAU_FASTFLAGVARIABLE(LuauMixedModeDefFinderTraversesTypeOf) LUAU_FASTFLAG(LuauBetterReverseDependencyTracking) LUAU_FASTFLAGVARIABLE(LuauCloneIncrementalModule) - +LUAU_FASTFLAGVARIABLE(LogFragmentsFromAutocomplete) namespace { template @@ -335,6 +335,8 @@ std::optional parseFragment( FragmentParseResult fragmentResult; fragmentResult.fragmentToParse = std::string(dbg.data(), parseLength); // For the duration of the incremental parse, we want to allow the name table to re-use duplicate names + if (FFlag::LogFragmentsFromAutocomplete) + logLuau(dbg); ParseOptions opts; opts.allowDeclarationSyntax = false; @@ -650,7 +652,8 @@ FragmentAutocompleteResult fragmentAutocomplete( return {}; auto globalScope = (opts && opts->forAutocomplete) ? frontend.globalsForAutocomplete.globalScope.get() : frontend.globals.globalScope.get(); - + if (FFlag::LogFragmentsFromAutocomplete) + logLuau(src); TypeArena arenaForFragmentAutocomplete; auto result = Luau::autocomplete_( tcResult.incrementalModule, diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 3209fd08..1dbd6608 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -20,6 +20,26 @@ LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteCommentDetection) namespace Luau { +static void defaultLogLuau(std::string_view input) +{ + // The default is to do nothing because we don't want to mess with + // the xml parsing done by the dcr script. +} + +Luau::LogLuauProc logLuau = &defaultLogLuau; + +void setLogLuau(LogLuauProc ll) +{ + logLuau = ll; +} + +void resetLogLuauProc() +{ + logLuau = &defaultLogLuau; +} + + + static bool contains_DEPRECATED(Position pos, Comment comment) { if (comment.location.contains(pos)) diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index 0645e4e2..453b552c 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -14,13 +14,15 @@ #include "Luau/TypeFunction.h" #include "Luau/Def.h" #include "Luau/ToString.h" -#include "Luau/TypeFwd.h" +#include "Luau/TypeUtils.h" #include #include LUAU_FASTFLAGVARIABLE(LuauCountSelfCallsNonstrict) LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) +LUAU_FASTFLAGVARIABLE(LuauNonStrictVisitorImprovements) +LUAU_FASTFLAGVARIABLE(LuauNewNonStrictWarnOnUnknownGlobals) namespace Luau { @@ -342,8 +344,9 @@ struct NonStrictTypeChecker NonStrictContext visit(AstStatIf* ifStatement) { - NonStrictContext condB = visit(ifStatement->condition); + NonStrictContext condB = visit(ifStatement->condition, ValueContext::RValue); NonStrictContext branchContext; + // If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error if (ifStatement->elsebody) { @@ -351,17 +354,32 @@ struct NonStrictTypeChecker NonStrictContext elseBody = visit(ifStatement->elsebody); branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody); } + return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext); } NonStrictContext visit(AstStatWhile* whileStatement) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + NonStrictContext condition = visit(whileStatement->condition, ValueContext::RValue); + NonStrictContext body = visit(whileStatement->body); + return NonStrictContext::disjunction(builtinTypes, arena, condition, body); + } + else + return {}; } NonStrictContext visit(AstStatRepeat* repeatStatement) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + NonStrictContext body = visit(repeatStatement->body); + NonStrictContext condition = visit(repeatStatement->condition, ValueContext::RValue); + return NonStrictContext::disjunction(builtinTypes, arena, body, condition); + } + else + return {}; } NonStrictContext visit(AstStatBreak* breakStatement) @@ -376,49 +394,94 @@ struct NonStrictTypeChecker NonStrictContext visit(AstStatReturn* returnStatement) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + // TODO: this is believing existing code, but i'm not sure if this makes sense + // for how the contexts are handled + for (AstExpr* expr : returnStatement->list) + visit(expr, ValueContext::RValue); + } + return {}; } NonStrictContext visit(AstStatExpr* expr) { - return visit(expr->expr); + return visit(expr->expr, ValueContext::RValue); } NonStrictContext visit(AstStatLocal* local) { for (AstExpr* rhs : local->values) - visit(rhs); + visit(rhs, ValueContext::RValue); return {}; } NonStrictContext visit(AstStatFor* forStatement) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + // TODO: throwing out context based on same principle as existing code? + if (forStatement->from) + visit(forStatement->from, ValueContext::RValue); + if (forStatement->to) + visit(forStatement->to, ValueContext::RValue); + if (forStatement->step) + visit(forStatement->step, ValueContext::RValue); + return visit(forStatement->body); + } + else + { + return {}; + } } NonStrictContext visit(AstStatForIn* forInStatement) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (AstExpr* rhs : forInStatement->values) + visit(rhs, ValueContext::RValue); + return visit(forInStatement->body); + } + else + { + return {}; + } } NonStrictContext visit(AstStatAssign* assign) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (AstExpr* lhs : assign->vars) + visit(lhs, ValueContext::LValue); + for (AstExpr* rhs : assign->values) + visit(rhs, ValueContext::RValue); + } + return {}; } NonStrictContext visit(AstStatCompoundAssign* compoundAssign) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + visit(compoundAssign->var, ValueContext::LValue); + visit(compoundAssign->value, ValueContext::RValue); + } + return {}; } NonStrictContext visit(AstStatFunction* statFn) { - return visit(statFn->func); + return visit(statFn->func, ValueContext::RValue); } NonStrictContext visit(AstStatLocalFunction* localFn) { - return visit(localFn->func); + return visit(localFn->func, ValueContext::RValue); } NonStrictContext visit(AstStatTypeAlias* typeAlias) @@ -448,14 +511,22 @@ struct NonStrictTypeChecker NonStrictContext visit(AstStatError* error) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (AstStat* stat : error->statements) + visit(stat); + for (AstExpr* expr : error->expressions) + visit(expr, ValueContext::RValue); + } + return {}; } - NonStrictContext visit(AstExpr* expr) + NonStrictContext visit(AstExpr* expr, ValueContext context) { auto pusher = pushStack(expr); if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) @@ -465,17 +536,17 @@ struct NonStrictTypeChecker else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) @@ -499,9 +570,12 @@ struct NonStrictTypeChecker } } - NonStrictContext visit(AstExprGroup* group) + NonStrictContext visit(AstExprGroup* group, ValueContext context) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + return visit(group->expr, context); + else + return {}; } NonStrictContext visit(AstExprConstantNil* expr) @@ -524,17 +598,30 @@ struct NonStrictTypeChecker return {}; } - NonStrictContext visit(AstExprLocal* local) + NonStrictContext visit(AstExprLocal* local, ValueContext context) { return {}; } - NonStrictContext visit(AstExprGlobal* global) + NonStrictContext visit(AstExprGlobal* global, ValueContext context) { + if (FFlag::LuauNewNonStrictWarnOnUnknownGlobals) + { + // We don't file unknown symbols for LValues. + if (context == ValueContext::LValue) + return {}; + + NotNull scope = stack.back(); + if (!scope->lookup(global->name)) + { + reportError(UnknownSymbol{global->name.value, UnknownSymbol::Binding}, global->location); + } + } + return {}; } - NonStrictContext visit(AstExprVarargs* global) + NonStrictContext visit(AstExprVarargs* varargs) { return {}; } @@ -763,14 +850,24 @@ struct NonStrictTypeChecker return fresh; } - NonStrictContext visit(AstExprIndexName* indexName) + NonStrictContext visit(AstExprIndexName* indexName, ValueContext context) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + return visit(indexName->expr, context); + else + return {}; } - NonStrictContext visit(AstExprIndexExpr* indexExpr) + NonStrictContext visit(AstExprIndexExpr* indexExpr, ValueContext context) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + NonStrictContext expr = visit(indexExpr->expr, context); + NonStrictContext index = visit(indexExpr->index, ValueContext::RValue); + return NonStrictContext::disjunction(builtinTypes, arena, expr, index); + } + else + return {}; } NonStrictContext visit(AstExprFunction* exprFn) @@ -789,39 +886,74 @@ struct NonStrictTypeChecker NonStrictContext visit(AstExprTable* table) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (auto [_, key, value] : table->items) + { + if (key) + visit(key, ValueContext::RValue); + visit(value, ValueContext::RValue); + } + } + return {}; } NonStrictContext visit(AstExprUnary* unary) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + return visit(unary->expr, ValueContext::RValue); + else + return {}; } NonStrictContext visit(AstExprBinary* binary) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + NonStrictContext lhs = visit(binary->left, ValueContext::RValue); + NonStrictContext rhs = visit(binary->right, ValueContext::RValue); + return NonStrictContext::disjunction(builtinTypes, arena, lhs, rhs); + } + else + return {}; } NonStrictContext visit(AstExprTypeAssertion* typeAssertion) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + return visit(typeAssertion->expr, ValueContext::RValue); + else + return {}; } NonStrictContext visit(AstExprIfElse* ifElse) { - NonStrictContext condB = visit(ifElse->condition); - NonStrictContext thenB = visit(ifElse->trueExpr); - NonStrictContext elseB = visit(ifElse->falseExpr); + NonStrictContext condB = visit(ifElse->condition, ValueContext::RValue); + NonStrictContext thenB = visit(ifElse->trueExpr, ValueContext::RValue); + NonStrictContext elseB = visit(ifElse->falseExpr, ValueContext::RValue); return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB)); } NonStrictContext visit(AstExprInterpString* interpString) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (AstExpr* expr : interpString->expressions) + visit(expr, ValueContext::RValue); + } + return {}; } NonStrictContext visit(AstExprError* error) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (AstExpr* expr : error->expressions) + visit(expr, ValueContext::RValue); + } + return {}; } diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp index 5c7ea4d9..d5a4a804 100644 --- a/Analysis/src/TableLiteralInference.cpp +++ b/Analysis/src/TableLiteralInference.cpp @@ -10,6 +10,7 @@ #include "Luau/Unifier2.h" LUAU_FASTFLAGVARIABLE(LuauDontInPlaceMutateTableType) +LUAU_FASTFLAGVARIABLE(LuauAllowNonSharedTableTypesInLiteral) namespace Luau { @@ -251,8 +252,19 @@ TypeId matchLiteralType( Property& prop = it->second; - // Table literals always initially result in shared read-write types - LUAU_ASSERT(prop.isShared()); + if (FFlag::LuauAllowNonSharedTableTypesInLiteral) + { + // If we encounter a duplcate property, we may have already + // set it to be read-only. If that's the case, the only thing + // that will definitely crash is trying to access a write + // only property. + LUAU_ASSERT(!prop.isWriteOnly()); + } + else + { + // Table literals always initially result in shared read-write types + LUAU_ASSERT(prop.isShared()); + } TypeId propTy = *prop.readTy; auto it2 = expectedTableTy->props.find(keyStr); diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index eee41f24..1c8c5e85 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -117,9 +117,9 @@ struct StringWriter : Writer index = newlinePos + 1; } - pos.line += unsigned(numLines); + pos.line += numLines; if (numLines > 0) - pos.column = unsigned(s.size()) - unsigned(index); + pos.column = unsigned(s.size()) - index; else pos.column += unsigned(s.size()); } @@ -898,14 +898,14 @@ struct Printer_DEPRECATED { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); - if (o.defaultValue) + if (o->defaultValue) { - writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.maybeSpace(o->defaultValue->location.begin, 2); writer.symbol("="); - visualizeTypeAnnotation(*o.defaultValue); + visualizeTypeAnnotation(*o->defaultValue); } } @@ -913,15 +913,15 @@ struct Printer_DEPRECATED { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); writer.symbol("..."); - if (o.defaultValue) + if (o->defaultValue) { - writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.maybeSpace(o->defaultValue->location.begin, 2); writer.symbol("="); - visualizeTypePackAnnotation(*o.defaultValue, false); + visualizeTypePackAnnotation(*o->defaultValue, false); } } @@ -978,15 +978,15 @@ struct Printer_DEPRECATED { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); } for (const auto& o : func.genericPacks) { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); writer.symbol("..."); } writer.symbol(">"); @@ -1115,15 +1115,15 @@ struct Printer_DEPRECATED { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); } for (const auto& o : a->genericPacks) { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); writer.symbol("..."); } writer.symbol(">"); @@ -1690,7 +1690,10 @@ struct Printer if (writeTypes) { - writer.maybeSpace(a->annotation->location.begin, 2); + if (const auto* cstNode = lookupCstNode(a)) + advance(cstNode->opPosition); + else + writer.maybeSpace(a->annotation->location.begin, 2); writer.symbol("::"); visualizeTypeAnnotation(*a->annotation); } @@ -2047,14 +2050,14 @@ struct Printer { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); - if (o.defaultValue) + if (o->defaultValue) { - writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.maybeSpace(o->defaultValue->location.begin, 2); writer.symbol("="); - visualizeTypeAnnotation(*o.defaultValue); + visualizeTypeAnnotation(*o->defaultValue); } } @@ -2062,15 +2065,15 @@ struct Printer { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); writer.symbol("..."); - if (o.defaultValue) + if (o->defaultValue) { - writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.maybeSpace(o->defaultValue->location.begin, 2); writer.symbol("="); - visualizeTypePackAnnotation(*o.defaultValue, false); + visualizeTypePackAnnotation(*o->defaultValue, false); } } @@ -2131,15 +2134,15 @@ struct Printer { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); } for (const auto& o : func.genericPacks) { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); writer.symbol("..."); } writer.symbol(">"); @@ -2312,15 +2315,15 @@ struct Printer { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); } for (const auto& o : a->genericPacks) { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); writer.symbol("..."); } writer.symbol(">"); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 6fc60b2f..0d038694 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -261,24 +261,24 @@ public: if (hasSeen(&ftv)) return allocator->alloc(Location(), std::nullopt, AstName(""), std::nullopt, Location()); - AstArray generics; + AstArray generics; generics.size = ftv.generics.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstGenericType) * generics.size)); + generics.data = static_cast(allocator->allocate(sizeof(AstGenericType) * generics.size)); size_t numGenerics = 0; for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { if (auto gtv = get(*it)) - generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr}; + generics.data[numGenerics++] = allocator->alloc(Location(), AstName(gtv->name.c_str()), nullptr); } - AstArray genericPacks; + AstArray genericPacks; genericPacks.size = ftv.genericPacks.size(); - genericPacks.data = static_cast(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { if (auto gtv = get(*it)) - genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; + genericPacks.data[numGenericPacks++] = allocator->alloc(Location(), AstName(gtv->name.c_str()), nullptr); } AstArray argTypes; diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 32c0f4db..0f578954 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -29,7 +29,6 @@ LUAU_FASTFLAG(DebugLuauMagicTypes) -LUAU_FASTFLAG(InferGlobalTypes) LUAU_FASTFLAGVARIABLE(LuauTableKeysAreRValues) LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) @@ -1357,7 +1356,7 @@ void TypeChecker2::visit(AstExprGlobal* expr) { reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); } - else if (FFlag::InferGlobalTypes) + else { if (scope->shouldWarnGlobal(expr->name.value) && !warnedGlobals.contains(expr->name.value)) { @@ -2379,30 +2378,30 @@ TypeId TypeChecker2::flattenPack(TypePackId pack) ice->ice("flattenPack got a weird pack!"); } -void TypeChecker2::visitGenerics(AstArray generics, AstArray genericPacks) +void TypeChecker2::visitGenerics(AstArray generics, AstArray genericPacks) { DenseHashSet seen{AstName{}}; - for (const auto& g : generics) + for (const auto* g : generics) { - if (seen.contains(g.name)) - reportError(DuplicateGenericParameter{g.name.value}, g.location); + if (seen.contains(g->name)) + reportError(DuplicateGenericParameter{g->name.value}, g->location); else - seen.insert(g.name); + seen.insert(g->name); - if (g.defaultValue) - visit(g.defaultValue); + if (g->defaultValue) + visit(g->defaultValue); } - for (const auto& g : genericPacks) + for (const auto* g : genericPacks) { - if (seen.contains(g.name)) - reportError(DuplicateGenericParameter{g.name.value}, g.location); + if (seen.contains(g->name)) + reportError(DuplicateGenericParameter{g->name.value}, g->location); else - seen.insert(g.name); + seen.insert(g->name); - if (g.defaultValue) - visit(g.defaultValue); + if (g->defaultValue) + visit(g->defaultValue); } } diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index a5f69460..5a2e77a5 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -50,6 +50,7 @@ LUAU_FASTFLAG(DebugLuauEqSatSimplification) LUAU_FASTFLAGVARIABLE(LuauMetatableTypeFunctions) LUAU_FASTFLAGVARIABLE(LuauClipNestedAndRecursiveUnion) LUAU_FASTFLAGVARIABLE(LuauDoNotGeneralizeInTypeFunctions) +LUAU_FASTFLAGVARIABLE(LuauPreventReentrantTypeFunctionReduction) namespace Luau { @@ -446,19 +447,49 @@ static FunctionGraphReductionResult reduceFunctionsInternal( TypeFunctionReducer reducer{std::move(queuedTys), std::move(queuedTps), std::move(shouldGuess), std::move(cyclics), location, ctx, force}; int iterationCount = 0; - while (!reducer.done()) + if (FFlag::LuauPreventReentrantTypeFunctionReduction) { - reducer.step(); + // If we are reducing a type function while reducing a type function, + // we're probably doing something clowny. One known place this can + // occur is type function reduction => overload selection => subtyping + // => back to type function reduction. At worst, if there's a reduction + // that _doesn't_ loop forever and _needs_ reentrancy, we'll fail to + // handle that and potentially emit an error when we didn't need to. + if (ctx.normalizer->sharedState->reentrantTypeReduction) + return {}; - ++iterationCount; - if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) + TypeReductionRentrancyGuard _{ctx.normalizer->sharedState}; + while (!reducer.done()) { - reducer.result.errors.emplace_back(location, CodeTooComplex{}); - break; + reducer.step(); + + ++iterationCount; + if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) + { + reducer.result.errors.emplace_back(location, CodeTooComplex{}); + break; + } } + + return std::move(reducer.result); + } + else + { + while (!reducer.done()) + { + reducer.step(); + + ++iterationCount; + if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) + { + reducer.result.errors.emplace_back(location, CodeTooComplex{}); + break; + } + } + + return std::move(reducer.result); } - return std::move(reducer.result); } FunctionGraphReductionResult reduceTypeFunctions(TypeId entrypoint, Location location, TypeFunctionContext ctx, bool force) diff --git a/Analysis/src/TypeFunctionRuntime.cpp b/Analysis/src/TypeFunctionRuntime.cpp index c5c54477..14f13d5e 100644 --- a/Analysis/src/TypeFunctionRuntime.cpp +++ b/Analysis/src/TypeFunctionRuntime.cpp @@ -14,9 +14,6 @@ #include LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunFixInner) -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunGenerics) -LUAU_FASTFLAGVARIABLE(LuauUserTypeFunCloneTail) namespace Luau { @@ -158,7 +155,7 @@ static std::string getTag(lua_State* L, TypeFunctionTypeId ty) return "function"; else if (get(ty)) return "class"; - else if (FFlag::LuauUserTypeFunGenerics && get(ty)) + else if (get(ty)) return "generic"; LUAU_UNREACHABLE(); @@ -427,21 +424,11 @@ static int getNegatedValue(lua_State* L) luaL_error(L, "type.inner: expected 1 argument, but got %d", argumentCount); TypeFunctionTypeId self = getTypeUserData(L, 1); - - if (FFlag::LuauUserTypeFunFixInner) - { - if (auto tfnt = get(self); tfnt) - allocTypeUserData(L, tfnt->type->type); - else - luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str()); - } + + if (auto tfnt = get(self); tfnt) + allocTypeUserData(L, tfnt->type->type); else - { - if (auto tfnt = get(self); !tfnt) - allocTypeUserData(L, tfnt->type->type); - else - luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str()); - } + luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str()); return 1; } @@ -941,99 +928,6 @@ static void pushTypePack(lua_State* L, TypeFunctionTypePackId tp) } } -static int createFunction_DEPRECATED(lua_State* L) -{ - int argumentCount = lua_gettop(L); - if (argumentCount > 2) - luaL_error(L, "types.newfunction: expected 0-2 arguments, but got %d", argumentCount); - - TypeFunctionTypePackId argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); - if (lua_istable(L, 1)) - { - std::vector head{}; - lua_getfield(L, 1, "head"); - if (lua_istable(L, -1)) - { - int argSize = lua_objlen(L, -1); - for (int i = 1; i <= argSize; i++) - { - lua_pushinteger(L, i); - lua_gettable(L, -2); - - if (lua_isnil(L, -1)) - { - lua_pop(L, 1); - break; - } - - TypeFunctionTypeId ty = getTypeUserData(L, -1); - head.push_back(ty); - - lua_pop(L, 1); // Remove `ty` from stack - } - } - lua_pop(L, 1); // Pop the "head" field - - std::optional tail; - lua_getfield(L, 1, "tail"); - if (auto type = optionalTypeUserData(L, -1)) - tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); - lua_pop(L, 1); // Pop the "tail" field - - if (head.size() == 0 && tail.has_value()) - argTypes = *tail; - else - argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); - } - else if (!lua_isnoneornil(L, 1)) - luaL_typeerrorL(L, 1, "table"); - - TypeFunctionTypePackId retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); - if (lua_istable(L, 2)) - { - std::vector head{}; - lua_getfield(L, 2, "head"); - if (lua_istable(L, -1)) - { - int argSize = lua_objlen(L, -1); - for (int i = 1; i <= argSize; i++) - { - lua_pushinteger(L, i); - lua_gettable(L, -2); - - if (lua_isnil(L, -1)) - { - lua_pop(L, 1); - break; - } - - TypeFunctionTypeId ty = getTypeUserData(L, -1); - head.push_back(ty); - - lua_pop(L, 1); // Remove `ty` from stack - } - } - lua_pop(L, 1); // Pop the "head" field - - std::optional tail; - lua_getfield(L, 2, "tail"); - if (auto type = optionalTypeUserData(L, -1)) - tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); - lua_pop(L, 1); // Pop the "tail" field - - if (head.size() == 0 && tail.has_value()) - retTypes = *tail; - else - retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); - } - else if (!lua_isnoneornil(L, 2)) - luaL_typeerrorL(L, 2, "table"); - - allocTypeUserData(L, TypeFunctionFunctionType{{}, {}, argTypes, retTypes}); - - return 1; -} - // Luau: `types.newfunction(parameters: {head: {type}?, tail: type?}, returns: {head: {type}?, tail: type?}, generics: {type}?) -> type` // Returns the type instance representing a function static int createFunction(lua_State* L) @@ -1102,45 +996,7 @@ static int setFunctionParameters(lua_State* L) if (!tfft) luaL_error(L, "type.setparameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); - if (FFlag::LuauUserTypeFunGenerics) - { - tfft->argTypes = getTypePack(L, 2, 3); - } - else - { - std::vector head{}; - if (lua_istable(L, 2)) - { - int argSize = lua_objlen(L, 2); - for (int i = 1; i <= argSize; i++) - { - lua_pushinteger(L, i); - lua_gettable(L, 2); - - if (lua_isnil(L, -1)) - { - lua_pop(L, 1); - break; - } - - TypeFunctionTypeId ty = getTypeUserData(L, -1); - head.push_back(ty); - - lua_pop(L, 1); // Remove `ty` from stack - } - } - else if (!lua_isnoneornil(L, 2)) - luaL_typeerrorL(L, 2, "table"); - - std::optional tail; - if (auto type = optionalTypeUserData(L, 3)) - tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); - - if (head.size() == 0 && tail.has_value()) // Make argTypes a variadic type pack - tfft->argTypes = *tail; - else // Make argTypes a type pack - tfft->argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); - } + tfft->argTypes = getTypePack(L, 2, 3); return 0; } @@ -1158,59 +1014,7 @@ static int getFunctionParameters(lua_State* L) if (!tfft) luaL_error(L, "type.parameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); - if (FFlag::LuauUserTypeFunGenerics) - { - pushTypePack(L, tfft->argTypes); - } - else - { - if (auto tftp = get(tfft->argTypes)) - { - int size = 0; - if (tftp->head.size() > 0) - size++; - if (tftp->tail.has_value()) - size++; - - lua_createtable(L, 0, size); - - int argSize = (int)tftp->head.size(); - if (argSize > 0) - { - lua_createtable(L, argSize, 0); - for (int i = 0; i < argSize; i++) - { - allocTypeUserData(L, tftp->head[i]->type); - lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed - } - lua_setfield(L, -2, "head"); - } - - if (tftp->tail.has_value()) - { - auto tfvp = get(*tftp->tail); - if (!tfvp) - LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); - - allocTypeUserData(L, tfvp->type->type); - lua_setfield(L, -2, "tail"); - } - - return 1; - } - - if (auto tfvp = get(tfft->argTypes)) - { - lua_createtable(L, 0, 1); - - allocTypeUserData(L, tfvp->type->type); - lua_setfield(L, -2, "tail"); - - return 1; - } - - lua_createtable(L, 0, 0); - } + pushTypePack(L, tfft->argTypes); return 1; } @@ -1228,45 +1032,7 @@ static int setFunctionReturns(lua_State* L) if (!tfft) luaL_error(L, "type.setreturns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); - if (FFlag::LuauUserTypeFunGenerics) - { - tfft->retTypes = getTypePack(L, 2, 3); - } - else - { - std::vector head{}; - if (lua_istable(L, 2)) - { - int argSize = lua_objlen(L, 2); - for (int i = 1; i <= argSize; i++) - { - lua_pushinteger(L, i); - lua_gettable(L, 2); - - if (lua_isnil(L, -1)) - { - lua_pop(L, 1); - break; - } - - TypeFunctionTypeId ty = getTypeUserData(L, -1); - head.push_back(ty); - - lua_pop(L, 1); // Remove `ty` from stack - } - } - else if (!lua_isnoneornil(L, 2)) - luaL_typeerrorL(L, 2, "table"); - - std::optional tail; - if (auto type = optionalTypeUserData(L, 3)) - tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); - - if (head.size() == 0 && tail.has_value()) // Make retTypes a variadic type pack - tfft->retTypes = *tail; - else // Make retTypes a type pack - tfft->retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); - } + tfft->retTypes = getTypePack(L, 2, 3); return 0; } @@ -1284,59 +1050,7 @@ static int getFunctionReturns(lua_State* L) if (!tfft) luaL_error(L, "type.returns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); - if (FFlag::LuauUserTypeFunGenerics) - { - pushTypePack(L, tfft->retTypes); - } - else - { - if (auto tftp = get(tfft->retTypes)) - { - int size = 0; - if (tftp->head.size() > 0) - size++; - if (tftp->tail.has_value()) - size++; - - lua_createtable(L, 0, size); - - int argSize = (int)tftp->head.size(); - if (argSize > 0) - { - lua_createtable(L, argSize, 0); - for (int i = 0; i < argSize; i++) - { - allocTypeUserData(L, tftp->head[i]->type); - lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed - } - lua_setfield(L, -2, "head"); - } - - if (tftp->tail.has_value()) - { - auto tfvp = get(*tftp->tail); - if (!tfvp) - LUAU_ASSERT(!"We should only be supporting variadic packs as TypeFunctionTypePack.tail at the moment"); - - allocTypeUserData(L, tfvp->type->type); - lua_setfield(L, -2, "tail"); - } - - return 1; - } - - if (auto tfvp = get(tfft->retTypes)) - { - lua_createtable(L, 0, 1); - - allocTypeUserData(L, tfvp->type->type); - lua_setfield(L, -2, "tail"); - - return 1; - } - - lua_createtable(L, 0, 0); - } + pushTypePack(L, tfft->retTypes); return 1; } @@ -1761,9 +1475,9 @@ void registerTypesLibrary(lua_State* L) {"unionof", createUnion}, {"intersectionof", createIntersection}, {"newtable", createTable}, - {"newfunction", FFlag::LuauUserTypeFunGenerics ? createFunction : createFunction_DEPRECATED}, + {"newfunction", createFunction}, {"copy", deepCopy}, - {FFlag::LuauUserTypeFunGenerics ? "generic" : nullptr, FFlag::LuauUserTypeFunGenerics ? createGeneric : nullptr}, + {"generic", createGeneric}, {nullptr, nullptr} }; @@ -1838,12 +1552,12 @@ void registerTypeUserData(lua_State* L) {"parent", getClassParent}, // Function type methods (cont.) - {FFlag::LuauUserTypeFunGenerics ? "setgenerics" : nullptr, FFlag::LuauUserTypeFunGenerics ? setFunctionGenerics : nullptr}, - {FFlag::LuauUserTypeFunGenerics ? "generics" : nullptr, FFlag::LuauUserTypeFunGenerics ? getFunctionGenerics : nullptr}, + {"setgenerics", setFunctionGenerics}, + {"generics", getFunctionGenerics}, // Generic type methods - {FFlag::LuauUserTypeFunGenerics ? "name" : nullptr, FFlag::LuauUserTypeFunGenerics ? getGenericName : nullptr}, - {FFlag::LuauUserTypeFunGenerics ? "ispack" : nullptr, FFlag::LuauUserTypeFunGenerics ? getGenericIsPack : nullptr}, + {"name", getGenericName}, + {"ispack", getGenericIsPack}, {nullptr, nullptr} }; @@ -2097,25 +1811,22 @@ bool areEqual(SeenSet& seen, const TypeFunctionFunctionType& lhs, const TypeFunc if (seenSetContains(seen, &lhs, &rhs)) return true; - if (FFlag::LuauUserTypeFunGenerics) + if (lhs.generics.size() != rhs.generics.size()) + return false; + + for (auto l = lhs.generics.begin(), r = rhs.generics.begin(); l != lhs.generics.end() && r != rhs.generics.end(); ++l, ++r) { - if (lhs.generics.size() != rhs.generics.size()) + if (!areEqual(seen, **l, **r)) return false; + } - for (auto l = lhs.generics.begin(), r = rhs.generics.begin(); l != lhs.generics.end() && r != rhs.generics.end(); ++l, ++r) - { - if (!areEqual(seen, **l, **r)) - return false; - } + if (lhs.genericPacks.size() != rhs.genericPacks.size()) + return false; - if (lhs.genericPacks.size() != rhs.genericPacks.size()) + for (auto l = lhs.genericPacks.begin(), r = rhs.genericPacks.begin(); l != lhs.genericPacks.end() && r != rhs.genericPacks.end(); ++l, ++r) + { + if (!areEqual(seen, **l, **r)) return false; - - for (auto l = lhs.genericPacks.begin(), r = rhs.genericPacks.begin(); l != lhs.genericPacks.end() && r != rhs.genericPacks.end(); ++l, ++r) - { - if (!areEqual(seen, **l, **r)) - return false; - } } if (bool(lhs.argTypes) != bool(rhs.argTypes)) @@ -2218,14 +1929,11 @@ bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType return areEqual(seen, *lf, *rf); } - if (FFlag::LuauUserTypeFunGenerics) { - { - const TypeFunctionGenericType* lg = get(&lhs); - const TypeFunctionGenericType* rg = get(&rhs); - if (lg && rg) - return lg->isNamed == rg->isNamed && lg->isPack == rg->isPack && lg->name == rg->name; - } + const TypeFunctionGenericType* lg = get(&lhs); + const TypeFunctionGenericType* rg = get(&rhs); + if (lg && rg) + return lg->isNamed == rg->isNamed && lg->isPack == rg->isPack && lg->name == rg->name; } return false; @@ -2274,14 +1982,11 @@ bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunct return areEqual(seen, *lv, *rv); } - if (FFlag::LuauUserTypeFunGenerics) { - { - const TypeFunctionGenericTypePack* lg = get(&lhs); - const TypeFunctionGenericTypePack* rg = get(&rhs); - if (lg && rg) - return lg->isNamed == rg->isNamed && lg->name == rg->name; - } + const TypeFunctionGenericTypePack* lg = get(&lhs); + const TypeFunctionGenericTypePack* rg = get(&rhs); + if (lg && rg) + return lg->isNamed == rg->isNamed && lg->name == rg->name; } return false; @@ -2510,7 +2215,7 @@ private: } else if (auto c = get(ty)) target = ty; // Don't copy a class since they are immutable - else if (auto g = get(ty); FFlag::LuauUserTypeFunGenerics && g) + else if (auto g = get(ty)) target = typeFunctionRuntime->typeArena.allocate(TypeFunctionGenericType{g->isNamed, g->isPack, g->name}); else LUAU_ASSERT(!"Unknown type"); @@ -2531,7 +2236,7 @@ private: target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); else if (auto vPack = get(tp)) target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); - else if (auto gPack = get(tp); gPack && FFlag::LuauUserTypeFunGenerics) + else if (auto gPack = get(tp)) target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionGenericTypePack{gPack->isNamed, gPack->name}); else LUAU_ASSERT(!"Unknown type"); @@ -2565,8 +2270,7 @@ private: cloneChildren(f1, f2); else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) cloneChildren(c1, c2); - else if (auto [g1, g2] = std::tuple{getMutable(ty), getMutable(tfti)}; - FFlag::LuauUserTypeFunGenerics && g1 && g2) + else if (auto [g1, g2] = std::tuple{getMutable(ty), getMutable(tfti)}; g1 && g2) cloneChildren(g1, g2); else LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types @@ -2580,7 +2284,7 @@ private: vPack1 && vPack2) cloneChildren(vPack1, vPack2); else if (auto [gPack1, gPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; - FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2) + gPack1 && gPack2) cloneChildren(gPack1, gPack2); else LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types @@ -2662,16 +2366,13 @@ private: void cloneChildren(TypeFunctionFunctionType* f1, TypeFunctionFunctionType* f2) { - if (FFlag::LuauUserTypeFunGenerics) - { - f2->generics.reserve(f1->generics.size()); - for (auto ty : f1->generics) - f2->generics.push_back(shallowClone(ty)); + f2->generics.reserve(f1->generics.size()); + for (auto ty : f1->generics) + f2->generics.push_back(shallowClone(ty)); - f2->genericPacks.reserve(f1->genericPacks.size()); - for (auto tp : f1->genericPacks) - f2->genericPacks.push_back(shallowClone(tp)); - } + f2->genericPacks.reserve(f1->genericPacks.size()); + for (auto tp : f1->genericPacks) + f2->genericPacks.push_back(shallowClone(tp)); f2->argTypes = shallowClone(f1->argTypes); f2->retTypes = shallowClone(f1->retTypes); @@ -2692,11 +2393,8 @@ private: for (TypeFunctionTypeId& ty : t1->head) t2->head.push_back(shallowClone(ty)); - if (FFlag::LuauUserTypeFunCloneTail) - { - if (t1->tail) - t2->tail = shallowClone(*t1->tail); - } + if (t1->tail) + t2->tail = shallowClone(*t1->tail); } void cloneChildren(TypeFunctionVariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) diff --git a/Analysis/src/TypeFunctionRuntimeBuilder.cpp b/Analysis/src/TypeFunctionRuntimeBuilder.cpp index 6b7fa419..e9d8e41f 100644 --- a/Analysis/src/TypeFunctionRuntimeBuilder.cpp +++ b/Analysis/src/TypeFunctionRuntimeBuilder.cpp @@ -20,8 +20,6 @@ // currently, controls serialization, deserialization, and `type.copy` LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000); -LUAU_FASTFLAG(LuauUserTypeFunGenerics) - namespace Luau { @@ -212,7 +210,7 @@ private: state->classesSerialized[c->name] = ty; target = typeFunctionRuntime->typeArena.allocate(TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, c->name}); } - else if (auto g = get(ty); FFlag::LuauUserTypeFunGenerics && g) + else if (auto g = get(ty)) { Name name = g->name; @@ -245,7 +243,7 @@ private: target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); else if (auto vPack = get(tp)) target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); - else if (auto gPack = get(tp); FFlag::LuauUserTypeFunGenerics && gPack) + else if (auto gPack = get(tp)) { Name name = gPack->name; @@ -291,8 +289,7 @@ private: serializeChildren(f1, f2); else if (auto [c1, c2] = std::tuple{get(ty), getMutable(tfti)}; c1 && c2) serializeChildren(c1, c2); - else if (auto [g1, g2] = std::tuple{get(ty), getMutable(tfti)}; - FFlag::LuauUserTypeFunGenerics && g1 && g2) + else if (auto [g1, g2] = std::tuple{get(ty), getMutable(tfti)}; g1 && g2) serializeChildren(g1, g2); else { // Either this or ty and tfti do not represent the same type @@ -307,8 +304,7 @@ private: serializeChildren(tPack1, tPack2); else if (auto [vPack1, vPack2] = std::tuple{get(tp), getMutable(tftp)}; vPack1 && vPack2) serializeChildren(vPack1, vPack2); - else if (auto [gPack1, gPack2] = std::tuple{get(tp), getMutable(tftp)}; - FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2) + else if (auto [gPack1, gPack2] = std::tuple{get(tp), getMutable(tftp)}; gPack1 && gPack2) serializeChildren(gPack1, gPack2); else { // Either this or ty and tfti do not represent the same type @@ -399,16 +395,13 @@ private: void serializeChildren(const FunctionType* f1, TypeFunctionFunctionType* f2) { - if (FFlag::LuauUserTypeFunGenerics) - { - f2->generics.reserve(f1->generics.size()); - for (auto ty : f1->generics) - f2->generics.push_back(shallowSerialize(ty)); + f2->generics.reserve(f1->generics.size()); + for (auto ty : f1->generics) + f2->generics.push_back(shallowSerialize(ty)); - f2->genericPacks.reserve(f1->genericPacks.size()); - for (auto tp : f1->genericPacks) - f2->genericPacks.push_back(shallowSerialize(tp)); - } + f2->genericPacks.reserve(f1->genericPacks.size()); + for (auto tp : f1->genericPacks) + f2->genericPacks.push_back(shallowSerialize(tp)); f2->argTypes = shallowSerialize(f1->argTypes); f2->retTypes = shallowSerialize(f1->retTypes); @@ -573,14 +566,11 @@ private: deserializeChildren(tfti, ty); - if (FFlag::LuauUserTypeFunGenerics) + // If we have completed working on all children of a function, remove the generic parameters from scope + if (!functionScopes.empty() && queue.size() == functionScopes.back().oldQueueSize && state->errors.empty()) { - // If we have completed working on all children of a function, remove the generic parameters from scope - if (!functionScopes.empty() && queue.size() == functionScopes.back().oldQueueSize && state->errors.empty()) - { - closeFunctionScope(functionScopes.back().function); - functionScopes.pop_back(); - } + closeFunctionScope(functionScopes.back().function); + functionScopes.pop_back(); } } } @@ -702,7 +692,7 @@ private: else state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious class type is being deserialized"); } - else if (auto g = get(ty); FFlag::LuauUserTypeFunGenerics && g) + else if (auto g = get(ty)) { if (g->isPack) { @@ -752,7 +742,7 @@ private: { target = state->ctx->arena->addTypePack(VariadicTypePack{}); } - else if (auto gPack = get(tp); FFlag::LuauUserTypeFunGenerics && gPack) + else if (auto gPack = get(tp)) { auto it = std::find_if( genericPacks.rbegin(), @@ -809,8 +799,7 @@ private: deserializeChildren(f2, f1); else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) deserializeChildren(c2, c1); - else if (auto [g1, g2] = std::tuple{getMutable(ty), getMutable(tfti)}; - FFlag::LuauUserTypeFunGenerics && g1 && g2) + else if (auto [g1, g2] = std::tuple{getMutable(ty), getMutable(tfti)}; g1 && g2) deserializeChildren(g2, g1); else state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); @@ -823,8 +812,7 @@ private: else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; vPack1 && vPack2) deserializeChildren(vPack2, vPack1); - else if (auto [gPack1, gPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; - FFlag::LuauUserTypeFunGenerics && gPack1 && gPack2) + else if (auto [gPack1, gPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; gPack1 && gPack2) deserializeChildren(gPack2, gPack1); else state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); @@ -909,64 +897,60 @@ private: void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1) { - if (FFlag::LuauUserTypeFunGenerics) + functionScopes.push_back({queue.size(), f2}); + + std::set> genericNames; + + // Introduce generic function parameters into scope + for (auto ty : f2->generics) { - functionScopes.push_back({queue.size(), f2}); + auto gty = get(ty); + LUAU_ASSERT(gty && !gty->isPack); - std::set> genericNames; + std::pair nameKey = std::make_pair(gty->isNamed, gty->name); - // Introduce generic function parameters into scope - for (auto ty : f2->generics) + // Duplicates are not allowed + if (genericNames.find(nameKey) != genericNames.end()) { - auto gty = get(ty); - LUAU_ASSERT(gty && !gty->isPack); - - std::pair nameKey = std::make_pair(gty->isNamed, gty->name); - - // Duplicates are not allowed - if (genericNames.find(nameKey) != genericNames.end()) - { - state->errors.push_back(format("Duplicate type parameter '%s'", gty->name.c_str())); - return; - } - - genericNames.insert(nameKey); - - TypeId mapping = state->ctx->arena->addTV(Type(gty->isNamed ? GenericType{state->ctx->scope.get(), gty->name} : GenericType{})); - genericTypes.push_back({gty->isNamed, gty->name, mapping}); + state->errors.push_back(format("Duplicate type parameter '%s'", gty->name.c_str())); + return; } - for (auto tp : f2->genericPacks) - { - auto gtp = get(tp); - LUAU_ASSERT(gtp); + genericNames.insert(nameKey); - std::pair nameKey = std::make_pair(gtp->isNamed, gtp->name); - - // Duplicates are not allowed - if (genericNames.find(nameKey) != genericNames.end()) - { - state->errors.push_back(format("Duplicate type parameter '%s'", gtp->name.c_str())); - return; - } - - genericNames.insert(nameKey); - - TypePackId mapping = - state->ctx->arena->addTypePack(TypePackVar(gtp->isNamed ? GenericTypePack{state->ctx->scope.get(), gtp->name} : GenericTypePack{}) - ); - genericPacks.push_back({gtp->isNamed, gtp->name, mapping}); - } - - f1->generics.reserve(f2->generics.size()); - for (auto ty : f2->generics) - f1->generics.push_back(shallowDeserialize(ty)); - - f1->genericPacks.reserve(f2->genericPacks.size()); - for (auto tp : f2->genericPacks) - f1->genericPacks.push_back(shallowDeserialize(tp)); + TypeId mapping = state->ctx->arena->addTV(Type(gty->isNamed ? GenericType{state->ctx->scope.get(), gty->name} : GenericType{})); + genericTypes.push_back({gty->isNamed, gty->name, mapping}); } + for (auto tp : f2->genericPacks) + { + auto gtp = get(tp); + LUAU_ASSERT(gtp); + + std::pair nameKey = std::make_pair(gtp->isNamed, gtp->name); + + // Duplicates are not allowed + if (genericNames.find(nameKey) != genericNames.end()) + { + state->errors.push_back(format("Duplicate type parameter '%s'", gtp->name.c_str())); + return; + } + + genericNames.insert(nameKey); + + TypePackId mapping = + state->ctx->arena->addTypePack(TypePackVar(gtp->isNamed ? GenericTypePack{state->ctx->scope.get(), gtp->name} : GenericTypePack{})); + genericPacks.push_back({gtp->isNamed, gtp->name, mapping}); + } + + f1->generics.reserve(f2->generics.size()); + for (auto ty : f2->generics) + f1->generics.push_back(shallowDeserialize(ty)); + + f1->genericPacks.reserve(f2->genericPacks.size()); + for (auto tp : f2->genericPacks) + f1->genericPacks.push_back(shallowDeserialize(tp)); + if (f2->argTypes) f1->argTypes = shallowDeserialize(f2->argTypes); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 25d1f5c2..73f8b1be 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5913,8 +5913,8 @@ GenericTypeDefinitions TypeChecker::createGenericTypes( const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, - const AstArray& genericPackNames, + const AstArray& genericNames, + const AstArray& genericPackNames, bool useCache ) { @@ -5924,14 +5924,14 @@ GenericTypeDefinitions TypeChecker::createGenericTypes( std::vector generics; - for (const AstGenericType& generic : genericNames) + for (const AstGenericType* generic : genericNames) { std::optional defaultValue; - if (generic.defaultValue) - defaultValue = resolveType(scope, *generic.defaultValue); + if (generic->defaultValue) + defaultValue = resolveType(scope, *generic->defaultValue); - Name n = generic.name.value; + Name n = generic->name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic types have the same name. @@ -5960,14 +5960,14 @@ GenericTypeDefinitions TypeChecker::createGenericTypes( std::vector genericPacks; - for (const AstGenericTypePack& genericPack : genericPackNames) + for (const AstGenericTypePack* genericPack : genericPackNames) { std::optional defaultValue; - if (genericPack.defaultValue) - defaultValue = resolveTypePack(scope, *genericPack.defaultValue); + if (genericPack->defaultValue) + defaultValue = resolveTypePack(scope, *genericPack->defaultValue); - Name n = genericPack.name.value; + Name n = genericPack->name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic types have the same name. diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index d4764656..34f0072e 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -120,20 +120,6 @@ struct AstTypeList using AstArgumentName = std::pair; // TODO: remove and replace when we get a common struct for this pair instead of AstName -struct AstGenericType -{ - AstName name; - Location location; - AstType* defaultValue = nullptr; -}; - -struct AstGenericTypePack -{ - AstName name; - Location location; - AstTypePack* defaultValue = nullptr; -}; - extern int gAstRttiIndex; template @@ -253,6 +239,32 @@ public: bool hasSemicolon; }; +class AstGenericType : public AstNode +{ +public: + LUAU_RTTI(AstGenericType) + + explicit AstGenericType(const Location& location, AstName name, AstType* defaultValue = nullptr); + + void visit(AstVisitor* visitor) override; + + AstName name; + AstType* defaultValue = nullptr; +}; + +class AstGenericTypePack : public AstNode +{ +public: + LUAU_RTTI(AstGenericTypePack) + + explicit AstGenericTypePack(const Location& location, AstName name, AstTypePack* defaultValue = nullptr); + + void visit(AstVisitor* visitor) override; + + AstName name; + AstTypePack* defaultValue = nullptr; +}; + class AstExprGroup : public AstExpr { public: @@ -424,8 +436,8 @@ public: AstExprFunction( const Location& location, const AstArray& attributes, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, @@ -443,8 +455,8 @@ public: bool hasNativeAttribute() const; AstArray attributes; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstLocal* self; AstArray args; std::optional returnAnnotation; @@ -857,8 +869,8 @@ public: const Location& location, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported ); @@ -867,8 +879,8 @@ public: AstName name; Location nameLocation; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -911,8 +923,8 @@ public: const Location& location, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, bool vararg, @@ -925,8 +937,8 @@ public: const AstArray& attributes, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, bool vararg, @@ -942,8 +954,8 @@ public: AstArray attributes; AstName name; Location nameLocation; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList params; AstArray paramNames; bool vararg = false; @@ -1074,8 +1086,8 @@ public: AstTypeFunction( const Location& location, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes @@ -1084,8 +1096,8 @@ public: AstTypeFunction( const Location& location, const AstArray& attributes, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes @@ -1096,8 +1108,8 @@ public: bool isCheckedFunction() const; AstArray attributes; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; @@ -1276,6 +1288,16 @@ public: return visit(static_cast(node)); } + virtual bool visit(class AstGenericType* node) + { + return visit(static_cast(node)); + } + + virtual bool visit(class AstGenericTypePack* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExpr* node) { return visit(static_cast(node)); diff --git a/Ast/include/Luau/Cst.h b/Ast/include/Luau/Cst.h index bea3df90..af3198cc 100644 --- a/Ast/include/Luau/Cst.h +++ b/Ast/include/Luau/Cst.h @@ -141,6 +141,16 @@ public: Position opPosition; }; +class CstExprTypeAssertion : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprTypeAssertion) + + explicit CstExprTypeAssertion(Position opPosition); + + Position opPosition; +}; + class CstExprIfElse : public CstNode { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 584782ee..8c7fac74 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -222,8 +222,8 @@ private: AstType* parseFunctionTypeTail( const Lexeme& begin, const AstArray& attributes, - AstArray generics, - AstArray genericPacks, + AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation @@ -294,7 +294,7 @@ private: Name parseIndexName(const char* context, const Position& previous); // `<' namelist `>' - std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); + std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); // `<' Type[, ...] `>' AstArray parseTypeParams( @@ -474,8 +474,8 @@ private: std::vector scratchItem; std::vector scratchCstItem; std::vector scratchArgName; - std::vector scratchGenericTypes; - std::vector scratchGenericTypePacks; + std::vector scratchGenericTypes; + std::vector scratchGenericTypePacks; std::vector> scratchOptArgName; std::vector scratchPosition; std::string scratchData; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 5fa63149..ab42ec8c 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -28,6 +28,38 @@ void AstAttr::visit(AstVisitor* visitor) int gAstRttiIndex = 0; +AstGenericType::AstGenericType(const Location& location, AstName name, AstType* defaultValue) + : AstNode(ClassIndex(), location) + , name(name) + , defaultValue(defaultValue) +{ +} + +void AstGenericType::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + if (defaultValue) + defaultValue->visit(visitor); + } +} + +AstGenericTypePack::AstGenericTypePack(const Location& location, AstName name, AstTypePack* defaultValue) + : AstNode(ClassIndex(), location) + , name(name) + , defaultValue(defaultValue) +{ +} + +void AstGenericTypePack::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + if (defaultValue) + defaultValue->visit(visitor); + } +} + AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) : AstExpr(ClassIndex(), location) , expr(expr) @@ -185,8 +217,8 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) AstExprFunction::AstExprFunction( const Location& location, const AstArray& attributes, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, @@ -721,8 +753,8 @@ AstStatTypeAlias::AstStatTypeAlias( const Location& location, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported ) @@ -740,16 +772,14 @@ void AstStatTypeAlias::visit(AstVisitor* visitor) { if (visitor->visit(this)) { - for (const AstGenericType& el : generics) + for (AstGenericType* el : generics) { - if (el.defaultValue) - el.defaultValue->visit(visitor); + el->visit(visitor); } - for (const AstGenericTypePack& el : genericPacks) + for (AstGenericTypePack* el : genericPacks) { - if (el.defaultValue) - el.defaultValue->visit(visitor); + el->visit(visitor); } type->visit(visitor); @@ -795,8 +825,8 @@ AstStatDeclareFunction::AstStatDeclareFunction( const Location& location, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, bool vararg, @@ -822,8 +852,8 @@ AstStatDeclareFunction::AstStatDeclareFunction( const AstArray& attributes, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, bool vararg, @@ -970,8 +1000,8 @@ void AstTypeTable::visit(AstVisitor* visitor) AstTypeFunction::AstTypeFunction( const Location& location, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes @@ -990,8 +1020,8 @@ AstTypeFunction::AstTypeFunction( AstTypeFunction::AstTypeFunction( const Location& location, const AstArray& attributes, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes diff --git a/Ast/src/Cst.cpp b/Ast/src/Cst.cpp index e2faf6e7..9f951ef2 100644 --- a/Ast/src/Cst.cpp +++ b/Ast/src/Cst.cpp @@ -50,6 +50,12 @@ CstExprOp::CstExprOp(Position opPosition) { } +CstExprTypeAssertion::CstExprTypeAssertion(Position opPosition) + : CstNode(CstClassIndex()) + , opPosition(opPosition) +{ +} + CstExprIfElse::CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf) : CstNode(CstClassIndex()) , thenPosition(thenPosition) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 3fa0ccc9..1fce2216 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1066,8 +1066,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; generics.size = 0; generics.data = nullptr; genericPacks.size = 0; @@ -2035,8 +2035,8 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray AstType* Parser::parseFunctionTypeTail( const Lexeme& begin, const AstArray& attributes, - AstArray generics, - AstArray genericPacks, + AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation @@ -2824,9 +2824,18 @@ AstExpr* Parser::parseAssertionExpr() if (lexer.current().type == Lexeme::DoubleColon) { + CstExprTypeAssertion* cstNode = nullptr; + if (FFlag::LuauStoreCSTData && options.storeCstData) + { + Position opPosition = lexer.current().location.begin; + cstNode = allocator.alloc(opPosition); + } nextLexeme(); AstType* annotation = parseType(); - return allocator.alloc(Location(start, annotation->location), expr, annotation); + AstExprTypeAssertion* node = allocator.alloc(Location(start, annotation->location), expr, annotation); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[node] = cstNode; + return node; } else return expr; @@ -3305,10 +3314,10 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou return Name(nameError, location); } -std::pair, AstArray> Parser::parseGenericTypeList(bool withDefaultValues) +std::pair, AstArray> Parser::parseGenericTypeList(bool withDefaultValues) { - TempVector names{scratchGenericTypes}; - TempVector namePacks{scratchGenericTypePacks}; + TempVector names{scratchGenericTypes}; + TempVector namePacks{scratchGenericTypePacks}; if (lexer.current().type == '<') { @@ -3340,7 +3349,7 @@ std::pair, AstArray> Parser::parseG { AstTypePack* typePack = parseTypePack(); - namePacks.push_back({name, nameLocation, typePack}); + namePacks.push_back(allocator.alloc(nameLocation, name, typePack)); } else { @@ -3349,7 +3358,7 @@ std::pair, AstArray> Parser::parseG if (type) report(type->location, "Expected type pack after '=', got type"); - namePacks.push_back({name, nameLocation, typePack}); + namePacks.push_back(allocator.alloc(nameLocation, name, typePack)); } } else @@ -3357,7 +3366,7 @@ std::pair, AstArray> Parser::parseG if (seenDefault) report(lexer.current().location, "Expected default type pack after type pack name"); - namePacks.push_back({name, nameLocation, nullptr}); + namePacks.push_back(allocator.alloc(nameLocation, name, nullptr)); } } else @@ -3369,14 +3378,14 @@ std::pair, AstArray> Parser::parseG AstType* defaultType = parseType(); - names.push_back({name, nameLocation, defaultType}); + names.push_back(allocator.alloc(nameLocation, name, defaultType)); } else { if (seenDefault) report(lexer.current().location, "Expected default type after type name"); - names.push_back({name, nameLocation, nullptr}); + names.push_back(allocator.alloc(nameLocation, name, nullptr)); } } @@ -3397,8 +3406,8 @@ std::pair, AstArray> Parser::parseG expectMatchAndConsume('>', begin); } - AstArray generics = copy(names); - AstArray genericPacks = copy(namePacks); + AstArray generics = copy(names); + AstArray genericPacks = copy(namePacks); return {generics, genericPacks}; } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 29cf5c05..5ef2bf93 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -4298,8 +4298,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c AstExprFunction main( root->location, /* attributes= */ AstArray({nullptr, 0}), - /* generics= */ AstArray(), - /* genericPacks= */ AstArray(), + /* generics= */ AstArray(), + /* genericPacks= */ AstArray(), /* self= */ nullptr, AstArray(), /* vararg= */ true, diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index e251447b..34a27f4f 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -6,10 +6,10 @@ namespace Luau { -static bool isGeneric(AstName name, const AstArray& generics) +static bool isGeneric(AstName name, const AstArray& generics) { - for (const AstGenericType& gt : generics) - if (gt.name == name) + for (const AstGenericType* gt : generics) + if (gt->name == name) return true; return false; @@ -39,7 +39,7 @@ static LuauBytecodeType getPrimitiveType(AstName name) static LuauBytecodeType getType( const AstType* ty, - const AstArray& generics, + const AstArray& generics, const DenseHashMap& typeAliases, bool resolveAliases, const char* hostVectorType, diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 4262eb49..1d23b155 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,6 +11,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauLibWhereErrorAutoreserve) + // convert a stack index to positive #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -67,6 +69,7 @@ static l_noret tag_error(lua_State* L, int narg, int tag) luaL_typeerrorL(L, narg, lua_typename(L, tag)); } +// Can be called without stack space reservation void luaL_where(lua_State* L, int level) { lua_Debug ar; @@ -75,9 +78,14 @@ void luaL_where(lua_State* L, int level) lua_pushfstring(L, "%s:%d: ", ar.short_src, ar.currentline); return; } + + if (FFlag::LuauLibWhereErrorAutoreserve) + lua_rawcheckstack(L, 1); + lua_pushliteral(L, ""); // else, no information available... } +// Can be called without stack space reservation l_noret luaL_errorL(lua_State* L, const char* fmt, ...) { va_list argp; diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index b5a4bd13..5c9402f9 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,6 +8,8 @@ #include #include +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauStringFormatFixC, false) + // macro to `unsign' a character #define uchar(c) ((unsigned char)(c)) @@ -999,8 +1001,17 @@ static int str_format(lua_State* L) { case 'c': { - snprintf(buff, sizeof(buff), form, (int)luaL_checknumber(L, arg)); - break; + if (DFFlag::LuauStringFormatFixC) + { + int count = snprintf(buff, sizeof(buff), form, (int)luaL_checknumber(L, arg)); + luaL_addlstring(&b, buff, count); + continue; // skip the 'luaL_addlstring' at the end + } + else + { + snprintf(buff, sizeof(buff), form, (int)luaL_checknumber(L, arg)); + break; + } } case 'd': case 'i': diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 60c9b313..53cbd07b 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -31,6 +31,7 @@ extern int optimizationLevel; void luaC_fullgc(lua_State* L); void luaC_validate(lua_State* L); +LUAU_FASTFLAG(LuauLibWhereErrorAutoreserve) LUAU_FASTFLAG(LuauMathLerp) LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) @@ -40,7 +41,7 @@ LUAU_FASTFLAG(LuauVectorLibNativeDot) LUAU_FASTFLAG(LuauVector2Constructor) LUAU_FASTFLAG(LuauBufferBitMethods2) LUAU_FASTFLAG(LuauCodeGenLimitLiveSlotReuse) -LUAU_FASTFLAG(LuauMathMapDefinition) +LUAU_DYNAMIC_FASTFLAG(LuauStringFormatFixC) static lua_CompileOptions defaultOptions() { @@ -718,6 +719,8 @@ TEST_CASE("Clear") TEST_CASE("Strings") { + ScopedFastFlag luauStringFormatFixC{DFFlag::LuauStringFormatFixC, true}; + runConformance("strings.lua"); } @@ -988,7 +991,6 @@ TEST_CASE("Types") { ScopedFastFlag luauVector2Constructor{FFlag::LuauVector2Constructor, true}; ScopedFastFlag luauMathLerp{FFlag::LuauMathLerp, true}; - ScopedFastFlag luauMathMapDefinition{FFlag::LuauMathMapDefinition, true}; runConformance( "types.lua", @@ -1719,7 +1721,31 @@ TEST_CASE("ApiBuffer") lua_pop(L, 1); } -TEST_CASE("AllocApi") +int slowlyOverflowStack(lua_State* L) +{ + for (int i = 0; i < LUAI_MAXCSTACK * 2; i++) + { + luaL_checkstack(L, 1, "test"); + lua_pushnumber(L, 1.0); + } + + return 0; +} + +TEST_CASE("ApiStack") +{ + ScopedFastFlag luauLibWhereErrorAutoreserve{FFlag::LuauLibWhereErrorAutoreserve, true}; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_pushcfunction(L, slowlyOverflowStack, "foo"); + int result = lua_pcall(L, 0, 0, 0); + REQUIRE(result == LUA_ERRRUN); + CHECK(strcmp(luaL_checkstring(L, -1), "stack overflow (test)") == 0); +} + +TEST_CASE("ApiAlloc") { int ud = 0; StateRef globalState(lua_newstate(limitedRealloc, &ud), lua_close); diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index f613e750..b7d67d74 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -6,6 +6,7 @@ #include "Luau/Ast.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/Error.h" #include "Luau/IostreamHelpers.h" #include "Luau/ModuleResolver.h" #include "Luau/VisitType.h" @@ -16,6 +17,8 @@ LUAU_FASTFLAG(LuauCountSelfCallsNonstrict) LUAU_FASTFLAG(LuauVector2Constructor) +LUAU_FASTFLAG(LuauNewNonStrictWarnOnUnknownGlobals) +LUAU_FASTFLAG(LuauNonStrictVisitorImprovements) using namespace Luau; @@ -490,6 +493,40 @@ foo.bar("hi") NONSTRICT_REQUIRE_CHECKED_ERR(Position(1, 8), "foo.bar", result); } +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "exprgroup_is_checked") +{ + ScopedFastFlag sff{FFlag::LuauNonStrictVisitorImprovements, true}; + + CheckResult result = checkNonStrict(R"( + local foo = (abs("foo")) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto r1 = get(result.errors[0]); + LUAU_ASSERT(r1); + CHECK_EQ("abs", r1->checkedFunctionName); + CHECK_EQ("number", toString(r1->expected)); + CHECK_EQ("string", toString(r1->passed)); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "binop_is_checked") +{ + ScopedFastFlag sff{FFlag::LuauNonStrictVisitorImprovements, true}; + + CheckResult result = checkNonStrict(R"( + local foo = 4 + abs("foo") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto r1 = get(result.errors[0]); + LUAU_ASSERT(r1); + CHECK_EQ("abs", r1->checkedFunctionName); + CHECK_EQ("number", toString(r1->expected)); + CHECK_EQ("string", toString(r1->passed)); +} + TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "incorrect_arg_count") { CheckResult result = checkNonStrict(R"( @@ -602,4 +639,22 @@ TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_method_calls") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "unknown_globals_in_non_strict") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauNonStrictVisitorImprovements, true}, + {FFlag::LuauNewNonStrictWarnOnUnknownGlobals, true} + }; + + CheckResult result = check(Mode::Nonstrict, R"( + foo = 5 + local wrong1 = foob + + local x = 12 + local wrong2 = x + foblm + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index ebbfd4dc..d726dfb2 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -944,6 +944,16 @@ TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "type_assertion_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "local a = 5 :: number"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = "local a = 5 :: number"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") { std::string code = "local a = if 1 then 2 else 3"; @@ -1408,6 +1418,10 @@ TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple_types") TEST_CASE_FIXTURE(Fixture, "transpile_string_interp") { + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; std::string code = R"( local _ = `hello {name}` )"; CHECK_EQ(code, transpile(code, {}, true).code); @@ -1452,6 +1466,10 @@ TEST_CASE_FIXTURE(Fixture, "transpile_string_interp_multiline_escape") TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") { + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; std::string code = R"( local _ = ` bracket = \{, backtick = \` = {'ok'} ` )"; CHECK_EQ(code, transpile(code, {}, true).code); diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp index bdde63f5..8f48961d 100644 --- a/tests/TypeFunction.user.test.cpp +++ b/tests/TypeFunction.user.test.cpp @@ -8,9 +8,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUserTypeFunFixInner) -LUAU_FASTFLAG(LuauUserTypeFunGenerics) -LUAU_FASTFLAG(LuauUserTypeFunCloneTail) LUAU_FASTFLAG(DebugLuauEqSatSimplification) TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); @@ -475,7 +472,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_negation_methods_work") TEST_CASE_FIXTURE(ClassFixture, "udtf_negation_inner") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunFixInner{FFlag::LuauUserTypeFunFixInner, true}; CheckResult result = check(R"( type function pass(t) @@ -1404,7 +1400,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "print_to_error_plus_no_result") TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_serialization_1") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1422,7 +1417,6 @@ local function ok(idx: pass): test return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_serialization_2") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1440,7 +1434,6 @@ local function ok(idx: pass): test return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_serialization_3") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1462,7 +1455,6 @@ local function ok(idx: pass): test return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_cloning_1") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1480,8 +1472,6 @@ local function ok(idx: pass): test return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_cloning_2") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; - ScopedFastFlag luauUserTypeFunCloneTail{FFlag::LuauUserTypeFunCloneTail, true}; CheckResult result = check(R"( type function pass(arg) @@ -1499,7 +1489,6 @@ local function ok(idx: pass): test return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_equality") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1517,7 +1506,6 @@ local function ok(idx: pass): true return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_1") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1537,7 +1525,6 @@ local function ok(idx: pass): (T) -> (T) return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_2") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1561,7 +1548,6 @@ local function ok(idx: pass): (T, T) -> (T) return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_3") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass() @@ -1591,7 +1577,6 @@ local function ok(idx: pass<>): (T, U...) -> (T, V...) return idx TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_4") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass() @@ -1618,7 +1603,6 @@ local function ok(idx: pass<>): test return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_5") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass() @@ -1635,7 +1619,6 @@ local function ok(idx: pass<>): (T) -> () return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_6") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1663,7 +1646,6 @@ local function ok(idx: pass): (T) -> (U) return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_7") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1686,7 +1668,6 @@ local function ok(idx: pass): (T, U...) -> (T, U...) return idx e TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_8") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1709,7 +1690,6 @@ local function ok(idx: pass): (T, T) -> (T) return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_equality_2") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function get() @@ -1730,7 +1710,6 @@ local function ok(idx: get<>): false return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_1") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function get() @@ -1750,7 +1729,6 @@ local function ok(idx: get<>): false return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_2") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function get() @@ -1767,7 +1745,6 @@ local function ok(idx: get<>): false return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_3") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function get() @@ -1789,7 +1766,6 @@ local function ok(idx: get<>): false return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_4") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function get() @@ -1806,7 +1782,6 @@ local function ok(idx: get<>): false return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_5") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function get() @@ -1823,7 +1798,6 @@ local function ok(idx: get<>): false return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_6") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function get() @@ -1840,7 +1814,6 @@ local function ok(idx: get<>): false return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_7") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function get() @@ -1857,7 +1830,6 @@ local function ok(idx: get<>): false return idx end TEST_CASE_FIXTURE(ClassFixture, "udtf_variadic_api") { ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; - ScopedFastFlag luauUserTypeFunGenerics{FFlag::LuauUserTypeFunGenerics, true}; CheckResult result = check(R"( type function pass(arg) @@ -1878,7 +1850,7 @@ local function ok(idx: pass): (number, ...string) -> (string, ...number) r TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_eqsat_opaque") { - ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauUserTypeFunGenerics, true}, {FFlag::DebugLuauEqSatSimplification, true}}; + ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::DebugLuauEqSatSimplification, true}}; CheckResult _ = check(R"( type function t0(a) diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index ce1cef29..84cfb97b 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -11,6 +11,8 @@ using namespace Luau; LUAU_FASTFLAG(LuauNewSolverPrePopulateClasses) LUAU_FASTFLAG(LuauClipNestedAndRecursiveUnion) +LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTFLAG(LuauPreventReentrantTypeFunctionReduction) TEST_SUITE_BEGIN("DefinitionTests"); @@ -557,5 +559,32 @@ TEST_CASE_FIXTURE(Fixture, "recursive_redefinition_reduces_rightfully") )")); } +TEST_CASE_FIXTURE(Fixture, "vector3_overflow") +{ + ScopedFastFlag _{FFlag::LuauPreventReentrantTypeFunctionReduction, true}; + // We set this to zero to ensure that we either run to completion or stack overflow here. + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 0}; + + loadDefinition(R"( + declare class Vector3 + function __add(self, other: Vector3): Vector3 + end + )"); + + CheckResult result = check(R"( +--!strict +local function graphPoint(t : number, points : { Vector3 }) : Vector3 + local n : number = #points - 1 + local p : Vector3 = (nil :: any) + for i = 0, n do + local x = points[i + 1] + p = p and p + x or x + end + return p +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 535b9961..8f816bc2 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -9,7 +9,6 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(DebugLuauEqSatSimplification) -LUAU_FASTFLAG(InferGlobalTypes) LUAU_FASTFLAG(LuauGeneralizationRemoveRecursiveUpperBound) using namespace Luau; @@ -1881,8 +1880,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "refine_a_param_that_got_resolved_duri TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") { - ScopedFastFlag sff{FFlag::InferGlobalTypes, true}; - CheckResult result = check(R"( foo = { bar = 5 :: number? } diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index ea893528..4cf027b1 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -23,6 +23,9 @@ LUAU_FASTFLAG(LuauAllowNilAssignmentToIndexer) LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) LUAU_FASTFLAG(LuauTrackInteriorFreeTablesOnScope) LUAU_FASTFLAG(LuauDontInPlaceMutateTableType) +LUAU_FASTFLAG(LuauAllowNonSharedTableTypesInLiteral) +LUAU_FASTFLAG(LuauFollowTableFreeze) +LUAU_FASTFLAG(LuauPrecalculateMutatedFreeTypes) TEST_SUITE_BEGIN("TableTests"); @@ -5005,17 +5008,22 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_union_type") TEST_CASE_FIXTURE(Fixture, "function_check_constraint_too_eager") { - ScopedFastFlag _{FFlag::LuauSolverV2, true}; + // NOTE: All of these examples should have no errors, but + // bidirectional inference is known to be broken. + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauPrecalculateMutatedFreeTypes, true}, + }; - // CLI-121540: All of these examples should have no errors. - - LUAU_CHECK_ERROR_COUNT(3, check(R"( + auto result = check(R"( local function doTheThing(_: { [string]: unknown }) end doTheThing({ ['foo'] = 5, ['bar'] = 'heyo', }) - )")); + )"); + LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_CHECK_NO_ERROR(result, ConstraintSolvingIncompleteError); LUAU_CHECK_ERROR_COUNT(1, check(R"( type Input = { [string]: unknown } @@ -5028,7 +5036,7 @@ TEST_CASE_FIXTURE(Fixture, "function_check_constraint_too_eager") // This example previously asserted due to eagerly mutating the underlying // table type. - LUAU_CHECK_ERROR_COUNT(3, check(R"( + result = check(R"( type Input = { [string]: unknown } local function doTheThing(_: Input) end @@ -5037,7 +5045,9 @@ TEST_CASE_FIXTURE(Fixture, "function_check_constraint_too_eager") [('%s'):format('3.14')]=5, ['stringField']='Heyo' }) - )")); + )"); + LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_CHECK_NO_ERROR(result, ConstraintSolvingIncompleteError); } @@ -5091,4 +5101,56 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "multiple_fields_in_literal") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "multiple_fields_from_fuzzer") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontInPlaceMutateTableType, true}, + {FFlag::LuauAllowNonSharedTableTypesInLiteral, true}, + }; + + // This would trigger an assert previously, so we really only care that + // there are errors (and there will be: lots of syntax errors). + LUAU_CHECK_ERRORS(check(R"( + function _(l0,l0) _(_,{n0=_,n0=_,},if l0:n0()[_] then _) + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "write_only_table_field_duplicate") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontInPlaceMutateTableType, true}, + {FFlag::LuauAllowNonSharedTableTypesInLiteral, true}, + }; + + auto result = check(R"( + type WriteOnlyTable = { write x: number } + local wo: WriteOnlyTable = { + x = 42, + x = 13, + } + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + CHECK_EQ("write keyword is illegal here", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_musnt_assert") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauFollowTableFreeze, true}, + }; + + auto result = check(R"( + local m = {} + function m.foo() + local self = { entries = entries, _caches = {}} + local self = setmetatable(self, {}) + table.freeze(self) + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 2ff97a25..9bb22798 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -16,15 +16,16 @@ #include -LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); -LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAG(LuauInstantiateInSubtyping); -LUAU_FASTINT(LuauCheckRecursionLimit); -LUAU_FASTINT(LuauNormalizeCacheLimit); -LUAU_FASTINT(LuauRecursionLimit); -LUAU_FASTINT(LuauTypeInferRecursionLimit); -LUAU_FASTFLAG(InferGlobalTypes) +LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTINT(LuauCheckRecursionLimit) +LUAU_FASTINT(LuauNormalizeCacheLimit) +LUAU_FASTINT(LuauRecursionLimit) +LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauAstTypeGroup) +LUAU_FASTFLAG(LuauNewNonStrictWarnOnUnknownGlobals) +LUAU_FASTFLAG(LuauInferLocalTypesInMultipleAssignments) using namespace Luau; @@ -819,7 +820,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_heap_use_after_free_error") end )"); - if (FFlag::LuauSolverV2) + if (FFlag::LuauSolverV2 && !FFlag::LuauNewNonStrictWarnOnUnknownGlobals) LUAU_REQUIRE_NO_ERRORS(result); else LUAU_REQUIRE_ERRORS(result); @@ -1770,7 +1771,6 @@ TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type") TEST_CASE_FIXTURE(BuiltinsFixture, "infer_types_of_globals") { ScopedFastFlag sff_LuauSolverV2{FFlag::LuauSolverV2, true}; - ScopedFastFlag sff_InferGlobalTypes{FFlag::InferGlobalTypes, true}; CheckResult result = check(R"( --!strict @@ -1784,4 +1784,25 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "infer_types_of_globals") CHECK_EQ("Unknown global 'foo'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "multiple_assignment") +{ + ScopedFastFlag sff_LuauSolverV2{FFlag::LuauSolverV2, true}; + ScopedFastFlag sff_InferLocalTypesInMultipleAssignments{FFlag::LuauInferLocalTypesInMultipleAssignments, true}; + + CheckResult result = check(R"( + local function requireString(arg: string) end + local function requireNumber(arg: number) end + + local function f(): ...number end + + local w: "a", x, y, z = "a", 1, f() + requireString(w) + requireNumber(x) + requireNumber(y) + requireNumber(z) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/conformance/strings.lua b/tests/conformance/strings.lua index 370641d9..857a4bb9 100644 --- a/tests/conformance/strings.lua +++ b/tests/conformance/strings.lua @@ -61,7 +61,7 @@ assert(#"\0\0\0" == 3) assert(#"1234567890" == 10) assert(string.byte("a") == 97) -assert(string.byte("á") > 127) +assert(string.byte("\xe4") > 127) assert(string.byte(string.char(255)) == 255) assert(string.byte(string.char(0)) == 0) assert(string.byte("\0") == 0) @@ -76,10 +76,10 @@ assert(string.byte("hi", 9, 10) == nil) assert(string.byte("hi", 2, 1) == nil) assert(string.char() == "") assert(string.char(0, 255, 0) == "\0\255\0") -assert(string.char(0, string.byte("á"), 0) == "\0á\0") -assert(string.char(string.byte("ál\0óu", 1, -1)) == "ál\0óu") -assert(string.char(string.byte("ál\0óu", 1, 0)) == "") -assert(string.char(string.byte("ál\0óu", -10, 100)) == "ál\0óu") +assert(string.char(0, string.byte("\xe4"), 0) == "\0\xe4\0") +assert(string.char(string.byte("\xe4l\0óu", 1, -1)) == "\xe4l\0óu") +assert(string.char(string.byte("\xe4l\0óu", 1, 0)) == "") +assert(string.char(string.byte("\xe4l\0óu", -10, 100)) == "\xe4l\0óu") assert(pcall(function() return string.char(256) end) == false) assert(pcall(function() return string.char(-1) end) == false) print('+') @@ -87,7 +87,7 @@ print('+') assert(string.upper("ab\0c") == "AB\0C") assert(string.lower("\0ABCc%$") == "\0abcc%$") assert(string.rep('teste', 0) == '') -assert(string.rep('tés\00tê', 2) == 'tés\0têtés\000tê') +assert(string.rep('tés\00tê', 2) == 'tés\0têtés\000tê') assert(string.rep('', 10) == '') assert(string.rep('', 1e9) == '') assert(pcall(string.rep, 'x', 2e9) == false) @@ -115,15 +115,18 @@ assert(pcall(function() return tostring(nothing()) end) == false) print('+') -x = '"ílo"\n\\' -assert(string.format('%q%s', x, x) == '"\\"ílo\\"\\\n\\\\""ílo"\n\\') +x = '"ílo"\n\\' +assert(string.format('%q%s', x, x) == '"\\"ílo\\"\\\n\\\\""ílo"\n\\') assert(string.format('%q', "\0") == [["\000"]]) assert(string.format('%q', "\r") == [["\r"]]) -assert(string.format("\0%c\0%c%x\0", string.byte("á"), string.byte("b"), 140) == - "\0á\0b8c\0") +assert(string.format("\0%c\0%c%x\0", string.byte("\xe4"), string.byte("b"), 140) == + "\0\xe4\0b8c\0") assert(string.format('') == "") assert(string.format("%c",34)..string.format("%c",48)..string.format("%c",90)..string.format("%c",100) == string.format("%c%c%c%c", 34, 48, 90, 100)) +assert(string.format("%c%c%c%c", 1, 0, 2, 3) == '\1\0\2\3') +assert(string.format("%5c%5c%5c%5c", 1, 0, 2, 3) == ' \1 \0 \2 \3') +assert(string.format("%-5c%-5c%-5c%-5c", 1, 0, 2, 3) == '\1 \0 \2 \3 ') assert(string.format("%s\0 is not \0%s", 'not be', 'be') == 'not be\0 is not \0be') assert(string.format("%%%d %010d", 10, 23) == "%10 0000000023") assert(tonumber(string.format("%f", 10.3)) == 10.3) @@ -184,7 +187,7 @@ assert(pcall(function() string.format("%#*", "bad form") end) == false) -assert(loadstring("return 1\n--comentário sem EOL no final")() == 1) +assert(loadstring("return 1\n--comentário sem EOL no final")() == 1) assert(table.concat{} == "") @@ -244,16 +247,16 @@ end if not trylocale("collate") then print("locale not supported") else - assert("alo" < "álo" and "álo" < "amo") + assert("alo" < "álo" and "álo" < "amo") end if not trylocale("ctype") then print("locale not supported") else - assert(string.gsub("áéíóú", "%a", "x") == "xxxxx") - assert(string.gsub("áÁéÉ", "%l", "x") == "xÁxÉ") - assert(string.gsub("áÁéÉ", "%u", "x") == "áxéx") - assert(string.upper"áÁé{xuxu}ção" == "ÁÁÉ{XUXU}ÇÃO") + assert(string.gsub("áéíóú", "%a", "x") == "xxxxx") + assert(string.gsub("áÃéÉ", "%l", "x") == "xÃxÉ") + assert(string.gsub("áÃéÉ", "%u", "x") == "áxéx") + assert(string.upper"áÃé{xuxu}ção" == "ÃÃÉ{XUXU}ÇÃO") end os.setlocale("C")