From 816e41a8f2eae7ca2d9d9fb4d2f847e33c087650 Mon Sep 17 00:00:00 2001 From: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com> Date: Thu, 10 Nov 2022 14:53:13 -0800 Subject: [PATCH 1/2] Sync to upstream/release/553 (#742) * Type inference of `a and b` and `a or b` has been improved (Fixes https://github.com/Roblox/luau/issues/730) * Improved error message when `for ... in x` loop iterates over a value that could be 'nil' * Return type of `next` not includes 'nil' (Fixes https://github.com/Roblox/luau/issues/706) * Improved type inference of `string` type * Luau library table type names are now reported as `typeof(string)`/etc instead of just `string` which was misleading * Added parsing error when optional type alias type parameter wasn't provided after `=` token * Improved tagged union type refinement in conditional expressions, type in `else` branch should no longer include previously handled union options --- .../include/Luau/ConstraintGraphBuilder.h | 20 +- Analysis/include/Luau/ConstraintSolver.h | 4 +- Analysis/include/Luau/Scope.h | 5 - Analysis/include/Luau/ToString.h | 2 - Analysis/include/Luau/TypeInfer.h | 4 +- Analysis/src/BuiltinDefinitions.cpp | 94 ++++++-- Analysis/src/ConstraintGraphBuilder.cpp | 191 ++++++++++++---- Analysis/src/ConstraintSolver.cpp | 108 +++++---- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 8 +- Analysis/src/Frontend.cpp | 3 +- Analysis/src/Module.cpp | 14 -- Analysis/src/ToString.cpp | 110 ++++----- Analysis/src/TypeChecker2.cpp | 35 ++- Analysis/src/TypeInfer.cpp | 202 +++++++++++++---- Analysis/src/TypeVar.cpp | 204 +++++++++++++++-- Analysis/src/Unifier.cpp | 43 +++- Ast/src/Parser.cpp | 13 +- CodeGen/include/Luau/AddressA64.h | 7 +- CodeGen/include/Luau/AssemblyBuilderA64.h | 29 ++- CodeGen/src/AssemblyBuilderA64.cpp | 210 ++++++++++++++---- CodeGen/src/CodeGenX64.cpp | 74 +++--- CodeGen/src/EmitCommonX64.h | 8 +- CodeGen/src/UnwindBuilderDwarf2.cpp | 20 +- CodeGen/src/UnwindBuilderWin.cpp | 6 +- Common/include/Luau/ExperimentalFlags.h | 2 + tests/AssemblyBuilderA64.test.cpp | 65 +++++- tests/CodeAllocator.test.cpp | 9 +- tests/ConstraintGraphBuilderFixture.cpp | 4 +- tests/Frontend.test.cpp | 8 - tests/Parser.test.cpp | 17 ++ tests/Repl.test.cpp | 17 ++ tests/ToString.test.cpp | 71 +++--- tests/TypeInfer.aliases.test.cpp | 13 +- tests/TypeInfer.builtins.test.cpp | 6 +- tests/TypeInfer.functions.test.cpp | 2 +- tests/TypeInfer.operators.test.cpp | 83 ++++++- tests/TypeInfer.provisional.test.cpp | 26 +-- tests/TypeInfer.refinements.test.cpp | 189 +++++++++++++--- tests/TypeInfer.tables.test.cpp | 99 ++++++++- tests/TypeInfer.test.cpp | 37 --- tests/TypeInfer.typePacks.cpp | 2 - tests/TypeInfer.unionTypes.test.cpp | 17 ++ tests/TypeInfer.unknownnever.test.cpp | 4 +- tools/faillist.txt | 32 +-- 44 files changed, 1535 insertions(+), 582 deletions(-) diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index cb5900ea..2bfc62f1 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -66,6 +66,14 @@ struct ConstraintGraphBuilder // The root scope of the module we're generating constraints for. // This is null when the CGB is initially constructed. Scope* rootScope; + + // Constraints that go straight to the solver. + std::vector constraints; + + // Constraints that do not go to the solver right away. Other constraints + // will enqueue them during solving. + std::vector unqueuedConstraints; + // A mapping of AST node to TypeId. DenseHashMap astTypes{nullptr}; // A mapping of AST node to TypePackId. @@ -252,16 +260,8 @@ struct ConstraintGraphBuilder void prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program); }; -/** - * Collects a vector of borrowed constraints from the scope and all its child - * scopes. It is important to only call this function when you're done adding - * constraints to the scope or its descendants, lest the borrowed pointers - * become invalid due to a container reallocation. - * @param rootScope the root scope of the scope graph to collect constraints - * from. - * @return a list of pointers to constraints contained within the scope graph. - * None of these pointers should be null. +/** Borrow a vector of pointers from a vector of owning pointers to constraints. */ -std::vector> collectConstraints(NotNull rootScope); +std::vector> borrowConstraints(const std::vector& constraints); } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 07f027ad..7b89a278 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -76,8 +76,8 @@ struct ConstraintSolver DcrLogger* logger; - explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, NotNull moduleResolver, - std::vector requireCycles, DcrLogger* logger); + explicit ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger); // Randomize the order in which to dispatch constraints void randomize(unsigned seed); diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index ccf2964c..a26f506d 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -38,11 +38,6 @@ struct Scope std::unordered_map bindings; TypePackId returnType; std::optional varargPack; - // All constraints belonging to this scope. - std::vector constraints; - // Constraints belonging to this scope that are queued manually by other - // constraints. - std::vector unqueuedConstraints; TypeLevel level; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index ff2561e6..0200a719 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -34,7 +34,6 @@ struct ToStringOptions size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); ToStringNameMap nameMap; - std::optional DEPRECATED_nameMap; std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' std::vector namedFunctionOverrideArgNames; // If present, named function argument names will be overridden }; @@ -42,7 +41,6 @@ struct ToStringOptions struct ToStringResult { std::string name; - ToStringNameMap DEPRECATED_nameMap; bool invalid = false; bool error = false; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index c5d7501d..4eaa5969 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -280,14 +280,14 @@ private: TypeId singletonType(bool value); TypeId singletonType(std::string value); - TypeIdPredicate mkTruthyPredicate(bool sense); + TypeIdPredicate mkTruthyPredicate(bool sense, TypeId emptySetTy); // TODO: Return TypeId only. std::optional filterMapImpl(TypeId type, TypeIdPredicate predicate); std::pair, bool> filterMap(TypeId type, TypeIdPredicate predicate); public: - std::pair, bool> pickTypesFromSense(TypeId type, bool sense); + std::pair, bool> pickTypesFromSense(TypeId type, bool sense, TypeId emptySetTy); private: TypeId unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes = true); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index ee53ae6b..67e3979a 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -17,7 +17,9 @@ LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauBuiltInMetatableNoBadSynthetic, false) +LUAU_FASTFLAG(LuauOptionalNextKey) LUAU_FASTFLAG(LuauReportShadowedTypeAlias) +LUAU_FASTFLAG(LuauNewLibraryTypeNames) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -276,18 +278,38 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - // next(t: Table, i: K?) -> (K, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); - addGlobalBinding(typeChecker, "next", - arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + if (FFlag::LuauOptionalNextKey) + { + // next(t: Table, i: K?) -> (K?, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(typeChecker, arena, genericK), genericV}}); + addGlobalBinding(typeChecker, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); - // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } + else + { + // next(t: Table, i: K?) -> (K, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); + addGlobalBinding(typeChecker, "next", + arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); @@ -319,7 +341,12 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) if (TableTypeVar* ttv = getMutable(pair.second.typeId)) { if (!ttv->name) - ttv->name = toString(pair.first); + { + if (FFlag::LuauNewLibraryTypeNames) + ttv->name = "typeof(" + toString(pair.first) + ")"; + else + ttv->name = toString(pair.first); + } } } @@ -370,18 +397,38 @@ void registerBuiltinGlobals(Frontend& frontend) addGlobalBinding(frontend, "string", it->second.type, "@luau"); - // next(t: Table, i: K?) -> (K, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); - addGlobalBinding(frontend, "next", - arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + if (FFlag::LuauOptionalNextKey) + { + // next(t: Table, i: K?) -> (K?, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(frontend, arena, genericK), genericV}}); + addGlobalBinding(frontend, "next", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); - TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, nextRetsTypePack}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); - // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding(frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) + addGlobalBinding( + frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } + else + { + // next(t: Table, i: K?) -> (K, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); + addGlobalBinding(frontend, "next", + arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.singletonTypes->nilType}}); + + // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) + addGlobalBinding( + frontend, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + } TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); @@ -413,7 +460,12 @@ void registerBuiltinGlobals(Frontend& frontend) if (TableTypeVar* ttv = getMutable(pair.second.typeId)) { if (!ttv->name) - ttv->name = toString(pair.first); + { + if (FFlag::LuauNewLibraryTypeNames) + ttv->name = "typeof(" + toString(pair.first) + ")"; + else + ttv->name = toString(pair.first); + } } } @@ -623,7 +675,7 @@ static std::optional> magicFunctionAssert( if (head.size() > 0) { - auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true); + auto [ty, ok] = typechecker.pickTypesFromSense(head[0], true, typechecker.singletonTypes->nilType); if (FFlag::LuauUnknownAndNeverType) { if (get(*ty)) diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 79a69ca4..42dc07f6 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -52,6 +52,70 @@ static bool matchSetmetatable(const AstExprCall& call) return true; } +struct TypeGuard +{ + bool isTypeof; + AstExpr* target; + std::string type; +}; + +static std::optional matchTypeGuard(const AstExprBinary* binary) +{ + if (binary->op != AstExprBinary::CompareEq && binary->op != AstExprBinary::CompareNe) + return std::nullopt; + + AstExpr* left = binary->left; + AstExpr* right = binary->right; + if (right->is()) + std::swap(left, right); + + if (!right->is()) + return std::nullopt; + + AstExprCall* call = left->as(); + AstExprConstantString* string = right->as(); + if (!call || !string) + return std::nullopt; + + AstExprGlobal* callee = call->func->as(); + if (!callee) + return std::nullopt; + + if (callee->name != "type" && callee->name != "typeof") + return std::nullopt; + + if (call->args.size != 1) + return std::nullopt; + + return TypeGuard{ + /*isTypeof*/ callee->name == "typeof", + /*target*/ call->args.data[0], + /*type*/ std::string(string->value.data, string->value.size), + }; +} + +namespace +{ + +struct Checkpoint +{ + size_t offset; +}; + +Checkpoint checkpoint(const ConstraintGraphBuilder* cgb) +{ + return Checkpoint{cgb->constraints.size()}; +} + +template +void forEachConstraint(const Checkpoint& start, const Checkpoint& end, const ConstraintGraphBuilder* cgb, F f) +{ + for (size_t i = start.offset; i < end.offset; ++i) + f(cgb->constraints[i]); +} + +} // namespace + ConstraintGraphBuilder::ConstraintGraphBuilder(const ModuleName& moduleName, ModulePtr module, TypeArena* arena, NotNull moduleResolver, NotNull singletonTypes, NotNull ice, const ScopePtr& globalScope, DcrLogger* logger, NotNull dfg) @@ -99,12 +163,12 @@ ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& paren NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, const Location& location, ConstraintV cv) { - return NotNull{scope->constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; + return NotNull{constraints.emplace_back(new Constraint{NotNull{scope.get()}, location, std::move(cv)}).get()}; } NotNull ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr c) { - return NotNull{scope->constraints.emplace_back(std::move(c)).get()}; + return NotNull{constraints.emplace_back(std::move(c)).get()}; } static void unionRefinements(const std::unordered_map& lhs, const std::unordered_map& rhs, @@ -476,6 +540,9 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatForIn* forIn) TypeId ty = freshType(loopScope); loopScope->bindings[var] = Binding{ty, var->location}; variableTypes.push_back(ty); + + if (auto def = dfg->getDef(var)) + loopScope->dcrRefinements[*def] = ty; } // It is always ok to provide too few variables, so we give this pack a free tail. @@ -506,20 +573,6 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat) check(repeatScope, repeat->condition); } -void addConstraints(Constraint* constraint, NotNull scope) -{ - scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); - - for (const auto& c : scope->constraints) - constraint->dependencies.push_back(NotNull{c.get()}); - - for (const auto& c : scope->unqueuedConstraints) - constraint->dependencies.push_back(NotNull{c.get()}); - - for (NotNull childScope : scope->children) - addConstraints(constraint, childScope); -} - void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* function) { // Local @@ -537,12 +590,17 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction* FunctionSignature sig = checkFunctionSignature(scope, function->func); sig.bodyScope->bindings[function->name] = Binding{sig.signature, function->func->location}; + auto start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); + auto end = checkpoint(this); NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); - addConstraints(c.get(), NotNull(sig.bodyScope.get())); + + forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { + c->dependencies.push_back(NotNull{constraint.get()}); + }); addConstraint(scope, std::move(c)); } @@ -610,12 +668,17 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct LUAU_ASSERT(functionType != nullptr); + auto start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); + auto end = checkpoint(this); NotNull constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}; std::unique_ptr c = std::make_unique(constraintScope, function->name->location, GeneralizationConstraint{functionType, sig.signature}); - addConstraints(c.get(), NotNull(sig.bodyScope.get())); + + forEachConstraint(start, end, this, [&c](const ConstraintPtr& constraint) { + c->dependencies.push_back(NotNull{constraint.get()}); + }); addConstraint(scope, std::move(c)); } @@ -947,8 +1010,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExpr* InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCall* call, const std::vector& expectedTypes) { TypeId fnType = check(scope, call->func).ty; - const size_t constraintIndex = scope->constraints.size(); - const size_t scopeIndex = scopes.size(); + auto startCheckpoint = checkpoint(this); std::vector args; @@ -977,8 +1039,7 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa } else { - const size_t constraintEndIndex = scope->constraints.size(); - const size_t scopeEndIndex = scopes.size(); + auto endCheckpoint = checkpoint(this); astOriginalCallTypes[call->func] = fnType; @@ -989,29 +1050,22 @@ InferencePack ConstraintGraphBuilder::checkPack(const ScopePtr& scope, AstExprCa FunctionTypeVar ftv(TypeLevel{}, scope.get(), argPack, rets); TypeId inferredFnType = arena->addType(ftv); - scope->unqueuedConstraints.push_back( + unqueuedConstraints.push_back( std::make_unique(NotNull{scope.get()}, call->func->location, InstantiationConstraint{instantiatedType, fnType})); - NotNull ic(scope->unqueuedConstraints.back().get()); + NotNull ic(unqueuedConstraints.back().get()); - scope->unqueuedConstraints.push_back( + unqueuedConstraints.push_back( std::make_unique(NotNull{scope.get()}, call->func->location, SubtypeConstraint{inferredFnType, instantiatedType})); - NotNull sc(scope->unqueuedConstraints.back().get()); + NotNull sc(unqueuedConstraints.back().get()); // We force constraints produced by checking function arguments to wait // until after we have resolved the constraint on the function itself. // This ensures, for instance, that we start inferring the contents of // lambdas under the assumption that their arguments and return types // will be compatible with the enclosing function call. - for (size_t ci = constraintIndex; ci < constraintEndIndex; ++ci) - scope->constraints[ci]->dependencies.push_back(sc); - - for (size_t si = scopeIndex; si < scopeEndIndex; ++si) - { - for (auto& c : scopes[si].second->constraints) - { - c->dependencies.push_back(sc); - } - } + forEachConstraint(startCheckpoint, endCheckpoint, this, [sc](const ConstraintPtr& constraint) { + constraint->dependencies.push_back(sc); + }); addConstraint(scope, call->func->location, FunctionCallConstraint{ @@ -1283,6 +1337,54 @@ std::tuple ConstraintGraphBuilder::checkBinary( return {leftType, rightType, connectiveArena.disjunction(leftConnective, rightConnective)}; } + else if (auto typeguard = matchTypeGuard(binary)) + { + TypeId leftType = check(scope, binary->left).ty; + TypeId rightType = check(scope, binary->right).ty; + + std::optional def = dfg->getDef(typeguard->target); + if (!def) + return {leftType, rightType, nullptr}; + + TypeId discriminantTy = singletonTypes->neverType; + if (typeguard->type == "nil") + discriminantTy = singletonTypes->nilType; + else if (typeguard->type == "string") + discriminantTy = singletonTypes->stringType; + else if (typeguard->type == "number") + discriminantTy = singletonTypes->numberType; + else if (typeguard->type == "boolean") + discriminantTy = singletonTypes->threadType; + else if (typeguard->type == "table") + discriminantTy = singletonTypes->neverType; // TODO: replace with top table type + else if (typeguard->type == "function") + discriminantTy = singletonTypes->functionType; + else if (typeguard->type == "userdata") + { + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof + discriminantTy = singletonTypes->neverType; // TODO: replace with top class type + } + else if (!typeguard->isTypeof && typeguard->type == "vector") + discriminantTy = singletonTypes->neverType; // TODO: figure out a way to deal with this quirky type + else if (!typeguard->isTypeof) + discriminantTy = singletonTypes->neverType; + else if (auto typeFun = globalScope->lookupType(typeguard->type); typeFun && typeFun->typeParams.empty() && typeFun->typePackParams.empty()) + { + TypeId ty = follow(typeFun->type); + + // We're only interested in the root class of any classes. + if (auto ctv = get(ty); !ctv || !ctv->parent) + discriminantTy = ty; + } + + ConnectiveId proposition = connectiveArena.proposition(*def, discriminantTy); + if (binary->op == AstExprBinary::CompareEq) + return {leftType, rightType, proposition}; + else if (binary->op == AstExprBinary::CompareNe) + return {leftType, rightType, connectiveArena.negation(proposition)}; + else + ice->ice("matchTypeGuard should only return a Some under `==` or `~=`!"); + } else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) { TypeId leftType = check(scope, binary->left, expectedType, true).ty; @@ -2066,19 +2168,14 @@ void ConstraintGraphBuilder::prepopulateGlobalScope(const ScopePtr& globalScope, program->visit(&gp); } -void collectConstraints(std::vector>& result, NotNull scope) -{ - for (const auto& c : scope->constraints) - result.push_back(NotNull{c.get()}); - - for (NotNull child : scope->children) - collectConstraints(result, child); -} - -std::vector> collectConstraints(NotNull rootScope) +std::vector> borrowConstraints(const std::vector& constraints) { std::vector> result; - collectConstraints(result, rootScope); + result.reserve(constraints.size()); + + for (const auto& c : constraints) + result.emplace_back(c.get()); + return result; } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index c53ac659..533652e2 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); -LUAU_FASTFLAG(LuauFixNameMaps) namespace Luau { @@ -27,34 +26,14 @@ namespace Luau { for (const auto& [k, v] : scope->bindings) { - if (FFlag::LuauFixNameMaps) - { - auto d = toString(v.typeId, opts); - printf("\t%s : %s\n", k.c_str(), d.c_str()); - } - else - { - auto d = toStringDetailed(v.typeId, opts); - opts.DEPRECATED_nameMap = d.DEPRECATED_nameMap; - printf("\t%s : %s\n", k.c_str(), d.name.c_str()); - } + auto d = toString(v.typeId, opts); + printf("\t%s : %s\n", k.c_str(), d.c_str()); } for (NotNull child : scope->children) dumpBindings(child, opts); } -static void dumpConstraints(NotNull scope, ToStringOptions& opts) -{ - for (const ConstraintPtr& c : scope->constraints) - { - printf("\t%s\n", toString(*c, opts).c_str()); - } - - for (NotNull child : scope->children) - dumpConstraints(child, opts); -} - static std::pair, std::vector> saturateArguments(TypeArena* arena, NotNull singletonTypes, const TypeFun& fn, const std::vector& rawTypeArguments, const std::vector& rawPackArguments) { @@ -219,12 +198,6 @@ size_t HashInstantiationSignature::operator()(const InstantiationSignature& sign return hash; } -void dump(NotNull rootScope, ToStringOptions& opts) -{ - printf("constraints:\n"); - dumpConstraints(rootScope, opts); -} - void dump(ConstraintSolver* cs, ToStringOptions& opts) { printf("constraints:\n"); @@ -248,17 +221,17 @@ void dump(ConstraintSolver* cs, ToStringOptions& opts) if (auto fcc = get(*c)) { for (NotNull inner : fcc->innerConstraints) - printf("\t\t\t%s\n", toString(*inner, opts).c_str()); + printf("\t ->\t\t%s\n", toString(*inner, opts).c_str()); } } } -ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, ModuleName moduleName, - NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) +ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull rootScope, std::vector> constraints, + ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger) : arena(normalizer->arena) , singletonTypes(normalizer->singletonTypes) , normalizer(normalizer) - , constraints(collectConstraints(rootScope)) + , constraints(std::move(constraints)) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) , moduleResolver(moduleResolver) @@ -267,7 +240,7 @@ ConstraintSolver::ConstraintSolver(NotNull normalizer, NotNull c : constraints) + for (NotNull c : this->constraints) { unsolvedConstraints.push_back(c); @@ -310,6 +283,8 @@ void ConstraintSolver::run() { printf("Starting solver\n"); dump(this, opts); + printf("Bindings:\n"); + dumpBindings(rootScope, opts); } if (FFlag::DebugLuauLogSolverToJson) @@ -633,13 +608,6 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullty.emplace(unionOfTypes(rightType, singletonTypes->booleanType, constraint->scope, false)); + { + TypeId leftFilteredTy = arena->addType(IntersectionTypeVar{{singletonTypes->falsyType, leftType}}); + + // TODO: normaliztion here should be replaced by a more limited 'simplification' + const NormalizedType* normalized = normalizer->normalize(arena->addType(UnionTypeVar{{leftFilteredTy, rightType}})); + + if (!normalized) + { + reportError(CodeTooComplex{}, constraint->location); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + } + else + { + asMutable(resultType)->ty.emplace(normalizer->typeFromNormal(*normalized)); + } + unblock(resultType); return true; + } // Or evaluates to the LHS type if the LHS is truthy, and the RHS type if // LHS is falsey. case AstExprBinary::Op::Or: - asMutable(resultType)->ty.emplace(unionOfTypes(rightType, leftType, constraint->scope, true)); + { + TypeId rightFilteredTy = arena->addType(IntersectionTypeVar{{singletonTypes->truthyType, leftType}}); + + // TODO: normaliztion here should be replaced by a more limited 'simplification' + const NormalizedType* normalized = normalizer->normalize(arena->addType(UnionTypeVar{{rightFilteredTy, rightType}})); + + if (!normalized) + { + reportError(CodeTooComplex{}, constraint->location); + asMutable(resultType)->ty.emplace(errorRecoveryType()); + } + else + { + asMutable(resultType)->ty.emplace(normalizer->typeFromNormal(*normalized)); + } + unblock(resultType); return true; + } default: iceReporter.ice("Unhandled AstExprBinary::Op for binary operation", constraint->location); break; @@ -1148,6 +1148,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) + block(ic, blockedConstraint); + } + } + asMutable(*c.innerConstraints.at(0)).c = InstantiationConstraint{instantiatedType, *callMm}; asMutable(*c.innerConstraints.at(1)).c = SubtypeConstraint{inferredFnType, instantiatedType}; @@ -1180,6 +1191,17 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNullsecond) + block(ic, blockedConstraint); + } + } + unsolvedConstraints.insert(end(unsolvedConstraints), begin(c.innerConstraints), end(c.innerConstraints)); asMutable(c.result)->ty.emplace(constraint->scope); } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 339de975..b0f21737 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -2,6 +2,7 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAG(LuauUnknownAndNeverType) +LUAU_FASTFLAG(LuauOptionalNextKey) namespace Luau { @@ -126,7 +127,7 @@ declare function rawlen(obj: {[K]: V} | string): number declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? -declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number) +-- TODO: place ipairs definition here with removal of FFlagLuauOptionalNextKey declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) @@ -207,6 +208,11 @@ std::string getBuiltinDefinitionSource() else result += "declare function error(message: T, level: number?)\n"; + if (FFlag::LuauOptionalNextKey) + result += "declare function ipairs(tab: {V}): (({V}, number) -> (number?, V), {V}, number)\n"; + else + result += "declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number)\n"; + return result; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 39e6428d..22a9ecfa 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -955,7 +955,8 @@ ModulePtr Frontend::check( cgb.visit(sourceModule.root); result->errors = std::move(cgb.errors); - ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), sourceModule.name, NotNull(&moduleResolver), requireCycles, logger.get()}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull(cgb.rootScope), borrowConstraints(cgb.constraints), sourceModule.name, NotNull(&moduleResolver), + requireCycles, logger.get()}; if (options.randomizeConstraintResolutionSeed) cs.randomize(*options.randomizeConstraintResolutionSeed); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 0412f007..62674aa8 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,9 +14,7 @@ #include -LUAU_FASTFLAG(LuauAnyifyModuleReturnGenerics) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); -LUAU_FASTFLAGVARIABLE(LuauForceExportSurfacesToBeNormal, false); LUAU_FASTFLAGVARIABLE(LuauClonePublicInterfaceLess, false); LUAU_FASTFLAG(LuauSubstitutionReentrant); LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution); @@ -222,18 +220,6 @@ void Module::clonePublicInterface(NotNull singletonTypes, Intern } } - if (!FFlag::LuauAnyifyModuleReturnGenerics) - { - for (TypeId ty : returnType) - { - if (get(follow(ty))) - { - auto t = asMutable(ty); - t->ty = AnyTypeVar{}; - } - } - } - for (auto& [name, ty] : declaredGlobals) { if (FFlag::LuauClonePublicInterfaceLess) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 903e156b..5fd15cf7 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -13,7 +13,6 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauLvaluelessPath) LUAU_FASTFLAG(LuauUnknownAndNeverType) -LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) LUAU_FASTFLAGVARIABLE(LuauFunctionReturnStringificationFixup, false) LUAU_FASTFLAGVARIABLE(LuauUnseeArrayTtv, false) @@ -130,28 +129,15 @@ struct StringifierState bool exhaustive; - StringifierState(ToStringOptions& opts, ToStringResult& result, const std::optional& DEPRECATED_nameMap) + StringifierState(ToStringOptions& opts, ToStringResult& result) : opts(opts) , result(result) , exhaustive(opts.exhaustive) { - if (!FFlag::LuauFixNameMaps && DEPRECATED_nameMap) - result.DEPRECATED_nameMap = *DEPRECATED_nameMap; - - if (!FFlag::LuauFixNameMaps) - { - for (const auto& [_, v] : result.DEPRECATED_nameMap.typeVars) - usedNames.insert(v); - for (const auto& [_, v] : result.DEPRECATED_nameMap.typePacks) - usedNames.insert(v); - } - else - { - for (const auto& [_, v] : opts.nameMap.typeVars) - usedNames.insert(v); - for (const auto& [_, v] : opts.nameMap.typePacks) - usedNames.insert(v); - } + for (const auto& [_, v] : opts.nameMap.typeVars) + usedNames.insert(v); + for (const auto& [_, v] : opts.nameMap.typePacks) + usedNames.insert(v); } bool hasSeen(const void* tv) @@ -174,8 +160,8 @@ struct StringifierState std::string getName(TypeId ty) { - const size_t s = FFlag::LuauFixNameMaps ? opts.nameMap.typeVars.size() : result.DEPRECATED_nameMap.typeVars.size(); - std::string& n = FFlag::LuauFixNameMaps ? opts.nameMap.typeVars[ty] : result.DEPRECATED_nameMap.typeVars[ty]; + const size_t s = opts.nameMap.typeVars.size(); + std::string& n = opts.nameMap.typeVars[ty]; if (!n.empty()) return n; @@ -197,8 +183,8 @@ struct StringifierState std::string getName(TypePackId ty) { - const size_t s = FFlag::LuauFixNameMaps ? opts.nameMap.typePacks.size() : result.DEPRECATED_nameMap.typePacks.size(); - std::string& n = FFlag::LuauFixNameMaps ? opts.nameMap.typePacks[ty] : result.DEPRECATED_nameMap.typePacks[ty]; + const size_t s = opts.nameMap.typePacks.size(); + std::string& n = opts.nameMap.typePacks[ty]; if (!n.empty()) return n; @@ -405,10 +391,7 @@ struct TypeVarStringifier if (gtv.explicitName) { state.usedNames.insert(gtv.name); - if (FFlag::LuauFixNameMaps) - state.opts.nameMap.typeVars[ty] = gtv.name; - else - state.result.DEPRECATED_nameMap.typeVars[ty] = gtv.name; + state.opts.nameMap.typeVars[ty] = gtv.name; state.emit(gtv.name); } else @@ -1002,10 +985,7 @@ struct TypePackStringifier if (pack.explicitName) { state.usedNames.insert(pack.name); - if (FFlag::LuauFixNameMaps) - state.opts.nameMap.typePacks[tp] = pack.name; - else - state.result.DEPRECATED_nameMap.typePacks[tp] = pack.name; + state.opts.nameMap.typePacks[tp] = pack.name; state.emit(pack.name); } else @@ -1108,8 +1088,7 @@ ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts) ToStringResult result; - StringifierState state = - FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state{opts, result}; std::set cycles; std::set cycleTPs; @@ -1216,8 +1195,7 @@ ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) * 4. Print out the root of the type using the same algorithm as step 3. */ ToStringResult result; - StringifierState state = - FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state{opts, result}; std::set cycles; std::set cycleTPs; @@ -1303,8 +1281,7 @@ std::string toString(const TypePackVar& tp, ToStringOptions& opts) std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts) { ToStringResult result; - StringifierState state = - FFlag::LuauFixNameMaps ? StringifierState{opts, result, opts.nameMap} : StringifierState{opts, result, opts.DEPRECATED_nameMap}; + StringifierState state{opts, result}; TypeVarStringifier tvs{state}; state.emit(funcName); @@ -1436,91 +1413,84 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) auto go = [&opts](auto&& c) -> std::string { using T = std::decay_t; - // TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps - auto tos = [](auto&& a, ToStringOptions& opts) { - if (FFlag::LuauFixNameMaps) - return toString(a, opts); - else - { - ToStringResult tsr = toStringDetailed(a, opts); - opts.DEPRECATED_nameMap = std::move(tsr.DEPRECATED_nameMap); - return tsr.name; - } + auto tos = [&opts](auto&& a) + { + return toString(a, opts); }; if constexpr (std::is_same_v) { - std::string subStr = tos(c.subType, opts); - std::string superStr = tos(c.superType, opts); + std::string subStr = tos(c.subType); + std::string superStr = tos(c.superType); return subStr + " <: " + superStr; } else if constexpr (std::is_same_v) { - std::string subStr = tos(c.subPack, opts); - std::string superStr = tos(c.superPack, opts); + std::string subStr = tos(c.subPack); + std::string superStr = tos(c.superPack); return subStr + " <: " + superStr; } else if constexpr (std::is_same_v) { - std::string subStr = tos(c.generalizedType, opts); - std::string superStr = tos(c.sourceType, opts); + std::string subStr = tos(c.generalizedType); + std::string superStr = tos(c.sourceType); return subStr + " ~ gen " + superStr; } else if constexpr (std::is_same_v) { - std::string subStr = tos(c.subType, opts); - std::string superStr = tos(c.superType, opts); + std::string subStr = tos(c.subType); + std::string superStr = tos(c.superType); return subStr + " ~ inst " + superStr; } else if constexpr (std::is_same_v) { - std::string resultStr = tos(c.resultType, opts); - std::string operandStr = tos(c.operandType, opts); + std::string resultStr = tos(c.resultType); + std::string operandStr = tos(c.operandType); return resultStr + " ~ Unary<" + toString(c.op) + ", " + operandStr + ">"; } else if constexpr (std::is_same_v) { - std::string resultStr = tos(c.resultType, opts); - std::string leftStr = tos(c.leftType, opts); - std::string rightStr = tos(c.rightType, opts); + std::string resultStr = tos(c.resultType); + std::string leftStr = tos(c.leftType); + std::string rightStr = tos(c.rightType); return resultStr + " ~ Binary<" + toString(c.op) + ", " + leftStr + ", " + rightStr + ">"; } else if constexpr (std::is_same_v) { - std::string iteratorStr = tos(c.iterator, opts); - std::string variableStr = tos(c.variables, opts); + std::string iteratorStr = tos(c.iterator); + std::string variableStr = tos(c.variables); return variableStr + " ~ Iterate<" + iteratorStr + ">"; } else if constexpr (std::is_same_v) { - std::string namedStr = tos(c.namedType, opts); + std::string namedStr = tos(c.namedType); return "@name(" + namedStr + ") = " + c.name; } else if constexpr (std::is_same_v) { - std::string targetStr = tos(c.target, opts); + std::string targetStr = tos(c.target); return "expand " + targetStr; } else if constexpr (std::is_same_v) { - return "call " + tos(c.fn, opts) + " with { result = " + tos(c.result, opts) + " }"; + return "call " + tos(c.fn) + " with { result = " + tos(c.result) + " }"; } else if constexpr (std::is_same_v) { - return tos(c.resultType, opts) + " ~ prim " + tos(c.expectedType, opts) + ", " + tos(c.singletonType, opts) + ", " + - tos(c.multitonType, opts); + return tos(c.resultType) + " ~ prim " + tos(c.expectedType) + ", " + tos(c.singletonType) + ", " + + tos(c.multitonType); } else if constexpr (std::is_same_v) { - return tos(c.resultType, opts) + " ~ hasProp " + tos(c.subjectType, opts) + ", \"" + c.prop + "\""; + return tos(c.resultType) + " ~ hasProp " + tos(c.subjectType) + ", \"" + c.prop + "\""; } else if constexpr (std::is_same_v) { - std::string result = tos(c.resultType, opts); - std::string discriminant = tos(c.discriminantType, opts); + std::string result = tos(c.resultType); + std::string discriminant = tos(c.discriminantType); return result + " ~ if isSingleton D then ~D else unknown where D = " + discriminant; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index dde41a65..03575c40 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -610,7 +610,7 @@ struct TypeChecker2 visit(rhs); TypeId rhsType = lookupType(rhs); - if (!isSubtype(rhsType, lhsType, stack.back(), singletonTypes, ice)) + if (!isSubtype(rhsType, lhsType, stack.back())) { reportError(TypeMismatch{lhsType, rhsType}, rhs->location); } @@ -761,7 +761,7 @@ struct TypeChecker2 TypeId actualType = lookupType(number); TypeId numberType = singletonTypes->numberType; - if (!isSubtype(numberType, actualType, stack.back(), singletonTypes, ice)) + if (!isSubtype(numberType, actualType, stack.back())) { reportError(TypeMismatch{actualType, numberType}, number->location); } @@ -772,7 +772,7 @@ struct TypeChecker2 TypeId actualType = lookupType(string); TypeId stringType = singletonTypes->stringType; - if (!isSubtype(actualType, stringType, stack.back(), singletonTypes, ice)) + if (!isSubtype(actualType, stringType, stack.back())) { reportError(TypeMismatch{actualType, stringType}, string->location); } @@ -861,7 +861,7 @@ struct TypeChecker2 FunctionTypeVar ftv{argsTp, expectedRetType}; TypeId expectedType = arena.addType(ftv); - if (!isSubtype(testFunctionType, expectedType, stack.back(), singletonTypes, ice)) + if (!isSubtype(testFunctionType, expectedType, stack.back())) { CloneState cloneState; expectedType = clone(expectedType, module->internalTypes, cloneState); @@ -880,7 +880,7 @@ struct TypeChecker2 getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true); if (ty) { - if (!isSubtype(resultType, *ty, stack.back(), singletonTypes, ice)) + if (!isSubtype(resultType, *ty, stack.back())) { reportError(TypeMismatch{resultType, *ty}, indexName->location); } @@ -913,7 +913,7 @@ struct TypeChecker2 TypeId inferredArgTy = *argIt; TypeId annotatedArgTy = lookupAnnotation(arg->annotation); - if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), singletonTypes, ice)) + if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back())) { reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); } @@ -954,7 +954,7 @@ struct TypeChecker2 if (const FunctionTypeVar* ftv = get(follow(*mm))) { TypePackId expectedArgs = module->internalTypes.addTypePack({operandType}); - reportErrors(tryUnify(scope, expr->location, ftv->argTypes, expectedArgs)); + reportErrors(tryUnify(scope, expr->location, expectedArgs, ftv->argTypes)); if (std::optional ret = first(ftv->retTypes)) { @@ -1096,7 +1096,7 @@ struct TypeChecker2 expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::Op::CompareLe || expr->op == AstExprBinary::Op::CompareLt) { TypePackId expectedRets = module->internalTypes.addTypePack({singletonTypes->booleanType}); - if (!isSubtype(ftv->retTypes, expectedRets, scope, singletonTypes, ice)) + if (!isSubtype(ftv->retTypes, expectedRets, scope)) { reportError(GenericError{format("Metamethod '%s' must return type 'boolean'", it->second)}, expr->location); } @@ -1207,10 +1207,10 @@ struct TypeChecker2 TypeId computedType = lookupType(expr->expr); // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (isSubtype(annotationType, computedType, stack.back(), singletonTypes, ice)) + if (isSubtype(annotationType, computedType, stack.back())) return; - if (isSubtype(computedType, annotationType, stack.back(), singletonTypes, ice)) + if (isSubtype(computedType, annotationType, stack.back())) return; reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); @@ -1505,12 +1505,27 @@ struct TypeChecker2 } } + template + bool isSubtype(TID subTy, TID superTy, NotNull scope) + { + UnifierSharedState sharedState{&ice}; + TypeArena arena; + Normalizer normalizer{&arena, singletonTypes, NotNull{&sharedState}}; + Unifier u{NotNull{&normalizer}, Mode::Strict, scope, Location{}, Covariant}; + u.useScopes = true; + + u.tryUnify(subTy, superTy); + const bool ok = u.errors.empty() && u.log.empty(); + return ok; + } + template ErrorVec tryUnify(NotNull scope, const Location& location, TID subTy, TID superTy) { UnifierSharedState sharedState{&ice}; Normalizer normalizer{&module->internalTypes, singletonTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, Mode::Strict, scope, location, Covariant}; + u.useScopes = true; u.tryUnify(subTy, superTy); return std::move(u.errors); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index ccb1490a..8ecd45bd 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -35,18 +35,20 @@ LUAU_FASTFLAG(LuauTypeNormalization2) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) -LUAU_FASTFLAGVARIABLE(LuauAnyifyModuleReturnGenerics, false) LUAU_FASTFLAGVARIABLE(LuauLvaluelessPath, false) +LUAU_FASTFLAGVARIABLE(LuauNilIterator, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) -LUAU_FASTFLAGVARIABLE(LuauFixVarargExprHeadType, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauReturnsFromCallsitesAreNotWidened, false) +LUAU_FASTFLAGVARIABLE(LuauTryhardAnd, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauCompleteVisitor, false) +LUAU_FASTFLAGVARIABLE(LuauOptionalNextKey, false) LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) +LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) namespace Luau { @@ -331,8 +333,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo else moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); - if (FFlag::LuauAnyifyModuleReturnGenerics) - moduleScope->returnType = anyifyModuleReturnTypePackGenerics(moduleScope->returnType); + moduleScope->returnType = anyifyModuleReturnTypePackGenerics(moduleScope->returnType); for (auto& [_, typeFun] : moduleScope->exportedTypeBindings) typeFun.type = anyify(moduleScope, typeFun.type, Location{}); @@ -1209,10 +1210,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) AstExpr* firstValue = forin.values.data[0]; // next is a function that takes Table and an optional index of type K - // next(t: Table, index: K | nil) -> (K, V) + // next(t: Table, index: K | nil) -> (K?, V) // however, pairs and ipairs are quite messy, but they both share the same types // pairs returns 'next, t, nil', thus the type would be - // pairs(t: Table) -> ((Table, K | nil) -> (K, V), Table, K | nil) + // pairs(t: Table) -> ((Table, K | nil) -> (K?, V), Table, K | nil) // ipairs returns 'next, t, 0', thus ipairs will also share the same type as pairs, except K = number // // we can also define our own custom iterators by by returning a wrapped coroutine that calls coroutine.yield @@ -1255,6 +1256,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location); } + if (FFlag::LuauNilIterator) + iterTy = stripFromNilAndReport(iterTy, firstValue->location); + if (std::optional iterMM = findMetatableEntry(iterTy, "__iter", firstValue->location, /* addErrors= */ true)) { // if __iter metamethod is present, it will be called and the results are going to be called as if they are functions @@ -1338,21 +1342,61 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) reportErrors(state.errors); } - TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); - - if (forin.values.size >= 2) + if (FFlag::LuauOptionalNextKey) { - AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + TypePackId retPack = iterFunc->retTypes; - Position start = firstValue->location.begin; - Position end = values[forin.values.size - 1]->location.end; - AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + if (forin.values.size >= 2) + { + AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + + Position start = firstValue->location.begin; + Position end = values[forin.values.size - 1]->location.end; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + + retPack = checkExprPack(scope, exprCall).type; + } + + // We need to remove 'nil' from the set of options of the first return value + // Because for loop stops when it gets 'nil', this result is never actually assigned to the first variable + if (std::optional fty = first(retPack); fty && !varTypes.empty()) + { + TypeId keyTy = follow(*fty); + + if (get(keyTy)) + { + if (std::optional ty = tryStripUnionFromNil(keyTy)) + keyTy = *ty; + } + + unify(keyTy, varTypes.front(), scope, forin.location); + + // We have already handled the first variable type, make it match in the pack check + varTypes.front() = *fty; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); - TypePackId retPack = checkExprPack(scope, exprCall).type; unify(retPack, varPack, scope, forin.location); } else - unify(iterFunc->retTypes, varPack, scope, forin.location); + { + TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); + + if (forin.values.size >= 2) + { + AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + + Position start = firstValue->location.begin; + Position end = values[forin.values.size - 1]->location.end; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + + TypePackId retPack = checkExprPack(scope, exprCall).type; + unify(retPack, varPack, scope, forin.location); + } + else + unify(iterFunc->retTypes, varPack, scope, forin.location); + } check(loopScope, *forin.body); } @@ -1855,18 +1899,10 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp if (get(varargPack)) { - if (FFlag::LuauFixVarargExprHeadType) - { - if (std::optional ty = first(varargPack)) - return {*ty}; + if (std::optional ty = first(varargPack)) + return {*ty}; - return {nilType}; - } - else - { - std::vector types = flatten(varargPack).first; - return {!types.empty() ? types[0] : nilType}; - } + return {nilType}; } else if (get(varargPack)) { @@ -2717,12 +2753,54 @@ TypeId TypeChecker::checkRelationalOperation( case AstExprBinary::And: if (lhsIsAny) + { return lhsType; - return unionOfTypes(rhsType, booleanType, scope, expr.location, false); + } + else if (FFlag::LuauTryhardAnd) + { + // If lhs is free, we can't tell which 'falsy' components it has, if any + if (get(lhsType)) + return unionOfTypes(addType(UnionTypeVar{{nilType, singletonType(false)}}), rhsType, scope, expr.location, false); + + auto [oty, notNever] = pickTypesFromSense(lhsType, false, neverType); // Filter out falsy types + + if (notNever) + { + LUAU_ASSERT(oty); + return unionOfTypes(*oty, rhsType, scope, expr.location, false); + } + else + { + return rhsType; + } + } + else + { + return unionOfTypes(rhsType, booleanType, scope, expr.location, false); + } case AstExprBinary::Or: if (lhsIsAny) + { return lhsType; - return unionOfTypes(lhsType, rhsType, scope, expr.location); + } + else if (FFlag::LuauTryhardAnd) + { + auto [oty, notNever] = pickTypesFromSense(lhsType, true, neverType); // Filter out truthy types + + if (notNever) + { + LUAU_ASSERT(oty); + return unionOfTypes(*oty, rhsType, scope, expr.location); + } + else + { + return rhsType; + } + } + else + { + return unionOfTypes(lhsType, rhsType, scope, expr.location); + } default: LUAU_ASSERT(0); ice(format("checkRelationalOperation called with incorrect binary expression '%s'", toString(expr.op).c_str()), expr.location); @@ -4840,9 +4918,9 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) return singletonTypes->errorRecoveryTypePack(guess); } -TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) +TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense, TypeId emptySetTy) { - return [this, sense](TypeId ty) -> std::optional { + return [this, sense, emptySetTy](TypeId ty) -> std::optional { // any/error/free gets a special pass unconditionally because they can't be decided. if (get(ty) || get(ty) || get(ty)) return ty; @@ -4860,7 +4938,7 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) return sense ? std::nullopt : std::optional(ty); // at this point, anything else is kept if sense is true, or replaced by nil - return sense ? ty : nilType; + return sense ? ty : emptySetTy; }; } @@ -4886,9 +4964,9 @@ std::pair, bool> TypeChecker::filterMap(TypeId type, TypeI } } -std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense) +std::pair, bool> TypeChecker::pickTypesFromSense(TypeId type, bool sense, TypeId emptySetTy) { - return filterMap(type, mkTruthyPredicate(sense)); + return filterMap(type, mkTruthyPredicate(sense, emptySetTy)); } TypeId TypeChecker::addTV(TypeVar&& tv) @@ -5657,7 +5735,7 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, RefinementMap& refis, if (ty && fromOr) return addRefinement(refis, truthyP.lvalue, *ty); - refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); + refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense, nilType)); } void TypeChecker::resolve(const AndPredicate& andP, RefinementMap& refis, const ScopePtr& scope, bool sense) @@ -5850,13 +5928,57 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc if (maybeSingleton(eqP.type)) { - // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. - if (!sense || canUnify(eqP.type, option, scope, eqP.location).empty()) - return sense ? eqP.type : option; + if (FFlag::LuauImplicitElseRefinement) + { + bool optionIsSubtype = canUnify(option, eqP.type, scope, eqP.location).empty(); + bool targetIsSubtype = canUnify(eqP.type, option, scope, eqP.location).empty(); - // local variable works around an odd gcc 9.3 warning: may be used uninitialized - std::optional res = std::nullopt; - return res; + // terminology refresher: + // - option is the type of the expression `x`, and + // - eqP.type is the type of the expression `"hello"` + // + // "hello" == x where + // x : "hello" | "world" -> x : "hello" + // x : number | string -> x : "hello" + // x : number -> x : never + // + // "hello" ~= x where + // x : "hello" | "world" -> x : "world" + // x : number | string -> x : number | string + // x : number -> x : number + + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional nope = std::nullopt; + + if (sense) + { + if (optionIsSubtype && !targetIsSubtype) + return option; + else if (!optionIsSubtype && targetIsSubtype) + return eqP.type; + else if (!optionIsSubtype && !targetIsSubtype) + return nope; + else if (optionIsSubtype && targetIsSubtype) + return eqP.type; + } + else + { + bool isOptionSingleton = get(option); + if (!isOptionSingleton) + return option; + else if (optionIsSubtype && targetIsSubtype) + return nope; + } + } + else + { + if (!sense || canUnify(eqP.type, option, scope, eqP.location).empty()) + return sense ? eqP.type : option; + + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; + } } return option; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index de0890e1..814eca0d 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -3,6 +3,7 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/ConstraintSolver.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" #include "Luau/RecursionCounter.h" @@ -26,6 +27,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAGVARIABLE(LuauNoMoreGlobalSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauNewLibraryTypeNames, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) namespace Luau @@ -33,15 +35,19 @@ namespace Luau std::optional> magicFunctionFormat( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context); static std::optional> magicFunctionGmatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context); static std::optional> magicFunctionMatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context); static std::optional> magicFunctionFind( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); +static bool dcrMagicFunctionFind(MagicFunctionCallContext context); TypeId follow(TypeId t) { @@ -800,6 +806,7 @@ TypeId SingletonTypes::makeStringMetatable() FunctionTypeVar formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; formatFTV.magicFunction = &magicFunctionFormat; const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); const TypePackId emptyPack = arena->addTypePack({}); const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); @@ -814,14 +821,17 @@ TypeId SingletonTypes::makeStringMetatable() const TypeId gmatchFunc = makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); const TypeId matchFunc = arena->addType( FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); const TypeId findFunc = arena->addType(FunctionTypeVar{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); TableTypeVar::Props stringLib = { {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, @@ -855,7 +865,7 @@ TypeId SingletonTypes::makeStringMetatable() TypeId tableType = arena->addType(TableTypeVar{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); if (TableTypeVar* ttv = getMutable(tableType)) - ttv->name = "string"; + ttv->name = FFlag::LuauNewLibraryTypeNames ? "typeof(string)" : "string"; return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } @@ -1072,7 +1082,7 @@ IntersectionTypeVarIterator end(const IntersectionTypeVar* itv) return IntersectionTypeVarIterator{}; } -static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) +static std::vector parseFormatString(NotNull singletonTypes, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs*"; @@ -1095,13 +1105,13 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha break; if (data[i] == 'q' || data[i] == 's') - result.push_back(typechecker.stringType); + result.push_back(singletonTypes->stringType); else if (data[i] == '*') - result.push_back(typechecker.unknownType); + result.push_back(singletonTypes->unknownType); else if (strchr(options, data[i])) - result.push_back(typechecker.numberType); + result.push_back(singletonTypes->numberType); else - result.push_back(typechecker.errorRecoveryType(typechecker.anyType)); + result.push_back(singletonTypes->errorRecoveryType(singletonTypes->anyType)); } } @@ -1130,7 +1140,7 @@ std::optional> magicFunctionFormat( if (!fmt) return std::nullopt; - std::vector expected = parseFormatString(typechecker, fmt->value.data, fmt->value.size); + std::vector expected = parseFormatString(typechecker.singletonTypes, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(paramPack); size_t paramOffset = 1; @@ -1154,7 +1164,50 @@ std::optional> magicFunctionFormat( return WithPredicate{arena.addTypePack({typechecker.stringType})}; } -static std::vector parsePatternString(TypeChecker& typechecker, const char* data, size_t size) +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +{ + TypeArena* arena = context.solver->arena; + + AstExprConstantString* fmt = nullptr; + if (auto index = context.callSite->func->as(); index && context.callSite->self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!context.callSite->self && context.callSite->args.size > 0) + fmt = context.callSite->args.data[0]->as(); + + if (!fmt) + return false; + + std::vector expected = parseFormatString(context.solver->singletonTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(context.arguments); + + size_t paramOffset = 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + context.solver->unify(params[i + paramOffset], expected[i], context.solver->rootScope); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + TypePackId resultPack = arena->addTypePack({context.solver->singletonTypes->stringType}); + asMutable(context.result)->ty.emplace(resultPack); + + return true; +} + +static std::vector parsePatternString(NotNull singletonTypes, const char* data, size_t size) { std::vector result; int depth = 0; @@ -1186,12 +1239,12 @@ static std::vector parsePatternString(TypeChecker& typechecker, const ch if (i + 1 < size && data[i + 1] == ')') { i++; - result.push_back(typechecker.numberType); + result.push_back(singletonTypes->numberType); continue; } ++depth; - result.push_back(typechecker.stringType); + result.push_back(singletonTypes->stringType); } else if (data[i] == ')') { @@ -1209,7 +1262,7 @@ static std::vector parsePatternString(TypeChecker& typechecker, const ch return std::vector(); if (result.empty()) - result.push_back(typechecker.stringType); + result.push_back(singletonTypes->stringType); return result; } @@ -1233,7 +1286,7 @@ static std::optional> magicFunctionGmatch( if (!pattern) return std::nullopt; - std::vector returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + std::vector returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1246,6 +1299,39 @@ static std::optional> magicFunctionGmatch( return WithPredicate{arena.addTypePack({iteratorType})}; } +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() != 2) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t index = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > index) + pattern = context.callSite->args.data[index]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(params[0], context.solver->singletonTypes->stringType, context.solver->rootScope); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId returnList = arena->addTypePack(returnTypes); + const TypeId iteratorType = arena->addType(FunctionTypeVar{emptyPack, returnList}); + const TypePackId resTypePack = arena->addTypePack({iteratorType}); + asMutable(context.result)->ty.emplace(resTypePack); + + return true; +} + static std::optional> magicFunctionMatch( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { @@ -1265,7 +1351,7 @@ static std::optional> magicFunctionMatch( if (!pattern) return std::nullopt; - std::vector returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + std::vector returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1282,6 +1368,42 @@ static std::optional> magicFunctionMatch( return WithPredicate{returnList}; } +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 3) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(params[0], context.solver->singletonTypes->stringType, context.solver->rootScope); + + const TypeId optionalNumber = arena->addType(UnionTypeVar{{context.solver->singletonTypes->nilType, context.solver->singletonTypes->numberType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() == 3 && context.callSite->args.size > initIndex) + context.solver->unify(params[2], optionalNumber, context.solver->rootScope); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + + return true; +} + static std::optional> magicFunctionFind( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { @@ -1312,7 +1434,7 @@ static std::optional> magicFunctionFind( std::vector returnTypes; if (!plain) { - returnTypes = parsePatternString(typechecker, pattern->value.data, pattern->value.size); + returnTypes = parsePatternString(typechecker.singletonTypes, pattern->value.data, pattern->value.size); if (returnTypes.empty()) return std::nullopt; @@ -1336,6 +1458,60 @@ static std::optional> magicFunctionFind( return WithPredicate{returnList}; } +static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 4) + return false; + + TypeArena* arena = context.solver->arena; + NotNull singletonTypes = context.solver->singletonTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + bool plain = false; + size_t plainIndex = context.callSite->self ? 2 : 3; + if (context.callSite->args.size > plainIndex) + { + AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(singletonTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + } + + context.solver->unify(params[0], singletonTypes->stringType, context.solver->rootScope); + + const TypeId optionalNumber = arena->addType(UnionTypeVar{{singletonTypes->nilType, singletonTypes->numberType}}); + const TypeId optionalBoolean = arena->addType(UnionTypeVar{{singletonTypes->nilType, singletonTypes->booleanType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() >= 3 && context.callSite->args.size > initIndex) + context.solver->unify(params[2], optionalNumber, context.solver->rootScope); + + if (params.size() == 4 && context.callSite->args.size > plainIndex) + context.solver->unify(params[3], optionalBoolean, context.solver->rootScope); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + return true; +} + std::vector filterMap(TypeId type, TypeIdPredicate predicate) { type = follow(type); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index df5d86f1..4dc90983 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauSubtypeNormalizer, false); LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauOverloadedFunctionSubtypingPerf, false); +LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNegatedFunctionTypes) @@ -1699,8 +1700,20 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated // txn log. - TableTypeVar* newSuperTable = log.getMutable(superTy); - TableTypeVar* newSubTable = log.getMutable(subTy); + TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner ? log.follow(superTy) : superTy; + TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner ? log.follow(subTy) : subTy; + + if (FFlag::LuauScalarShapeUnifyToMtOwner) + { + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || subTy != subTyNew) && errors.empty()) + return tryUnify(subTy, superTy, false, isIntersection); + } + + // Otherwise, restart only the table unification + TableTypeVar* newSuperTable = log.getMutable(superTyNew); + TableTypeVar* newSubTable = log.getMutable(subTyNew); + if (superTable != newSuperTable || (subTable != newSubTable && subTable != instantiatedSubTable)) { if (errors.empty()) @@ -1862,7 +1875,9 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (reversed) std::swap(subTy, superTy); - if (auto ttv = log.get(superTy); !ttv || ttv->state != TableState::Free) + TableTypeVar* superTable = log.getMutable(superTy); + + if (!superTable || superTable->state != TableState::Free) return reportError(location, TypeMismatch{osuperTy, osubTy}); auto fail = [&](std::optional e) { @@ -1887,6 +1902,20 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) Unifier child = makeChildUnifier(); child.tryUnify_(ty, superTy); + if (FFlag::LuauScalarShapeUnifyToMtOwner) + { + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed + // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check + TypeId newSuperTy = child.log.follow(superTy); + + if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) + { + log.replace(superTy, BoundTypeVar{subTy}); + return; + } + } + if (auto e = hasUnificationTooComplex(child.errors)) reportError(*e); else if (!child.errors.empty()) @@ -1894,6 +1923,14 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) log.concat(std::move(child.log)); + if (FFlag::LuauScalarShapeUnifyToMtOwner) + { + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype + if (child.errors.empty()) + log.replace(superTy, BoundTypeVar{subTy}); + } + return; } else diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 4c0cc125..8338a04a 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -27,6 +27,8 @@ LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauCommaParenWarnings, false) LUAU_FASTFLAGVARIABLE(LuauTableConstructorRecovery, false) +LUAU_FASTFLAGVARIABLE(LuauParserErrorsOnMissingDefaultTypePackArgument, false) + bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false; bool lua_telemetry_parsed_double_prefix_hex_integer = false; @@ -2503,7 +2505,7 @@ std::pair, AstArray> Parser::parseG namePacks.push_back({name, nameLocation, typePack}); } - else if (lexer.current().type == '(') + else if (!FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument && lexer.current().type == '(') { auto [type, typePack] = parseTypeOrPackAnnotation(); @@ -2512,6 +2514,15 @@ std::pair, AstArray> Parser::parseG namePacks.push_back({name, nameLocation, typePack}); } + else if (FFlag::LuauParserErrorsOnMissingDefaultTypePackArgument) + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (type) + report(type->location, "Expected type pack after '=', got type"); + + namePacks.push_back({name, nameLocation, typePack}); + } } else { diff --git a/CodeGen/include/Luau/AddressA64.h b/CodeGen/include/Luau/AddressA64.h index 351e6715..53efd3c3 100644 --- a/CodeGen/include/Luau/AddressA64.h +++ b/CodeGen/include/Luau/AddressA64.h @@ -17,7 +17,6 @@ enum class AddressKindA64 : uint8_t // reg + reg << shift // reg + sext(reg) << shift // reg + uext(reg) << shift - // pc + offset }; struct AddressA64 @@ -28,8 +27,8 @@ struct AddressA64 , offset(xzr) , data(off) { - LUAU_ASSERT(base.kind == KindA64::x); - LUAU_ASSERT(off >= 0 && off < 4096); + LUAU_ASSERT(base.kind == KindA64::x || base == sp); + LUAU_ASSERT(off >= -256 && off < 4096); } AddressA64(RegisterA64 base, RegisterA64 offset) @@ -48,5 +47,7 @@ struct AddressA64 int data; }; +using mem = AddressA64; + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 9a1402be..9e12168a 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -34,15 +34,18 @@ public: // Comparisons // Note: some arithmetic instructions also have versions that update flags (ADDS etc) but we aren't using them atm - // TODO: add cmp + void cmp(RegisterA64 src1, RegisterA64 src2); + void cmp(RegisterA64 src1, int src2); - // Binary + // Bitwise // Note: shifted-register support and bitfield operations are omitted for simplicity // TODO: support immediate arguments (they have odd encoding and forbid many values) - // TODO: support not variants for and/or/eor (required to support not...) void and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void orr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void mvn(RegisterA64 dst, RegisterA64 src); + + // Shifts void lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void lsr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); void asr(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); @@ -66,11 +69,19 @@ public: // Control flow // Note: tbz/tbnz are currently not supported because they have 15-bit offsets and we don't support branch thunks + void b(Label& label); void b(ConditionA64 cond, Label& label); void cbz(RegisterA64 src, Label& label); void cbnz(RegisterA64 src, Label& label); + void br(RegisterA64 src); + void blr(RegisterA64 src); void ret(); + // Address of embedded data + void adr(RegisterA64 dst, const void* ptr, size_t size); + void adr(RegisterA64 dst, uint64_t value); + void adr(RegisterA64 dst, double value); + // Run final checks bool finalize(); @@ -97,17 +108,21 @@ private: // Instruction archetypes void place0(const char* name, uint32_t word); void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0); - void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op); + void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2 = 0); void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2); void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op); void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op); void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0); void placeA(const char* name, RegisterA64 dst, AddressA64 src, uint8_t op, uint8_t size); void placeBC(const char* name, Label& label, uint8_t op, uint8_t cond); - void placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond); + void placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond); + void placeBR(const char* name, RegisterA64 src, uint32_t op); + void placeADR(const char* name, RegisterA64 src, uint8_t op); void place(uint32_t word); - void placeLabel(Label& label); + + void patchLabel(Label& label); + void patchImm19(uint32_t location, int value); void commit(); LUAU_NOINLINE void extend(); @@ -123,6 +138,7 @@ private: LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, int src, int shift = 0); LUAU_NOINLINE void log(const char* opcode, RegisterA64 dst, AddressA64 src); LUAU_NOINLINE void log(const char* opcode, RegisterA64 src, Label label); + LUAU_NOINLINE void log(const char* opcode, RegisterA64 src); LUAU_NOINLINE void log(const char* opcode, Label label); LUAU_NOINLINE void log(Label label); LUAU_NOINLINE void log(RegisterA64 reg); @@ -133,6 +149,7 @@ private: std::vector labelLocations; bool finalized = false; + bool overflowed = false; size_t dataPos = 0; diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index e4237f53..286800d6 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -37,7 +37,10 @@ AssemblyBuilderA64::~AssemblyBuilderA64() void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src) { - placeSR2("mov", dst, src, 0b01'01010); + if (dst == sp || src == sp) + placeR1("mov", dst, src, 0b00'100010'0'000000000000); + else + placeSR2("mov", dst, src, 0b01'01010); } void AssemblyBuilderA64::mov(RegisterA64 dst, uint16_t src, int shift) @@ -75,6 +78,20 @@ void AssemblyBuilderA64::neg(RegisterA64 dst, RegisterA64 src) placeSR2("neg", dst, src, 0b10'01011); } +void AssemblyBuilderA64::cmp(RegisterA64 src1, RegisterA64 src2) +{ + RegisterA64 dst = src1.kind == KindA64::x ? xzr : wzr; + + placeSR3("cmp", dst, src1, src2, 0b11'01011); +} + +void AssemblyBuilderA64::cmp(RegisterA64 src1, int src2) +{ + RegisterA64 dst = src1.kind == KindA64::x ? xzr : wzr; + + placeI12("cmp", dst, src1, src2, 0b11'10001); +} + void AssemblyBuilderA64::and_(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { placeSR3("and", dst, src1, src2, 0b00'01010); @@ -90,6 +107,11 @@ void AssemblyBuilderA64::eor(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2 placeSR3("eor", dst, src1, src2, 0b10'01010); } +void AssemblyBuilderA64::mvn(RegisterA64 dst, RegisterA64 src) +{ + placeSR2("mvn", dst, src, 0b01'01010, 0b1); +} + void AssemblyBuilderA64::lsl(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { placeR3("lsl", dst, src1, src2, 0b11010110, 0b0010'00); @@ -183,6 +205,12 @@ void AssemblyBuilderA64::strh(RegisterA64 src, AddressA64 dst) placeA("strh", src, dst, 0b11100000, 0b01); } +void AssemblyBuilderA64::b(Label& label) +{ + // Note: we aren't using 'b' form since it has a 26-bit immediate which requires custom fixup logic + placeBC("b", label, 0b0101010'0, codeForCondition[int(ConditionA64::Always)]); +} + void AssemblyBuilderA64::b(ConditionA64 cond, Label& label) { placeBC(textForCondition[int(cond)], label, 0b0101010'0, codeForCondition[int(cond)]); @@ -190,12 +218,22 @@ void AssemblyBuilderA64::b(ConditionA64 cond, Label& label) void AssemblyBuilderA64::cbz(RegisterA64 src, Label& label) { - placeBR("cbz", label, 0b011010'0, src); + placeBCR("cbz", label, 0b011010'0, src); } void AssemblyBuilderA64::cbnz(RegisterA64 src, Label& label) { - placeBR("cbnz", label, 0b011010'1, src); + placeBCR("cbnz", label, 0b011010'1, src); +} + +void AssemblyBuilderA64::br(RegisterA64 src) +{ + placeBR("br", src, 0b1101011'0'0'00'11111'0000'0'0); +} + +void AssemblyBuilderA64::blr(RegisterA64 src) +{ + placeBR("blr", src, 0b1101011'0'0'01'11111'0000'0'0); } void AssemblyBuilderA64::ret() @@ -203,10 +241,41 @@ void AssemblyBuilderA64::ret() place0("ret", 0b1101011'0'0'10'11111'0000'0'0'11110'00000); } +void AssemblyBuilderA64::adr(RegisterA64 dst, const void* ptr, size_t size) +{ + size_t pos = allocateData(size, 4); + uint32_t location = getCodeSize(); + + memcpy(&data[pos], ptr, size); + placeADR("adr", dst, 0b10000); + + patchImm19(location, -int(location) - int((data.size() - pos) / 4)); +} + +void AssemblyBuilderA64::adr(RegisterA64 dst, uint64_t value) +{ + size_t pos = allocateData(8, 8); + uint32_t location = getCodeSize(); + + writeu64(&data[pos], value); + placeADR("adr", dst, 0b10000); + + patchImm19(location, -int(location) - int((data.size() - pos) / 4)); +} + +void AssemblyBuilderA64::adr(RegisterA64 dst, double value) +{ + size_t pos = allocateData(8, 8); + uint32_t location = getCodeSize(); + + writef64(&data[pos], value); + placeADR("adr", dst, 0b10000); + + patchImm19(location, -int(location) - int((data.size() - pos) / 4)); +} + bool AssemblyBuilderA64::finalize() { - bool success = true; - code.resize(codePos - code.data()); // Resolve jump targets @@ -214,15 +283,9 @@ bool AssemblyBuilderA64::finalize() { // If this assertion fires, a label was used in jmp without calling setLabel LUAU_ASSERT(labelLocations[fixup.id - 1] != ~0u); - int value = int(labelLocations[fixup.id - 1]) - int(fixup.location); - // imm19 encoding word offset, at bit offset 5 - // note that 18 bits of word offsets = 20 bits of byte offsets = +-1MB - if (value > -(1 << 18) && value < (1 << 18)) - code[fixup.location] |= (value & ((1 << 19) - 1)) << 5; - else - success = false; // overflow + patchImm19(fixup.location, value); } size_t dataSize = data.size() - dataPos; @@ -235,7 +298,7 @@ bool AssemblyBuilderA64::finalize() finalized = true; - return success; + return !overflowed; } Label AssemblyBuilderA64::setLabel() @@ -303,7 +366,7 @@ void AssemblyBuilderA64::placeSR3(const char* name, RegisterA64 dst, RegisterA64 commit(); } -void AssemblyBuilderA64::placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op) +void AssemblyBuilderA64::placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2) { if (logText) log(name, dst, src); @@ -313,7 +376,7 @@ void AssemblyBuilderA64::placeSR2(const char* name, RegisterA64 dst, RegisterA64 uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; - place(dst.index | (0x1f << 5) | (src.index << 16) | (op << 24) | sf); + place(dst.index | (0x1f << 5) | (src.index << 16) | (op2 << 21) | (op << 24) | sf); commit(); } @@ -336,10 +399,10 @@ void AssemblyBuilderA64::placeR1(const char* name, RegisterA64 dst, RegisterA64 if (logText) log(name, dst, src); - LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); - LUAU_ASSERT(dst.kind == src.kind); + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp); + LUAU_ASSERT(dst.kind == src.kind || (dst.kind == KindA64::x && src == sp) || (dst == sp && src.kind == KindA64::x)); - uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; + uint32_t sf = (dst.kind != KindA64::w) ? 0x80000000 : 0; place(dst.index | (src.index << 5) | (op << 10) | sf); commit(); @@ -350,11 +413,11 @@ void AssemblyBuilderA64::placeI12(const char* name, RegisterA64 dst, RegisterA64 if (logText) log(name, dst, src1, src2); - LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x); - LUAU_ASSERT(dst.kind == src1.kind); + LUAU_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp); + LUAU_ASSERT(dst.kind == src1.kind || (dst.kind == KindA64::x && src1 == sp) || (dst == sp && src1.kind == KindA64::x)); LUAU_ASSERT(src2 >= 0 && src2 < (1 << 12)); - uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; + uint32_t sf = (dst.kind != KindA64::w) ? 0x80000000 : 0; place(dst.index | (src1.index << 5) | (src2 << 10) | (op << 24) | sf); commit(); @@ -383,8 +446,12 @@ void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 sr switch (src.kind) { case AddressKindA64::imm: - LUAU_ASSERT(src.data % (1 << size) == 0); - place(dst.index | (src.base.index << 5) | ((src.data >> size) << 10) | (op << 22) | (1 << 24) | (size << 30)); + if (src.data >= 0 && src.data % (1 << size) == 0) + place(dst.index | (src.base.index << 5) | ((src.data >> size) << 10) | (op << 22) | (1 << 24) | (size << 30)); + else if (src.data >= -256 && src.data <= 255) + place(dst.index | (src.base.index << 5) | ((src.data & ((1 << 9) - 1)) << 12) | (op << 22) | (size << 30)); + else + LUAU_ASSERT(!"Unable to encode large immediate offset"); break; case AddressKindA64::reg: place(dst.index | (src.base.index << 5) | (0b10 << 10) | (0b011 << 13) | (src.offset.index << 16) | (1 << 21) | (op << 22) | (size << 30)); @@ -396,28 +463,50 @@ void AssemblyBuilderA64::placeA(const char* name, RegisterA64 dst, AddressA64 sr void AssemblyBuilderA64::placeBC(const char* name, Label& label, uint8_t op, uint8_t cond) { - placeLabel(label); + place(cond | (op << 24)); + commit(); + + patchLabel(label); if (logText) log(name, label); - - place(cond | (op << 24)); - commit(); } -void AssemblyBuilderA64::placeBR(const char* name, Label& label, uint8_t op, RegisterA64 cond) +void AssemblyBuilderA64::placeBCR(const char* name, Label& label, uint8_t op, RegisterA64 cond) { - placeLabel(label); - - if (logText) - log(name, cond, label); - LUAU_ASSERT(cond.kind == KindA64::w || cond.kind == KindA64::x); uint32_t sf = (cond.kind == KindA64::x) ? 0x80000000 : 0; place(cond.index | (op << 24) | sf); commit(); + + patchLabel(label); + + if (logText) + log(name, cond, label); +} + +void AssemblyBuilderA64::placeBR(const char* name, RegisterA64 src, uint32_t op) +{ + if (logText) + log(name, src); + + LUAU_ASSERT(src.kind == KindA64::x); + + place((src.index << 5) | (op << 10)); + commit(); +} + +void AssemblyBuilderA64::placeADR(const char* name, RegisterA64 dst, uint8_t op) +{ + if (logText) + log(name, dst); + + LUAU_ASSERT(dst.kind == KindA64::x); + + place(dst.index | (op << 24)); + commit(); } void AssemblyBuilderA64::place(uint32_t word) @@ -426,8 +515,10 @@ void AssemblyBuilderA64::place(uint32_t word) *codePos++ = word; } -void AssemblyBuilderA64::placeLabel(Label& label) +void AssemblyBuilderA64::patchLabel(Label& label) { + uint32_t location = getCodeSize() - 1; + if (label.location == ~0u) { if (label.id == 0) @@ -436,18 +527,26 @@ void AssemblyBuilderA64::placeLabel(Label& label) labelLocations.push_back(~0u); } - pendingLabels.push_back({label.id, getCodeSize()}); + pendingLabels.push_back({label.id, location}); } else { - // note: if label has an assigned location we can in theory avoid patching it later, but - // we need to handle potential overflow of 19-bit offsets - LUAU_ASSERT(label.id != 0); - labelLocations[label.id - 1] = label.location; - pendingLabels.push_back({label.id, getCodeSize()}); + int value = int(label.location) - int(location); + + patchImm19(location, value); } } +void AssemblyBuilderA64::patchImm19(uint32_t location, int value) +{ + // imm19 encoding word offset, at bit offset 5 + // note that 18 bits of word offsets = 20 bits of byte offsets = +-1MB + if (value > -(1 << 18) && value < (1 << 18)) + code[location] |= (value & ((1 << 19) - 1)) << 5; + else + overflowed = true; +} + void AssemblyBuilderA64::commit() { LUAU_ASSERT(codePos <= codeEnd); @@ -491,8 +590,11 @@ void AssemblyBuilderA64::log(const char* opcode) void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift) { logAppend(" %-12s", opcode); - log(dst); - text.append(","); + if (dst != xzr && dst != wzr) + { + log(dst); + text.append(","); + } log(src1); text.append(","); log(src2); @@ -504,8 +606,11 @@ void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 sr void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 src1, int src2) { logAppend(" %-12s", opcode); - log(dst); - text.append(","); + if (dst != xzr && dst != wzr) + { + log(dst); + text.append(","); + } log(src1); text.append(","); logAppend("#%d", src2); @@ -549,6 +654,13 @@ void AssemblyBuilderA64::log(const char* opcode, RegisterA64 src, Label label) logAppend(".L%d\n", label.id); } +void AssemblyBuilderA64::log(const char* opcode, RegisterA64 src) +{ + logAppend(" %-12s", opcode); + log(src); + text.append("\n"); +} + void AssemblyBuilderA64::log(const char* opcode, Label label) { logAppend(" %-12s.L%d\n", opcode, label.id); @@ -565,20 +677,24 @@ void AssemblyBuilderA64::log(RegisterA64 reg) { case KindA64::w: if (reg.index == 31) - logAppend("wzr"); + text.append("wzr"); else logAppend("w%d", reg.index); break; case KindA64::x: if (reg.index == 31) - logAppend("xzr"); + text.append("xzr"); else logAppend("x%d", reg.index); break; case KindA64::none: - LUAU_ASSERT(!"Unexpected register kind"); + if (reg.index == 31) + text.append("sp"); + else + LUAU_ASSERT(!"Unexpected register kind"); + break; } } diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 3074cce2..b23d2b38 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -14,15 +14,15 @@ * Each line is 8 bytes, stack grows downwards. * * | ... previous frames ... - * | rdx home space | (saved only on windows) - * | rcx home space | (saved only on windows) + * | rdx home space | (unused) + * | rcx home space | (unused) * | return address | - * | ... saved non-volatile registers ... + * | ... saved non-volatile registers ... <-- rsp + kStackSize + kLocalsSize * | unused | for 16 byte alignment of the stack * | sCode | - * | sClosure | <-- rbp points here - * | argument 6 | - * | argument 5 | + * | sClosure | <-- rsp + kStackSize + * | argument 6 | <-- rsp + 40 + * | argument 5 | <-- rsp + 32 * | r9 home space | * | r8 home space | * | rdx home space | @@ -48,28 +48,18 @@ bool initEntryFunction(NativeState& data) unwind.start(); - if (build.abi == ABIX64::Windows) - { - // Place arguments in home space - build.mov(qword[rsp + 16], rArg2); - unwind.spill(16, rArg2); - build.mov(qword[rsp + 8], rArg1); - unwind.spill(8, rArg1); - - // Save non-volatile registers that are specific to Windows x64 ABI - build.push(rdi); - unwind.save(rdi); - build.push(rsi); - unwind.save(rsi); - - // Once we start using non-volatile SIMD registers, we will save those here - } - // Save common non-volatile registers - build.push(rbx); - unwind.save(rbx); build.push(rbp); unwind.save(rbp); + + if (build.abi == ABIX64::SystemV) + { + build.mov(rbp, rsp); + unwind.setupFrameReg(rbp, 0); + } + + build.push(rbx); + unwind.save(rbx); build.push(r12); unwind.save(r12); build.push(r13); @@ -79,16 +69,20 @@ bool initEntryFunction(NativeState& data) build.push(r15); unwind.save(r15); - int stacksize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments - int localssize = 24; // 3 local pointers that also correctly align the stack + if (build.abi == ABIX64::Windows) + { + // Save non-volatile registers that are specific to Windows x64 ABI + build.push(rdi); + unwind.save(rdi); + build.push(rsi); + unwind.save(rsi); + + // TODO: once we start using non-volatile SIMD registers on Windows, we will save those here + } // Allocate stack space (reg home area + local data) - build.sub(rsp, stacksize + localssize); - unwind.allocStack(stacksize + localssize); - - // Setup frame pointer - build.lea(rbp, addr[rsp + stacksize]); - unwind.setupFrameReg(rbp, stacksize); + build.sub(rsp, kStackSize + kLocalsSize); + unwind.allocStack(kStackSize + kLocalsSize); unwind.finish(); @@ -113,13 +107,7 @@ bool initEntryFunction(NativeState& data) Label returnOff = build.setLabel(); // Cleanup and exit - build.lea(rsp, addr[rbp + localssize]); - build.pop(r15); - build.pop(r14); - build.pop(r13); - build.pop(r12); - build.pop(rbp); - build.pop(rbx); + build.add(rsp, kStackSize + kLocalsSize); if (build.abi == ABIX64::Windows) { @@ -127,6 +115,12 @@ bool initEntryFunction(NativeState& data) build.pop(rdi); } + build.pop(r15); + build.pop(r14); + build.pop(r13); + build.pop(r12); + build.pop(rbx); + build.pop(rbp); build.ret(); build.finalize(); diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index 071ef6af..61544855 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -32,8 +32,12 @@ constexpr RegisterX64 rNativeContext = r13; // NativeContext* context constexpr RegisterX64 rConstants = r12; // TValue* k // Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point -constexpr OperandX64 sClosure = qword[rbp + 0]; // Closure* cl -constexpr OperandX64 sCode = qword[rbp + 8]; // Instruction* code +// See CodeGenX64.cpp for layout +constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments +constexpr unsigned kLocalsSize = 24; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary) + +constexpr OperandX64 sClosure = qword[rsp + kStackSize + 0]; // Closure* cl +constexpr OperandX64 sCode = qword[rsp + kStackSize + 8]; // Instruction* code // TODO: These should be replaced with a portable call function that checks the ABI at runtime and reorders moves accordingly to avoid conflicts #if defined(_WIN32) diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index 8d06864e..7dc86d3e 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -13,11 +13,10 @@ // https://refspecs.linuxbase.org/elf/x86_64-abi-0.99.pdf [System V Application Binary Interface (AMD64 Architecture Processor Supplement)] // Interaction between Dwarf2 and System V ABI can be found in sections '3.6.2 DWARF Register Number Mapping' and '4.2.4 EH_FRAME sections' -// Call frame instruction opcodes +// Call frame instruction opcodes (Dwarf2, page 78, ch. 7.23 figure 37) #define DW_CFA_advance_loc 0x40 #define DW_CFA_offset 0x80 #define DW_CFA_restore 0xc0 -#define DW_CFA_nop 0x00 #define DW_CFA_set_loc 0x01 #define DW_CFA_advance_loc1 0x02 #define DW_CFA_advance_loc2 0x03 @@ -33,17 +32,11 @@ #define DW_CFA_def_cfa_register 0x0d #define DW_CFA_def_cfa_offset 0x0e #define DW_CFA_def_cfa_expression 0x0f -#define DW_CFA_expression 0x10 -#define DW_CFA_offset_extended_sf 0x11 -#define DW_CFA_def_cfa_sf 0x12 -#define DW_CFA_def_cfa_offset_sf 0x13 -#define DW_CFA_val_offset 0x14 -#define DW_CFA_val_offset_sf 0x15 -#define DW_CFA_val_expression 0x16 +#define DW_CFA_nop 0x00 #define DW_CFA_lo_user 0x1c #define DW_CFA_hi_user 0x3f -// Register numbers for x64 +// Register numbers for x64 (System V ABI, page 57, ch. 3.7, figure 3.36) #define DW_REG_RAX 0 #define DW_REG_RDX 1 #define DW_REG_RCX 2 @@ -197,7 +190,12 @@ void UnwindBuilderDwarf2::allocStack(int size) void UnwindBuilderDwarf2::setupFrameReg(RegisterX64 reg, int espOffset) { - // Not required for unwinding + if (espOffset != 0) + pos = advanceLocation(pos, 5); // REX.W lea rbp, [rsp + imm8] + else + pos = advanceLocation(pos, 3); // REX.W mov rbp, rsp + + // Cfa is based on rsp, so no additonal commands are required } void UnwindBuilderDwarf2::finish() diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index 1b3279e8..13e92ab0 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -77,7 +77,11 @@ void UnwindBuilderWin::setupFrameReg(RegisterX64 reg, int espOffset) frameReg = reg; frameRegOffset = uint8_t(espOffset / 16); - prologSize += 5; // REX.W lea rbp, [rsp + imm8] + if (espOffset != 0) + prologSize += 5; // REX.W lea rbp, [rsp + imm8] + else + prologSize += 3; // REX.W mov rbp, rsp + unwindCodes.push_back({prologSize, UWOP_SET_FPREG, frameRegOffset}); } diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 17fe26b1..15db9ea3 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,6 +13,8 @@ inline bool isFlagExperimental(const char* flag) static const char* kList[] = { "LuauInterpolatedStringBaseSupport", "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code + "LuauOptionalNextKey", // waiting for a fix to land in lua-apps + "LuauTryhardAnd", // waiting for a fix in graphql-lua -> apollo-client-lia -> lua-apps // makes sure we always have at least one entry nullptr, }; diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index d1f29c23..d808ac49 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -69,6 +69,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Unary") { SINGLE_COMPARE(neg(x0, x1), 0xCB0103E0); SINGLE_COMPARE(neg(w0, w1), 0x4B0103E0); + SINGLE_COMPARE(mvn(x0, x1), 0xAA2103E0); SINGLE_COMPARE(clz(x0, x1), 0xDAC01020); SINGLE_COMPARE(clz(w0, w1), 0x5AC01020); @@ -91,19 +92,22 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") SINGLE_COMPARE(lsr(x0, x1, x2), 0x9AC22420); SINGLE_COMPARE(asr(x0, x1, x2), 0x9AC22820); SINGLE_COMPARE(ror(x0, x1, x2), 0x9AC22C20); + SINGLE_COMPARE(cmp(x0, x1), 0xEB01001F); // reg, imm SINGLE_COMPARE(add(x3, x7, 78), 0x910138E3); SINGLE_COMPARE(add(w3, w7, 78), 0x110138E3); SINGLE_COMPARE(sub(w3, w7, 78), 0x510138E3); + SINGLE_COMPARE(cmp(w0, 42), 0x7100A81F); } TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Loads") { // address forms SINGLE_COMPARE(ldr(x0, x1), 0xF9400020); - SINGLE_COMPARE(ldr(x0, AddressA64(x1, 8)), 0xF9400420); - SINGLE_COMPARE(ldr(x0, AddressA64(x1, x7)), 0xF8676820); + SINGLE_COMPARE(ldr(x0, mem(x1, 8)), 0xF9400420); + SINGLE_COMPARE(ldr(x0, mem(x1, x7)), 0xF8676820); + SINGLE_COMPARE(ldr(x0, mem(x1, -7)), 0xF85F9020); // load sizes SINGLE_COMPARE(ldr(x0, x1), 0xF9400020); @@ -121,8 +125,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Stores") { // address forms SINGLE_COMPARE(str(x0, x1), 0xF9000020); - SINGLE_COMPARE(str(x0, AddressA64(x1, 8)), 0xF9000420); - SINGLE_COMPARE(str(x0, AddressA64(x1, x7)), 0xF8276820); + SINGLE_COMPARE(str(x0, mem(x1, 8)), 0xF9000420); + SINGLE_COMPARE(str(x0, mem(x1, x7)), 0xF8276820); + SINGLE_COMPARE(strh(w0, mem(x1, -7)), 0x781F9020); // store sizes SINGLE_COMPARE(str(x0, x1), 0xF9000020); @@ -169,26 +174,69 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "ControlFlow") build.cbz(x0, skip); build.cbnz(x0, skip); build.setLabel(skip); + build.b(skip); }, - {0x54000060, 0xB4000040, 0xB5000020})); + {0x54000060, 0xB4000040, 0xB5000020, 0x5400000E})); // Basic control flow + SINGLE_COMPARE(br(x0), 0xD61F0000); + SINGLE_COMPARE(blr(x0), 0xD63F0000); SINGLE_COMPARE(ret(), 0xD65F03C0); } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "StackOps") +{ + SINGLE_COMPARE(mov(x0, sp), 0x910003E0); + SINGLE_COMPARE(mov(sp, x0), 0x9100001F); + + SINGLE_COMPARE(add(sp, sp, 4), 0x910013FF); + SINGLE_COMPARE(sub(sp, sp, 4), 0xD10013FF); + + SINGLE_COMPARE(add(x0, sp, 4), 0x910013E0); + SINGLE_COMPARE(sub(sp, x0, 4), 0xD100101F); + + SINGLE_COMPARE(ldr(x0, mem(sp, 8)), 0xF94007E0); + SINGLE_COMPARE(str(x0, mem(sp, 8)), 0xF90007E0); +} + +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Constants") +{ + // clang-format off + CHECK(check( + [](AssemblyBuilderA64& build) { + char arr[12] = "hello world"; + build.adr(x0, arr, 12); + build.adr(x0, uint64_t(0x1234567887654321)); + build.adr(x0, 1.0); + }, + { + 0x10ffffa0, 0x10ffff20, 0x10fffec0 + }, + { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x3f, + 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12, + 0x00, 0x00, 0x00, 0x00, // 4b padding to align double + 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', 0x0, + })); + // clang-format on +} + TEST_CASE("LogTest") { AssemblyBuilderA64 build(/* logText= */ true); + build.add(sp, sp, 4); build.add(w0, w1, w2); build.add(x0, x1, x2, 2); build.add(w7, w8, 5); build.add(x7, x8, 5); build.ldr(x7, x8); - build.ldr(x7, AddressA64(x8, 8)); - build.ldr(x7, AddressA64(x8, x9)); + build.ldr(x7, mem(x8, 8)); + build.ldr(x7, mem(x8, x9)); build.mov(x1, x2); build.movk(x1, 42, 16); + build.cmp(x1, x2); + build.blr(x0); Label l; build.b(ConditionA64::Plus, l); @@ -200,6 +248,7 @@ TEST_CASE("LogTest") build.finalize(); std::string expected = R"( + add sp,sp,#4 add w0,w1,w2 add x0,x1,x2 LSL #2 add w7,w8,#5 @@ -209,6 +258,8 @@ TEST_CASE("LogTest") ldr x7,[x8,x9] mov x1,x2 movk x1,#42 LSL #16 + cmp x1,x2 + blr x0 b.pl .L1 cbz x7,.L1 .L1: diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 2e9a4b37..65b485a7 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -185,7 +185,7 @@ TEST_CASE("Dwarf2UnwindCodesX64") 0x00, 0x00, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x0e, 0x10, 0x85, 0x02, 0x02, 0x02, 0x0e, 0x18, 0x84, 0x03, 0x02, 0x02, 0x0e, 0x20, 0x83, 0x04, 0x02, 0x02, 0x0e, 0x28, 0x86, 0x05, 0x02, 0x02, 0x0e, 0x30, 0x8c, 0x06, 0x02, 0x02, 0x0e, 0x38, 0x8d, 0x07, 0x02, 0x02, 0x0e, 0x40, - 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + 0x8e, 0x08, 0x02, 0x02, 0x0e, 0x48, 0x8f, 0x09, 0x02, 0x04, 0x0e, 0x90, 0x01, 0x02, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00}; REQUIRE(data.size() == expected.size()); CHECK(memcmp(data.data(), expected.data(), expected.size()) == 0); @@ -446,7 +446,12 @@ TEST_CASE("GeneratedCodeExecutionA64") build.mov(x1, 0); // doesn't execute due to cbnz above build.setLabel(skip); - build.add(x1, x1, 1); + uint8_t one = 1; + build.adr(x2, &one, 1); + build.ldrb(w2, x2); + build.sub(x1, x1, x2); + + build.add(x1, x1, 2); build.add(x0, x0, x1, /* LSL */ 1); build.ret(); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index d011719e..30e1b2e6 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -21,13 +21,13 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) frontend.getGlobalScope(), &logger, NotNull{dfg.get()}); cgb->visit(root); rootScope = cgb->rootScope; - constraints = Luau::collectConstraints(NotNull{cgb->rootScope}); + constraints = Luau::borrowConstraints(cgb->constraints); } void ConstraintGraphBuilderFixture::solve(const std::string& code) { generateConstraints(code); - ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, "MainModule", NotNull(&moduleResolver), {}, &logger}; + ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger}; cs.run(); } diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index df0abdc9..4e72dd4e 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1054,10 +1054,6 @@ TEST_CASE("check_without_builtin_next") TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_cyclic_type") { - ScopedFastFlag sff[] = { - {"LuauForceExportSurfacesToBeNormal", true}, - }; - fileResolver.source["Module/A"] = R"( type F = (set: G) -> () @@ -1089,10 +1085,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_cyclic_type") TEST_CASE_FIXTURE(BuiltinsFixture, "reexport_type_alias") { - ScopedFastFlag sff[] = { - {"LuauForceExportSurfacesToBeNormal", true}, - }; - fileResolver.source["Module/A"] = R"( type KeyOfTestEvents = "test-file-start" | "test-file-success" | "test-file-failure" | "test-case-result" type MyAny = any diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 77cf6130..c0989a2e 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2823,4 +2823,21 @@ TEST_CASE_FIXTURE(Fixture, "get_a_nice_error_when_there_is_no_comma_after_last_t CHECK(table->items.size == 1); } +TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_type_parameter") +{ + ScopedFastFlag sff{"LuauParserErrorsOnMissingDefaultTypePackArgument", true}; + + ParseResult result = tryParse(R"( + type Foo = nil + )"); + + REQUIRE_EQ(2, result.errors.size()); + + CHECK_EQ(Location{{1, 23}, {1, 25}}, result.errors[0].getLocation()); + CHECK_EQ("Expected type, got '>'", result.errors[0].getMessage()); + + CHECK_EQ(Location{{1, 23}, {1, 24}}, result.errors[1].getLocation()); + CHECK_EQ("Expected type pack after '=', got type", result.errors[1].getMessage()); +} + TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index 18c243b0..c22d464e 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -404,3 +404,20 @@ t60 = makeChainedTable(60) } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("RegressionTests"); + +TEST_CASE_FIXTURE(ReplFixture, "InfiniteRecursion") +{ + // If the infinite recrusion is not caught, test will fail + runCode(L, R"( +local NewProxyOne = newproxy(true) +local MetaTableOne = getmetatable(NewProxyOne) +MetaTableOne.__index = function() + return NewProxyOne.Game +end +print(NewProxyOne.HelloICauseACrash) +)"); +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index a510f914..8bb1fbaf 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -10,7 +10,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); -LUAU_FASTFLAG(LuauFixNameMaps); LUAU_FASTFLAG(LuauFunctionReturnStringificationFixup); TEST_SUITE_BEGIN("ToString"); @@ -266,11 +265,23 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") ToStringOptions o; o.exhaustive = false; - o.maxTypeLength = 40; - CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + o.maxTypeLength = 30; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> (a) -> () -> ()"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> (b) -> (a) -> (... *TRUNCATED*"); + } + else + { + o.maxTypeLength = 40; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + } } TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") @@ -285,11 +296,22 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") ToStringOptions o; o.exhaustive = true; - o.maxTypeLength = 40; - CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); - CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); - CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); - CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + o.maxTypeLength = 30; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(a) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "(b) -> (a) -> () -> ()"); + CHECK_EQ(toString(requireType("f3"), o), "(c) -> (b) -> (a) -> (... *TRUNCATED*"); + } + else + { + o.maxTypeLength = 40; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... *TRUNCATED*"); + } } TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table_state_braces") @@ -423,28 +445,19 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TypeId id3Type = requireType("id3"); ToStringResult nameData = toStringDetailed(id3Type, opts); - if (FFlag::LuauFixNameMaps) - REQUIRE(3 == opts.nameMap.typeVars.size()); - else - REQUIRE_EQ(3, nameData.DEPRECATED_nameMap.typeVars.size()); + REQUIRE(3 == opts.nameMap.typeVars.size()); REQUIRE_EQ("(a, b, c) -> (a, b, c)", nameData.name); - ToStringOptions opts2; // TODO: delete opts2 when clipping FFlag::LuauFixNameMaps - if (FFlag::LuauFixNameMaps) - opts2.nameMap = std::move(opts.nameMap); - else - opts2.DEPRECATED_nameMap = std::move(nameData.DEPRECATED_nameMap); - const FunctionTypeVar* ftv = get(follow(id3Type)); REQUIRE(ftv != nullptr); auto params = flatten(ftv->argTypes).first; REQUIRE(3 == params.size()); - CHECK("a" == toString(params[0], opts2)); - CHECK("b" == toString(params[1], opts2)); - CHECK("c" == toString(params[2], opts2)); + CHECK("a" == toString(params[0], opts)); + CHECK("b" == toString(params[1], opts)); + CHECK("c" == toString(params[2], opts)); } TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") @@ -471,13 +484,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") TypeId tType = requireType("inst"); ToStringResult r = toStringDetailed(tType, opts); CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); - if (FFlag::LuauFixNameMaps) - CHECK(0 == opts.nameMap.typeVars.size()); - else - CHECK_EQ(0, r.DEPRECATED_nameMap.typeVars.size()); - - if (!FFlag::LuauFixNameMaps) - opts.DEPRECATED_nameMap = r.DEPRECATED_nameMap; + CHECK(0 == opts.nameMap.typeVars.size()); const MetatableTypeVar* tMeta = get(tType); REQUIRE(tMeta); @@ -502,8 +509,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") REQUIRE(tMeta6->props.count("two") > 0); ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts); - if (!FFlag::LuauFixNameMaps) - opts.DEPRECATED_nameMap = oneResult.DEPRECATED_nameMap; std::string twoResult = toString(tMeta6->props["two"].type, opts); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 8c738b7d..ef40e278 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -8,6 +8,7 @@ using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) +LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) TEST_SUITE_BEGIN("TypeAliases"); @@ -509,11 +510,21 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "general_require_multi_assign") TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") { + ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; + CheckResult result = check("type t10 = typeof(table)"); LUAU_REQUIRE_NO_ERRORS(result); TypeId ty = getGlobalBinding(frontend, "table"); - CHECK_EQ(toString(ty), "table"); + + if (FFlag::LuauNoMoreGlobalSingletonTypes) + { + CHECK_EQ(toString(ty), "typeof(table)"); + } + else + { + CHECK_EQ(toString(ty), "table"); + } const TableTypeVar* ttv = get(ty); REQUIRE(ttv); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 7c465c53..787aea9a 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -57,7 +57,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "next_iterator_should_infer_types_and_type_ch local s = "foo" local t = { [s] = 1 } - local c: string, d: number = next(t) + local c: string?, d: number = next(t) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -69,7 +69,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "pairs_iterator_should_infer_types_and_type_c type Map = { [K]: V } local map: Map = { ["foo"] = 1, ["bar"] = 2, ["baz"] = 3 } - local it: (Map, string | nil) -> (string, number), t: Map, i: nil = pairs(map) + local it: (Map, string | nil) -> (string?, number), t: Map, i: nil = pairs(map) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -81,7 +81,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "ipairs_iterator_should_infer_types_and_type_ type Map = { [K]: V } local array: Map = { "foo", "bar", "baz" } - local it: (Map, number) -> (number, string), t: Map, i: number = ipairs(array) + local it: (Map, number) -> (number?, string), t: Map, i: number = ipairs(array) )"); LUAU_REQUIRE_NO_ERRORS(result); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index ddf73349..b306515a 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -532,7 +532,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") REQUIRE_EQ(2, argVec.size()); const FunctionTypeVar* fType = get(follow(argVec[0])); - REQUIRE(fType != nullptr); + REQUIRE_MESSAGE(fType != nullptr, "Expected a function but got " << toString(argVec[0])); std::vector fArgs = flatten(fType->argTypes).first; diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index b2516f6d..c4bbc2e1 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -50,14 +50,18 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") CHECK_EQ(*requireType("s"), *typeChecker.stringType); } -TEST_CASE_FIXTURE(Fixture, "and_adds_boolean") +TEST_CASE_FIXTURE(Fixture, "and_does_not_always_add_boolean") { + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + }; + CheckResult result = check(R"( local s = "a" and 10 local x:boolean|number = s )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*requireType("s")), "boolean | number"); + CHECK_EQ(toString(*requireType("s")), "number"); } TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") @@ -971,4 +975,79 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "mm_comparisons_must_return_a_boolean") CHECK(toString(result.errors[1]) == "Metamethod '__lt' must return type 'boolean'"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_and") +{ + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + }; + + CheckResult result = check(R"( +local a: number? = 5 +local b: boolean = (a or 1) > 10 +local c -- free + +local x = a and 1 +local y = 'a' and 1 +local z = b and 1 +local w = c and 1 + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK("number?" == toString(requireType("x"))); + CHECK("number" == toString(requireType("y"))); + CHECK("false | number" == toString(requireType("z"))); + CHECK("number" == toString(requireType("w"))); // Normalizer considers free & falsy == never + } + else + { + CHECK("number?" == toString(requireType("x"))); + CHECK("number" == toString(requireType("y"))); + CHECK("boolean | number" == toString(requireType("z"))); // 'false' widened to boolean + CHECK("(boolean | number)?" == toString(requireType("w"))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "reworked_or") +{ + ScopedFastFlag sff[]{ + {"LuauTryhardAnd", true}, + }; + + CheckResult result = check(R"( +local a: number | false = 5 +local b: number? = 6 +local c: boolean = true +local d: true = true +local e: false = false +local f: nil = false + +local a1 = a or 'a' +local b1 = b or 4 +local c1 = c or 'c' +local d1 = d or 'd' +local e1 = e or 'e' +local f1 = f or 'f' + )"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK("number | string" == toString(requireType("a1"))); + CHECK("number" == toString(requireType("b1"))); + CHECK("string | true" == toString(requireType("c1"))); + CHECK("string | true" == toString(requireType("d1"))); + CHECK("string" == toString(requireType("e1"))); + CHECK("string" == toString(requireType("f1"))); + } + else + { + CHECK("number | string" == toString(requireType("a1"))); + CHECK("number" == toString(requireType("b1"))); + CHECK("boolean | string" == toString(requireType("c1"))); // 'true' widened to boolean + CHECK("boolean | string" == toString(requireType("d1"))); // 'true' widened to boolean + CHECK("string" == toString(requireType("e1"))); + CHECK("string" == toString(requireType("f1"))); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index f6e60cdc..5688eaaa 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -461,23 +461,27 @@ TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") LUAU_REQUIRE_NO_ERRORS(result); - // Solving this requires recognizing that we can partially solve the - // following constraint: + // Solving this requires recognizing that we can't dispatch a constraint + // like this without doing further work: // // (*blocked*) -> () <: (number) -> (b...) // - // The correct thing for us to do is to consider the constraint dispatched, - // but we need to also record a new constraint number <: *blocked* to finish - // the job later. + // We solve this by searching both types for BlockedTypeVars and block the + // constraint on any we find. It also gets the job done, but I'm worried + // about the efficiency of doing so many deep type traversals and it may + // make us more prone to getting stuck on constraint cycles. + // + // If this doesn't pan out, a possible solution is to go further down the + // path of supporting partial constraint dispatch. The way it would work is + // that we'd dispatch the above constraint by binding b... to (), but we + // would append a new constraint number <: *blocked* to the constraint set + // to be solved later. This should be faster and theoretically less prone + // to cyclic constraint dependencies. CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); } TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") { - ScopedFastFlag sff[] = { - {"LuauFixNameMaps", true}, - }; - TypeArena arena; TypeId nilType = singletonTypes->nilType; @@ -522,8 +526,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") // Ideally, we would not try to export a function type with generic types from incorrect scope TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface") { - ScopedFastFlag LuauAnyifyModuleReturnGenerics{"LuauAnyifyModuleReturnGenerics", true}; - fileResolver.source["game/A"] = R"( local wrapStrictTable @@ -563,8 +565,6 @@ return wrapStrictTable(Constants, "Constants") TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_leak_to_module_interface_variadic") { - ScopedFastFlag LuauAnyifyModuleReturnGenerics{"LuauAnyifyModuleReturnGenerics", true}; - fileResolver.source["game/A"] = R"( local wrapStrictTable diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 26f23438..66550be3 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -35,7 +35,7 @@ std::optional> magicFunctionInstanceIsA( return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } -struct RefinementClassFixture : Fixture +struct RefinementClassFixture : BuiltinsFixture { RefinementClassFixture() { @@ -320,7 +320,7 @@ TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") } } -TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_if_condition_position") { CheckResult result = check(R"( function f(s: any) @@ -332,7 +332,14 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("any & number", toString(requireTypeAtPosition({3, 26}))); + } + else + { + CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") @@ -344,10 +351,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") )"); LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("number", toString(requireType("b"))); } -TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") +TEST_CASE_FIXTURE(BuiltinsFixture, "call_an_incompatible_function_after_using_typeguard") { CheckResult result = check(R"( local function f(x: number) @@ -362,6 +370,7 @@ TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } @@ -648,7 +657,7 @@ TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_narrow_to_vector") { CheckResult result = check(R"( local function f(x) @@ -663,7 +672,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") CHECK_EQ("*error-type*", toString(requireTypeAtPosition({3, 28}))); } -TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") +TEST_CASE_FIXTURE(BuiltinsFixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") { CheckResult result = check(R"( local t = {"hello"} @@ -690,7 +699,7 @@ TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true" CHECK_EQ("string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" } -TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_not_to_be_string") { CheckResult result = check(R"( local function f(x: string | number | boolean) @@ -704,11 +713,19 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("boolean | number", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(boolean | number | string) & ~string", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" + CHECK_EQ("(boolean | number | string) & string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" + } + else + { + CHECK_EQ("boolean | number", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" + } } -TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_narrows_for_table") { CheckResult result = check(R"( local function f(x: string | {x: number} | {y: boolean}) @@ -726,7 +743,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "table" } -TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_narrows_for_functions") { CheckResult result = check(R"( local function weird(x: string | ((number) -> string)) @@ -740,11 +757,19 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(number) -> string", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(((number) -> string) | string) & function", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" + CHECK_EQ("(((number) -> string) | string) & ~function", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" + } + else + { + CHECK_EQ("(number) -> string", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" + } } -TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_can_filter_for_intersection_of_tables") { CheckResult result = check(R"( type XYCoord = {x: number} & {y: number} @@ -763,7 +788,7 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); } -TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_can_filter_for_overloaded_function") { CheckResult result = check(R"( type SomeOverloadedFunction = ((number) -> string) & ((string) -> number) @@ -778,8 +803,16 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("((number) -> string) & ((string) -> number)", toString(requireTypeAtPosition({4, 28}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("((((number) -> string) & ((string) -> number))?) & function", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("((((number) -> string) & ((string) -> number))?) & ~function", toString(requireTypeAtPosition({6, 28}))); + } + else + { + CHECK_EQ("((number) -> string) & ((string) -> number)", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "type_guard_narrowed_into_nothingness") @@ -884,7 +917,7 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") } } -TEST_CASE_FIXTURE(Fixture, "either_number_or_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "either_number_or_string") { CheckResult result = check(R"( local function f(x: any) @@ -896,7 +929,14 @@ TEST_CASE_FIXTURE(Fixture, "either_number_or_string") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(number | string) & any", toString(requireTypeAtPosition({3, 28}))); + } + else + { + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); + } } TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") @@ -946,10 +986,17 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "merge_should_be_fully_agnostic_of_hashmap_or LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireTypeAtPosition({6, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(string | {| x: string |}) & string", toString(requireTypeAtPosition({6, 28}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({6, 28}))); + } } -TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") +TEST_CASE_FIXTURE(BuiltinsFixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") { CheckResult result = check(R"( local function f(a: string | number | boolean) @@ -963,8 +1010,16 @@ TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_n LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(boolean | number | string) & ~number & ~string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(boolean | number | string) & (number | string)", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "is_truthy_constraint_ifelse_expression") @@ -995,7 +1050,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "invert_is_truthy_constraint_ifelse_expressio CHECK_EQ("string", toString(requireTypeAtPosition({2, 50}))); } -TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") +TEST_CASE_FIXTURE(BuiltinsFixture, "type_comparison_ifelse_expression") { CheckResult result = check(R"( function returnOne(x) @@ -1027,7 +1082,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_a_shadowed_local_that_which CHECK_EQ("Type 'number' does not have key 'sub'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined") +TEST_CASE_FIXTURE(BuiltinsFixture, "correctly_lookup_property_whose_base_was_previously_refined") { CheckResult result = check(R"( type T = {x: string | number} @@ -1246,8 +1301,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Instance | Vector3) & Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Instance | Vector3) & ~Vector3", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") @@ -1282,14 +1345,22 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Folder | Part | string) & Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Folder | Part | string) & ~Instance", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); + } } -TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") +TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_from_subclasses_of_instance_or_string_or_vector3") { CheckResult result = check(R"( - local function f(x: Part | Folder | Instance | string | Vector3 | any) + local function f(x: Part | Folder | string | Vector3) if typeof(x) == "Instance" then local foo = x else @@ -1300,8 +1371,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("(Folder | Part | Vector3 | string) & Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("(Folder | Part | Vector3 | string) & ~Instance", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Vector3 | string", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") @@ -1342,7 +1421,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); } -TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") +TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_doesnt_leak_to_elseif") { CheckResult result = check(R"( function f(a) @@ -1373,8 +1452,16 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("unknown", toString(requireTypeAtPosition({5, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("unknown & string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("unknown & ~string", toString(requireTypeAtPosition({5, 28}))); + } + else + { + CHECK_EQ("string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("unknown", toString(requireTypeAtPosition({5, 28}))); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") @@ -1408,7 +1495,35 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "what_nonsensical_condition") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + CHECK_EQ("a & number & string", toString(requireTypeAtPosition({3, 28}))); + } + else + { + CHECK_EQ("never", toString(requireTypeAtPosition({3, 28}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "else_with_no_explicit_expression_should_also_refine_the_tagged_union") +{ + ScopedFastFlag sff{"LuauImplicitElseRefinement", true}; + + CheckResult result = check(R"( + type Ok = { tag: "ok", value: T } + type Err = { tag: "err", err: E } + type Result = Ok | Err + + function and_then(r: Result, f: (T) -> U): Result + if r.tag == "ok" then + return { tag = "ok", value = f(r.value) } + else + return r + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 68757fef..bf0a0af6 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -17,6 +17,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauNoMoreGlobalSingletonTypes) TEST_SUITE_BEGIN("TableTests"); @@ -1721,6 +1722,8 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") { + ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; + CheckResult result = check(R"( os.h = 2 string.k = 3 @@ -1728,19 +1731,36 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_table_names") LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ("Cannot add property 'h' to table 'os'", toString(result.errors[0])); - CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); + if (FFlag::LuauNoMoreGlobalSingletonTypes) + { + CHECK_EQ("Cannot add property 'h' to table 'typeof(os)'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'k' to table 'typeof(string)'", toString(result.errors[1])); + } + else + { + CHECK_EQ("Cannot add property 'h' to table 'os'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); + } } TEST_CASE_FIXTURE(BuiltinsFixture, "persistent_sealed_table_is_immutable") { + ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; + CheckResult result = check(R"( --!nonstrict function os:bad() end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Cannot add property 'bad' to table 'os'", toString(result.errors[0])); + if (FFlag::LuauNoMoreGlobalSingletonTypes) + { + CHECK_EQ("Cannot add property 'bad' to table 'typeof(os)'", toString(result.errors[0])); + } + else + { + CHECK_EQ("Cannot add property 'bad' to table 'os'", toString(result.errors[0])); + } const TableTypeVar* osType = get(requireType("os")); REQUIRE(osType != nullptr); @@ -3188,6 +3208,7 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shap TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; CheckResult result = check(R"( local function f(s) @@ -3200,25 +3221,47 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_ )"); LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + + if (FFlag::LuauNoMoreGlobalSingletonTypes) + { + CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[0])); + CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[1])); + CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[2])); + } + else + { + CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); - CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + toString(result.errors[0])); + CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[1])); - CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + toString(result.errors[1])); + CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[2])); + toString(result.errors[2])); + } } TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner", true}; // Changes argument from table type to primitive CheckResult result = check(R"( local function f(s): string @@ -3234,6 +3277,7 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compati TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible") { ScopedFastFlag sff{"LuauScalarShapeSubtyping", true}; + ScopedFastFlag luauNewLibraryTypeNames{"LuauNewLibraryTypeNames", true}; CheckResult result = check(R"( local function f(s): string @@ -3243,11 +3287,42 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_ )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' + if (FFlag::LuauNoMoreGlobalSingletonTypes) + { + CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' +caused by: + The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')", + toString(result.errors[0])); + CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); + } + else + { + CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' caused by: The former's metatable does not satisfy the requirements. Table type 'string' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); - CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); + toString(result.errors[0])); + CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "a_free_shape_can_turn_into_a_scalar_directly") +{ + ScopedFastFlag luauScalarShapeSubtyping{"LuauScalarShapeSubtyping", true}; + ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner", true}; + + CheckResult result = check(R"( + local function stringByteList(str) + local out = {} + for i = 1, #str do + table.insert(out, string.byte(str, i)) + end + return table.concat(out, ",") + end + + local x = stringByteList("xoo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_call_is_unsound") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 6c7201a6..04b8bf57 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1125,43 +1125,6 @@ TEST_CASE_FIXTURE(Fixture, "bidirectional_checking_of_higher_order_function") CHECK(location.end.line == 4); } -TEST_CASE_FIXTURE(Fixture, "dcr_can_partially_dispatch_a_constraint") -{ - ScopedFastFlag sff[] = { - {"DebugLuauDeferredConstraintResolution", true}, - }; - - CheckResult result = check(R"( - local function hasDivisors(value: number) - end - - function prime_iter(state, index) - hasDivisors(index) - index += 1 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Solving this requires recognizing that we can't dispatch a constraint - // like this without doing further work: - // - // (*blocked*) -> () <: (number) -> (b...) - // - // We solve this by searching both types for BlockedTypeVars and block the - // constraint on any we find. It also gets the job done, but I'm worried - // about the efficiency of doing so many deep type traversals and it may - // make us more prone to getting stuck on constraint cycles. - // - // If this doesn't pan out, a possible solution is to go further down the - // path of supporting partial constraint dispatch. The way it would work is - // that we'd dispatch the above constraint by binding b... to (), but we - // would append a new constraint number <: *blocked* to the constraint set - // to be solved later. This should be faster and theoretically less prone - // to cyclic constraint dependencies. - CHECK("(a, number) -> ()" == toString(requireType("prime_iter"))); -} - TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_return_values_in_nonstrict") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 4c8eeac6..cde651df 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1002,8 +1002,6 @@ TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments_free") TEST_CASE_FIXTURE(BuiltinsFixture, "type_packs_with_tails_in_vararg_adjustment") { - ScopedFastFlag luauFixVarargExprHeadType{"LuauFixVarargExprHeadType", true}; - CheckResult result = check(R"( local function wrapReject(fn: (self: any, ...TArg) -> ...TResult): (self: any, ...TArg) -> ...TResult return function(self, ...) diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 0c25386f..627fbb56 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -395,6 +395,23 @@ local e = a.z CHECK_EQ("Type 'A | B | C | D' does not have key 'z'", toString(result.errors[3])); } +TEST_CASE_FIXTURE(Fixture, "optional_iteration") +{ + ScopedFastFlag luauNilIterator{"LuauNilIterator", true}; + + CheckResult result = check(R"( +function foo(values: {number}?) + local s = 0 + for _, value in values do + s += value + end +end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{number}?' could be nil", toString(result.errors[0])); +} + TEST_CASE_FIXTURE(Fixture, "unify_unsealed_table_union_check") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.unknownnever.test.cpp b/tests/TypeInfer.unknownnever.test.cpp index 2288db4e..9d1e46c5 100644 --- a/tests/TypeInfer.unknownnever.test.cpp +++ b/tests/TypeInfer.unknownnever.test.cpp @@ -284,6 +284,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_i ScopedFastFlag sff[]{ {"LuauUnknownAndNeverType", true}, {"LuauNeverTypesAndOperatorsInference", true}, + {"LuauTryhardAnd", true}, }; CheckResult result = check(R"( @@ -293,7 +294,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_unify_operands_if_one_of_the_operand_is_never_i )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(nil, a) -> boolean", toString(requireType("ord"))); + // Widening doesn't normalize yet, so the result is a bit strange + CHECK_EQ("(nil, a) -> boolean | boolean", toString(requireType("ord"))); } TEST_CASE_FIXTURE(Fixture, "math_operators_and_never") diff --git a/tools/faillist.txt b/tools/faillist.txt index 4ac2b357..a228c171 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -32,10 +32,8 @@ AutocompleteTest.type_correct_expected_argument_type_suggestion_optional AutocompleteTest.type_correct_expected_argument_type_suggestion_self AutocompleteTest.type_correct_expected_return_type_pack_suggestion AutocompleteTest.type_correct_expected_return_type_suggestion -AutocompleteTest.type_correct_full_type_suggestion AutocompleteTest.type_correct_function_no_parenthesis AutocompleteTest.type_correct_function_return_types -AutocompleteTest.type_correct_function_type_suggestion AutocompleteTest.type_correct_keywords AutocompleteTest.type_correct_suggestion_for_overloads AutocompleteTest.type_correct_suggestion_in_argument @@ -53,12 +51,10 @@ BuiltinTests.dont_add_definitions_to_persistent_types BuiltinTests.find_capture_types BuiltinTests.find_capture_types2 BuiltinTests.find_capture_types3 -BuiltinTests.gmatch_capture_types BuiltinTests.gmatch_capture_types2 BuiltinTests.gmatch_capture_types_balanced_escaped_parens BuiltinTests.gmatch_capture_types_default_capture BuiltinTests.gmatch_capture_types_parens_in_sets_are_ignored -BuiltinTests.gmatch_capture_types_set_containing_lbracket BuiltinTests.gmatch_definition BuiltinTests.ipairs_iterator_should_infer_types_and_type_check BuiltinTests.match_capture_types @@ -74,13 +70,12 @@ BuiltinTests.set_metatable_needs_arguments BuiltinTests.setmetatable_should_not_mutate_persisted_types BuiltinTests.sort_with_bad_predicate BuiltinTests.string_format_arg_count_mismatch -BuiltinTests.string_format_arg_types_inference BuiltinTests.string_format_as_method BuiltinTests.string_format_correctly_ordered_types BuiltinTests.string_format_report_all_type_errors_at_correct_positions BuiltinTests.string_format_use_correct_argument BuiltinTests.string_format_use_correct_argument2 -BuiltinTests.string_format_use_correct_argument3 +BuiltinTests.strings_have_methods BuiltinTests.table_freeze_is_generic BuiltinTests.table_insert_correctly_infers_type_of_array_2_args_overload BuiltinTests.table_insert_correctly_infers_type_of_array_3_args_overload @@ -114,7 +109,6 @@ GenericsTests.generic_factories GenericsTests.generic_functions_should_be_memory_safe GenericsTests.generic_table_method GenericsTests.generic_type_pack_parentheses -GenericsTests.generic_type_pack_unification1 GenericsTests.generic_type_pack_unification2 GenericsTests.higher_rank_polymorphism_should_not_accept_instantiated_arguments GenericsTests.infer_generic_function_function_argument @@ -174,46 +168,36 @@ ProvisionalTests.while_body_are_also_refined RefinementTest.apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string RefinementTest.assert_a_to_be_truthy_then_assert_a_to_be_number RefinementTest.assert_non_binary_expressions_actually_resolve_constraints -RefinementTest.call_a_more_specific_function_using_typeguard +RefinementTest.call_an_incompatible_function_after_using_typeguard RefinementTest.correctly_lookup_property_whose_base_was_previously_refined RefinementTest.correctly_lookup_property_whose_base_was_previously_refined2 RefinementTest.discriminate_from_isa_of_x RefinementTest.discriminate_from_truthiness_of_x RefinementTest.discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false RefinementTest.discriminate_tag -RefinementTest.either_number_or_string -RefinementTest.eliminate_subclasses_of_instance +RefinementTest.else_with_no_explicit_expression_should_also_refine_the_tagged_union RefinementTest.falsiness_of_TruthyPredicate_narrows_into_nil RefinementTest.index_on_a_refined_property RefinementTest.invert_is_truthy_constraint_ifelse_expression RefinementTest.is_truthy_constraint_ifelse_expression -RefinementTest.merge_should_be_fully_agnostic_of_hashmap_ordering RefinementTest.narrow_property_of_a_bounded_variable -RefinementTest.narrow_this_large_union RefinementTest.nonoptional_type_can_narrow_to_nil_if_sense_is_true RefinementTest.not_t_or_some_prop_of_t RefinementTest.refine_a_property_not_to_be_nil_through_an_intersection_table -RefinementTest.refine_the_correct_types_opposite_of_when_a_is_not_number_or_string RefinementTest.refine_unknowns RefinementTest.truthy_constraint_on_properties RefinementTest.type_comparison_ifelse_expression RefinementTest.type_guard_can_filter_for_intersection_of_tables -RefinementTest.type_guard_can_filter_for_overloaded_function RefinementTest.type_guard_narrowed_into_nothingness RefinementTest.type_narrow_for_all_the_userdata RefinementTest.type_narrow_to_vector RefinementTest.typeguard_cast_free_table_to_vector -RefinementTest.typeguard_cast_instance_or_vector3_to_vector -RefinementTest.typeguard_doesnt_leak_to_elseif RefinementTest.typeguard_in_assert_position -RefinementTest.typeguard_in_if_condition_position -RefinementTest.typeguard_narrows_for_functions RefinementTest.typeguard_narrows_for_table -RefinementTest.typeguard_not_to_be_string -RefinementTest.what_nonsensical_condition RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table RefinementTest.x_is_not_instance_or_else_not_part RuntimeLimits.typescript_port_of_Result_type +TableTests.a_free_shape_can_turn_into_a_scalar_directly TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.access_index_metamethod_that_returns_variadic @@ -249,7 +233,6 @@ TableTests.generic_table_instantiation_potential_regression TableTests.getmetatable_returns_pointer_to_metatable TableTests.give_up_after_one_metatable_index_look_up TableTests.hide_table_error_properties -TableTests.indexer_fn TableTests.indexer_on_sealed_table_must_unify_with_free_table TableTests.indexing_from_a_table_should_prefer_properties_when_possible TableTests.inequality_operators_imply_exactly_matching_types @@ -262,7 +245,6 @@ TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_i TableTests.invariant_table_properties_means_instantiating_tables_in_call_is_unsound TableTests.leaking_bad_metatable_errors TableTests.less_exponential_blowup_please -TableTests.meta_add TableTests.meta_add_both_ways TableTests.meta_add_inferred TableTests.metatable_mismatch_should_fail @@ -389,7 +371,6 @@ TypeInferFunctions.improved_function_arg_mismatch_error_nonstrict TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_return_type_from_selected_overload -TypeInferFunctions.infer_return_value_type TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.list_all_overloads_if_no_overload_takes_given_argument_count TypeInferFunctions.list_only_alternative_overloads_that_match_argument_count @@ -409,10 +390,12 @@ TypeInferFunctions.too_many_return_values TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function TypeInferFunctions.vararg_function_is_quantified +TypeInferLoops.for_in_loop TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values TypeInferLoops.for_in_loop_with_next TypeInferLoops.for_in_with_generic_next TypeInferLoops.for_in_with_just_one_iterator_is_ok +TypeInferLoops.loop_iter_metamethod_ok_with_inference TypeInferLoops.loop_iter_no_indexer_nonstrict TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.unreachable_code_after_infinite_loop @@ -430,8 +413,6 @@ TypeInferOOP.dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted TypeInferOOP.object_constructor_can_refer_to_method_of_self -TypeInferOperators.and_or_ternary -TypeInferOperators.CallAndOrOfFunctions TypeInferOperators.cannot_compare_tables_that_do_not_have_the_same_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_have_a_metatable TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators @@ -517,6 +498,7 @@ UnionTypes.optional_assignment_errors UnionTypes.optional_call_error UnionTypes.optional_field_access_error UnionTypes.optional_index_error +UnionTypes.optional_iteration UnionTypes.optional_length_error UnionTypes.optional_missing_key_error_details UnionTypes.optional_union_follow From aa7c64517ca550eff7087a197e667f8676db7c17 Mon Sep 17 00:00:00 2001 From: boyned//Kampfkarren Date: Wed, 16 Nov 2022 10:15:01 -0800 Subject: [PATCH 2/2] Fix string interpolation autocomplete and location (#748) String interpolation autocomplete was not working and was just using default autocomplete. This fixes it so that inside a string it is treated as normal, and inside an expression it suggests expression level autocomplete. Also fixes the range, which was previously just the first lexeme. --- Analysis/src/Autocomplete.cpp | 38 ++++++++++++++++++-- Ast/src/Parser.cpp | 50 +++++++++++++++++++------- tests/AstJsonEncoder.test.cpp | 2 +- tests/Autocomplete.test.cpp | 67 ++++++++++++++++++++++++++++++++++- 4 files changed, 140 insertions(+), 17 deletions(-) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 224e9440..50dc254f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -1219,6 +1219,31 @@ static std::optional getMethodContainingClass(const ModuleP return std::nullopt; } +static bool stringPartOfInterpString(const AstNode* node, Position position) +{ + const AstExprInterpString* interpString = node->as(); + if (!interpString) + { + return false; + } + + for (const AstExpr* expression : interpString->expressions) + { + if (expression->location.containsClosed(position)) + { + return false; + } + } + + return true; +} + +static bool isSimpleInterpolatedString(const AstNode* node) +{ + const AstExprInterpString* interpString = node->as(); + return interpString != nullptr && interpString->expressions.size == 0; +} + static std::optional autocompleteStringParams(const SourceModule& sourceModule, const ModulePtr& module, const std::vector& nodes, Position position, StringCompletionCallback callback) { @@ -1227,7 +1252,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - if (!nodes.back()->is() && !nodes.back()->is()) + if (!nodes.back()->is() && !isSimpleInterpolatedString(nodes.back()) && !nodes.back()->is()) { return std::nullopt; } @@ -1432,7 +1457,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is())) + else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is() || node->is())) { for (const auto& [kind, key, value] : exprTable->items) { @@ -1471,7 +1496,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { return {*ret, ancestry, AutocompleteContext::String}; } - else if (node->is()) + else if (node->is() || isSimpleInterpolatedString(node)) { AutocompleteEntryMap result; @@ -1497,6 +1522,13 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return {result, ancestry, AutocompleteContext::String}; } + else if (stringPartOfInterpString(node, position)) + { + // We're not a simple interpolated string, we're something like `a{"b"}@1`, and we + // can't know what to format to + AutocompleteEntryMap map; + return {map, ancestry, AutocompleteContext::String}; + } if (node->is()) return {}; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 8338a04a..85b0d31a 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -2661,6 +2661,7 @@ AstExpr* Parser::parseInterpString() TempVector expressions(scratchExpr); Location startLocation = lexer.current().location; + Location endLocation; do { @@ -2668,16 +2669,16 @@ AstExpr* Parser::parseInterpString() LUAU_ASSERT(currentLexeme.type == Lexeme::InterpStringBegin || currentLexeme.type == Lexeme::InterpStringMid || currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple); - Location location = currentLexeme.location; + endLocation = currentLexeme.location; - Location startOfBrace = Location(location.end, 1); + Location startOfBrace = Location(endLocation.end, 1); scratchData.assign(currentLexeme.data, currentLexeme.length); if (!Lexer::fixupQuotedString(scratchData)) { nextLexeme(); - return reportExprError(startLocation, {}, "Interpolated string literal contains malformed escape sequence"); + return reportExprError(Location{startLocation, endLocation}, {}, "Interpolated string literal contains malformed escape sequence"); } AstArray chars = copy(scratchData); @@ -2688,15 +2689,36 @@ AstExpr* Parser::parseInterpString() if (currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple) { - AstArray> stringsArray = copy(strings); - AstArray expressionsArray = copy(expressions); - - return allocator.alloc(startLocation, stringsArray, expressionsArray); + break; } - AstExpr* expression = parseExpr(); + bool errorWhileChecking = false; - expressions.push_back(expression); + switch (lexer.current().type) + { + case Lexeme::InterpStringMid: + case Lexeme::InterpStringEnd: + { + errorWhileChecking = true; + nextLexeme(); + expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, expected expression inside '{}'")); + break; + } + case Lexeme::BrokenString: + { + errorWhileChecking = true; + nextLexeme(); + expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '`'?")); + break; + } + default: + expressions.push_back(parseExpr()); + } + + if (errorWhileChecking) + { + break; + } switch (lexer.current().type) { @@ -2706,14 +2728,18 @@ AstExpr* Parser::parseInterpString() break; case Lexeme::BrokenInterpDoubleBrace: nextLexeme(); - return reportExprError(location, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + return reportExprError(endLocation, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); case Lexeme::BrokenString: nextLexeme(); - return reportExprError(location, {}, "Malformed interpolated string, did you forget to add a '}'?"); + return reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '}'?"); default: - return reportExprError(location, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); + return reportExprError(endLocation, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); } } while (true); + + AstArray> stringsArray = copy(strings); + AstArray expressionsArray = copy(expressions); + return allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); } AstExpr* Parser::parseNumber() diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index 81e74941..a14d5f59 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -183,7 +183,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprInterpString") AstStat* statement = expectParseStatement("local a = `var = {x}`"); std::string_view expected = - R"({"type":"AstStatLocal","location":"0,0 - 0,18","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,6 - 0,7"}],"values":[{"type":"AstExprInterpString","location":"0,10 - 0,18","strings":["var = ",""],"expressions":[{"type":"AstExprGlobal","location":"0,18 - 0,19","global":"x"}]}]})"; + R"({"type":"AstStatLocal","location":"0,0 - 0,21","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,6 - 0,7"}],"values":[{"type":"AstExprInterpString","location":"0,10 - 0,21","strings":["var = ",""],"expressions":[{"type":"AstExprGlobal","location":"0,18 - 0,19","global":"x"}]}]})"; CHECK(toJson(statement) == expected); } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 9a5c3411..45baec2c 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -7,6 +7,7 @@ #include "Luau/StringUtils.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -2708,13 +2709,77 @@ a = if temp then even else abc@3 CHECK(ac.entryMap.count("abcdef")); } -TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string") +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_constant") { + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"(f(`@1`))"); + auto ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"(f(`@1 {"a"}`))"); + ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"(f(`{"a"} @1`))"); + ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"(f(`{"a"} @1 {"b"}`))"); + ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_expression") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + check(R"(f(`expression = {@1}`))"); + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_expression_with_comments") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"(f(`expression = {--[[ bla bla bla ]]@1`))"); auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); CHECK_EQ(ac.context, AutocompleteContext::Expression); + + check(R"(f(`expression = {@1 --[[ bla bla bla ]]`))"); + ac = autocomplete('1'); + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_as_singleton") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"( + --!strict + local function f(a: "cat" | "dog") end + + f(`@1`) + f(`uhhh{'try'}@2`) + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("cat")); + CHECK_EQ(ac.context, AutocompleteContext::String); + + ac = autocomplete('2'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack")