diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 40ba150a..4176fe2c 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -283,6 +283,25 @@ struct SimplifyConstraint TypeId ty; }; +// push_function_type_constraint expectedFunctionType => functionType +// +// Attempt to "push" the types of `expectedFunctionType` into `functionType`, +// assuming that `expr` is a lambda who's ungeneralized type is `functionType`. +// Similar to `FunctionCheckConstraint`. For example: +// +// local Foo = {} :: { bar : (number) -> () } +// +// function Foo.bar(x) end +// +// This will force `x` to be inferred as `number`. +struct PushFunctionTypeConstraint +{ + TypeId expectedFunctionType; + TypeId functionType; + NotNull expr; + bool isSelf; +}; + using ConstraintV = Variant< SubtypeConstraint, PackSubtypeConstraint, @@ -302,7 +321,8 @@ using ConstraintV = Variant< ReducePackConstraint, EqualityConstraint, TableCheckConstraint, - SimplifyConstraint>; + SimplifyConstraint, + PushFunctionTypeConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index fc97316f..2f1d1887 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -494,6 +494,9 @@ private: ); TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right); + + void updateRValueRefinements(const ScopePtr& scope, DefId def, TypeId ty) const; + void updateRValueRefinements(Scope* scope, DefId def, TypeId ty) const; }; } // namespace Luau diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 331bc3f8..5d281eff 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -256,6 +256,8 @@ public: bool tryDispatch(const SimplifyConstraint& c, NotNull constraint); + bool tryDispatch(const PushFunctionTypeConstraint& c, NotNull constraint); + // for a, ... in some_table do // also handles __iter metamethod bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index 67e52240..31c93494 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -46,6 +46,8 @@ struct DataFlowGraph const RefinementKey* getRefinementKey(const AstExpr* expr) const; + std::optional getSymbolFromDef(const Def* def) const; + private: DataFlowGraph(NotNull defArena, NotNull keyArena); @@ -63,6 +65,7 @@ private: // There's no AstStatDeclaration, and it feels useless to introduce it just to enforce an invariant in one place. // All keys in this maps are really only statements that ambiently declares a symbol. DenseHashMap declaredDefs{nullptr}; + DenseHashMap defToSymbol{nullptr}; DenseHashMap astRefinementKeys{nullptr}; friend struct DataFlowGraphBuilder; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index d94956b7..cf61e021 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -152,7 +152,9 @@ struct Module // Once a module has been typechecked, we clone its public interface into a // separate arena. This helps us to force Type ownership into a DAG rather // than a DCG. - void clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice); + void clonePublicInterface_DEPRECATED(NotNull builtinTypes, InternalErrorReporter& ice); + + void clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice, SolverMode mode); }; } // namespace Luau diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index ad3c1e28..ee8aa04f 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -4,6 +4,7 @@ #include "Luau/VisitType.h" LUAU_FASTFLAG(LuauEagerGeneralization4) +LUAU_FASTFLAG(LuauPushFunctionTypesInFunctionStatement) namespace Luau { @@ -222,6 +223,14 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes_DEPRECATED() const rci.traverse(tcc->exprType); } + if (FFlag::LuauPushFunctionTypesInFunctionStatement) + { + if (auto pftc = get(*this)) + { + rci.traverse(pftc->functionType); + } + } + return types; } @@ -318,6 +327,14 @@ TypeIds Constraint::getMaybeMutatedFreeTypes() const rci.traverse(tcc->exprType); } + if (FFlag::LuauPushFunctionTypesInFunctionStatement) + { + if (auto pftc = get(*this)) + { + rci.traverse(pftc->functionType); + } + } + return types; } diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index b904fe7f..f938c615 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -52,6 +52,8 @@ LUAU_FASTFLAG(LuauRemoveTypeCallsForReadWriteProps) LUAU_FASTFLAGVARIABLE(LuauFollowTypeAlias) LUAU_FASTFLAGVARIABLE(LuauFollowExistingTypeFunction) LUAU_FASTFLAGVARIABLE(LuauRefineTablesWithReadType) +LUAU_FASTFLAGVARIABLE(LuauFragmentAutocompleteTracksRValueRefinements) +LUAU_FASTFLAGVARIABLE(LuauPushFunctionTypesInFunctionStatement) namespace Luau { @@ -776,8 +778,10 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat if (partition.shouldAppendNilType) ty = createTypeFunctionInstance(builtinTypeFunctions().weakoptionalFunc, {ty}, {}, scope, location); - - scope->rvalueRefinements[def] = ty; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(scope, def, ty); + else + scope->rvalueRefinements[def] = ty; } } @@ -1283,7 +1287,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFor* for_) DefId def = dfg->getDef(for_->var); forScope->lvalueTypes[def] = annotationTy; - forScope->rvalueRefinements[def] = annotationTy; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(forScope, def, annotationTy); + else + forScope->rvalueRefinements[def] = annotationTy; visit(forScope, for_->body); @@ -1435,9 +1442,15 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocalFuncti DefId def = dfg->getDef(function->name); scope->lvalueTypes[def] = functionType; - scope->rvalueRefinements[def] = functionType; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(scope, def, functionType); + else + scope->rvalueRefinements[def] = functionType; sig.bodyScope->lvalueTypes[def] = sig.signature; - sig.bodyScope->rvalueRefinements[def] = sig.signature; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(sig.bodyScope, def, sig.signature); + else + sig.bodyScope->rvalueRefinements[def] = sig.signature; Checkpoint start = checkpoint(this); checkFunctionBody(sig.bodyScope, function->func); @@ -1497,20 +1510,77 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f { sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; sig.bodyScope->lvalueTypes[def] = sig.signature; - sig.bodyScope->rvalueRefinements[def] = sig.signature; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(sig.bodyScope, def, sig.signature); + else + sig.bodyScope->rvalueRefinements[def] = sig.signature; } else if (AstExprGlobal* globalName = function->name->as()) { sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; sig.bodyScope->lvalueTypes[def] = sig.signature; - sig.bodyScope->rvalueRefinements[def] = sig.signature; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(sig.bodyScope, def, sig.signature); + else + sig.bodyScope->rvalueRefinements[def] = sig.signature; } else if (AstExprIndexName* indexName = function->name->as()) { - sig.bodyScope->rvalueRefinements[def] = sig.signature; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(sig.bodyScope, def, sig.signature); + else + sig.bodyScope->rvalueRefinements[def] = sig.signature; + } + + if (FFlag::LuauPushFunctionTypesInFunctionStatement) + { + if (auto indexName = function->name->as()) + { + auto beginProp = checkpoint(this); + auto [fn, _] = check(scope, indexName); + auto endProp = checkpoint(this); + auto pftc = addConstraint( + sig.signatureScope, + function->func->location, + PushFunctionTypeConstraint{ + fn, + sig.signature, + NotNull{function->func}, + /* isSelf */ indexName->op == ':', + } + ); + forEachConstraint( + beginProp, + endProp, + this, + [pftc](const ConstraintPtr& c) + { + pftc->dependencies.emplace_back(c.get()); + } + ); + auto beginBody = checkpoint(this); + checkFunctionBody(sig.bodyScope, function->func); + auto endBody = checkpoint(this); + forEachConstraint( + beginBody, + endBody, + this, + [pftc](const ConstraintPtr& c) + { + c->dependencies.push_back(pftc); + } + ); + } + else + { + checkFunctionBody(sig.bodyScope, function->func); + } + } + else + { + checkFunctionBody(sig.bodyScope, function->func); } - checkFunctionBody(sig.bodyScope, function->func); Checkpoint end = checkpoint(this); TypeId generalizedType = arena->addType(BlockedType{}); @@ -1582,7 +1652,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f if (generalizedType == nullptr) ice->ice("generalizedType == nullptr", function->location); - scope->rvalueRefinements[def] = generalizedType; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(scope, def, generalizedType); + else + scope->rvalueRefinements[def] = generalizedType; return ControlFlow::None; } @@ -1890,7 +1963,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareGlob DefId def = dfg->getDef(global); rootScope->lvalueTypes[def] = globalTy; - rootScope->rvalueRefinements[def] = globalTy; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(rootScope, def, globalTy); + else + rootScope->rvalueRefinements[def] = globalTy; return ControlFlow::None; } @@ -2149,7 +2225,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareFunc DefId def = dfg->getDef(global); rootScope->lvalueTypes[def] = fnType; - rootScope->rvalueRefinements[def] = fnType; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(rootScope, def, fnType); + else + rootScope->rvalueRefinements[def] = fnType; return ControlFlow::None; } @@ -2397,7 +2476,10 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* DefId def = dfg->getDef(targetLocal); scope->lvalueTypes[def] = resultTy; // TODO: typestates: track this as an assignment - scope->rvalueRefinements[def] = resultTy; // TODO: typestates: track this as an assignment + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(scope, def, resultTy); // TODO: typestates: track this as an assignment + else + scope->rvalueRefinements[def] = resultTy; // TODO: typestates: track this as an assignment // HACK: If we have a targetLocal, it has already been added to the // inferredBindings table. We want to replace it so that we don't @@ -2419,7 +2501,10 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* if (auto def = dfg->getDefOptional(targetExpr)) { scope->lvalueTypes[*def] = resultTy; - scope->rvalueRefinements[*def] = resultTy; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(scope, *def, resultTy); + else + scope->rvalueRefinements[*def] = resultTy; } } @@ -2729,7 +2814,10 @@ Inference ConstraintGenerator::checkIndexName( if (auto ty = lookup(scope, indexLocation, key->def)) return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; - scope->rvalueRefinements[key->def] = result; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(scope, key->def, result); + else + scope->rvalueRefinements[key->def] = result; } if (key) @@ -2763,8 +2851,10 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprIndexExpr* in { if (auto ty = lookup(scope, indexExpr->location, key->def)) return Inference{*ty, refinementArena.proposition(key, builtinTypes->truthyType)}; - - scope->rvalueRefinements[key->def] = result; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(scope, key->def, result); + else + scope->rvalueRefinements[key->def] = result; } auto c = addConstraint(scope, indexExpr->expr->location, HasIndexerConstraint{result, obj, indexType}); @@ -3533,7 +3623,10 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu DefId def = dfg->getDef(fn->self); signatureScope->lvalueTypes[def] = selfType; - signatureScope->rvalueRefinements[def] = selfType; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(signatureScope, def, selfType); + else + signatureScope->rvalueRefinements[def] = selfType; } for (size_t i = 0; i < fn->args.size; ++i) @@ -3558,7 +3651,10 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu DefId def = dfg->getDef(local); signatureScope->lvalueTypes[def] = argTy; - signatureScope->rvalueRefinements[def] = argTy; + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + updateRValueRefinements(signatureScope, def, argTy); + else + signatureScope->rvalueRefinements[def] = argTy; } TypePackId varargPack = nullptr; @@ -4499,4 +4595,17 @@ TypeId ConstraintGenerator::simplifyUnion(const ScopePtr& scope, Location locati return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result; } +void ConstraintGenerator::updateRValueRefinements(const ScopePtr& scope, DefId def, TypeId ty) const +{ + updateRValueRefinements(scope.get(), def, ty); +} + +void ConstraintGenerator::updateRValueRefinements(Scope* scope, DefId def, TypeId ty) const +{ + scope->rvalueRefinements[def] = ty; + if (auto sym = dfg->getSymbolFromDef(def)) + scope->refinements[*sym] = ty; +} + + } // namespace Luau diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index adfedacd..5ba5def2 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -42,6 +42,8 @@ LUAU_FASTFLAGVARIABLE(LuauMissingFollowInAssignIndexConstraint) LUAU_FASTFLAGVARIABLE(LuauRemoveTypeCallsForReadWriteProps) LUAU_FASTFLAGVARIABLE(LuauTableLiteralSubtypeCheckFunctionCalls) LUAU_FASTFLAGVARIABLE(LuauUseOrderedTypeSetsInConstraints) +LUAU_FASTFLAG(LuauPushFunctionTypesInFunctionStatement) +LUAU_FASTFLAG(LuauAvoidExcessiveTypeCopying) namespace Luau { @@ -776,7 +778,7 @@ void ConstraintSolver::generalizeOneType(TypeId ty) ty = follow(ty); const FreeType* freeTy = get(ty); - std::string saveme = toString(ty, opts); + std::string saveme = FFlag::DebugLuauLogSolver ? toString(ty, opts) : "[FFlag::DebugLuauLogSolver Off]"; // Some constraints (like prim) will also replace a free type with something // concrete. If so, our work is already done. @@ -904,6 +906,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*eqc, constraint); else if (auto sc = get(*constraint)) success = tryDispatch(*sc, constraint); + else if (auto pftc = get(*constraint)) + success = tryDispatch(*pftc, constraint); else LUAU_ASSERT(false); @@ -1845,7 +1849,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullargs.size && i + typeOffset < expectedArgs.size() && i + typeOffset < argPackHead.size(); ++i) { - const TypeId expectedArgTy = follow(expectedArgs[i + typeOffset]); + TypeId expectedArgTy = follow(expectedArgs[i + typeOffset]); const TypeId actualArgTy = follow(argPackHead[i + typeOffset]); AstExpr* expr = unwrapGroup(c.callSite->args.data[i]); @@ -1875,21 +1879,38 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullis() || expr->is() || expr->is() || expr->is() || (FFlag::LuauTableLiteralSubtypeCheckFunctionCalls && expr->is())) { - ReferentialReplacer replacer{arena, NotNull{&replacements}, NotNull{&replacementPacks}}; - if (auto res = replacer.substitute(expectedArgTy)) + if (FFlag::LuauAvoidExcessiveTypeCopying) { - if (FFlag::LuauTableLiteralSubtypeCheckFunctionCalls) + if (ContainsGenerics::hasGeneric(expectedArgTy, NotNull{&genericTypesAndPacks})) { - // If we do this replacement and there are type - // functions in the final type, then we need to - // ensure those get reduced. - InstantiationQueuer queuer{constraint->scope, constraint->location, this}; - queuer.traverse(*res); + ReferentialReplacer replacer{arena, NotNull{&replacements}, NotNull{&replacementPacks}}; + if (auto res = replacer.substitute(expectedArgTy)) + { + InstantiationQueuer queuer{constraint->scope, constraint->location, this}; + queuer.traverse(*res); + expectedArgTy = *res; + } } - u2.unify(actualArgTy, *res); + u2.unify(actualArgTy, expectedArgTy); } else - u2.unify(actualArgTy, expectedArgTy); + { + ReferentialReplacer replacer{arena, NotNull{&replacements}, NotNull{&replacementPacks}}; + if (auto res = replacer.substitute(expectedArgTy)) + { + if (FFlag::LuauTableLiteralSubtypeCheckFunctionCalls) + { + // If we do this replacement and there are type + // functions in the final type, then we need to + // ensure those get reduced. + InstantiationQueuer queuer{constraint->scope, constraint->location, this}; + queuer.traverse(*res); + } + u2.unify(actualArgTy, *res); + } + else + u2.unify(actualArgTy, expectedArgTy); + } } else if (!FFlag::LuauTableLiteralSubtypeCheckFunctionCalls && expr->is() && !ContainsGenerics::hasGeneric(expectedArgTy, NotNull{&genericTypesAndPacks})) @@ -2507,6 +2528,23 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullremainingProps > 0); lhsTable->remainingProps -= 1; + + // For some code like: + // + // local T = {} + // function T:foo() + // return T:bar(5) + // end + // function T:bar(i) + // return i + // end + // + // We need to wake up an unsealed table if it previously + // was blocked on missing a member. In the above, we may + // try to solve for `hasProp T "bar"`, block, then never + // wake up without forcing a constraint. + if (FFlag::LuauPushFunctionTypesInFunctionStatement) + unblock(lhsType, constraint->location); } return true; @@ -2905,6 +2943,104 @@ bool ConstraintSolver::tryDispatch(const SimplifyConstraint& c, NotNull(ty); + return !found; + } + + bool visit(TypePackId ty) override + { + found = found || is(ty); + return !found; + } + + static bool hasAnyGeneric(TypeId ty) + { + ContainsAnyGeneric cg; + cg.traverse(ty); + return cg.found; + } + + static bool hasAnyGeneric(TypePackId tp) + { + ContainsAnyGeneric cg; + cg.traverse(tp); + return cg.found; + } +}; + +} // namespace + +bool ConstraintSolver::tryDispatch(const PushFunctionTypeConstraint& c, NotNull constraint) +{ + // NOTE: This logic could probably be combined with that of + // `FunctionCheckConstraint`, but that constraint currently does a few + // different things. + + auto expectedFn = get(follow(c.expectedFunctionType)); + auto fn = get(follow(c.functionType)); + + // If either the expected type or given type aren't functions, then bail. + if (!expectedFn || !fn) + return true; + + auto expectedParams = begin(expectedFn->argTypes); + auto params = begin(fn->argTypes); + + if (expectedParams == end(expectedFn->argTypes) || params == end(fn->argTypes)) + return true; + + if (c.isSelf) + { + if (is(follow(*params))) + { + shiftReferences(*params, *expectedParams); + bind(constraint, *params, *expectedParams); + } + expectedParams++; + params++; + } + + // `idx` is an index into the arguments of the attached `AstExprFunction`, + // we don't need to increment it with respect to arguments in case of a + // `self` type. + size_t idx = 0; + while (idx < c.expr->args.size && expectedParams != end(expectedFn->argTypes) && params != end(fn->argTypes)) + { + // If we have an explicitly annotated parameter, a non-free type for + // the parameter, or the expected type contains a generic, bail. + // - Annotations should be respected above all else; + // - a non-free-type is unexpected, so just bail; + // - a generic in the expected type might cause us to leak a generic, so bail. + if (!c.expr->args.data[idx]->annotation && get(*params) && !ContainsAnyGeneric::hasAnyGeneric(*expectedParams)) + { + shiftReferences(*params, *expectedParams); + bind(constraint, *params, *expectedParams); + } + expectedParams++; + params++; + idx++; + } + + if (!c.expr->returnAnnotation && get(fn->retTypes) && !ContainsAnyGeneric::hasAnyGeneric(expectedFn->retTypes)) + bind(constraint, fn->retTypes, expectedFn->retTypes); + + return true; +} + bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force) { iteratorTy = follow(iteratorTy); diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 91f4d93e..1de9d743 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -17,6 +17,8 @@ LUAU_FASTFLAGVARIABLE(LuauDfgScopeStackNotNull) LUAU_FASTFLAGVARIABLE(LuauDoNotAddUpvalueTypesToLocalType) LUAU_FASTFLAGVARIABLE(LuauDfgIfBlocksShouldRespectControlFlow) LUAU_FASTFLAGVARIABLE(LuauDfgAllowUpdatesInLoops) +LUAU_FASTFLAG(LuauFragmentAutocompleteTracksRValueRefinements) +LUAU_FASTFLAGVARIABLE(LuauDfgForwardNilFromAndOr) namespace Luau { @@ -107,6 +109,14 @@ const RefinementKey* DataFlowGraph::getRefinementKey(const AstExpr* expr) const return nullptr; } +std::optional DataFlowGraph::getSymbolFromDef(const Def* def) const +{ + if (auto ref = defToSymbol.find(def)) + return *ref; + + return std::nullopt; +} + std::optional DfgScope::lookup(Symbol symbol) const { for (const DfgScope* current = this; current; current = current->parent) @@ -1051,12 +1061,16 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprLocal* l) { DefId def = lookup(l->local, l->local->location); const RefinementKey* key = keyArena->leaf(def); + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + graph.defToSymbol[def] = l->local; return {def, key}; } DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprGlobal* g) { DefId def = lookup(g->name, g->location); + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + graph.defToSymbol[def] = g->name; return {def, keyArena->leaf(def)}; } @@ -1216,10 +1230,23 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprUnary* u) DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprBinary* b) { - visitExpr(b->left); - visitExpr(b->right); + if (FFlag::LuauDfgForwardNilFromAndOr) + { + auto left = visitExpr(b->left); + auto right = visitExpr(b->right); + // I think there's some subtlety here. There are probably cases where + // X or Y / X and Y can _never_ "be subscripted." + auto subscripted = (b->op == AstExprBinary::And || b->op == AstExprBinary::Or) && + (containsSubscriptedDefinition(left.def) || containsSubscriptedDefinition(right.def)); + return {defArena->freshCell(Symbol{}, b->location, subscripted), nullptr}; + } + else + { + visitExpr(b->left); + visitExpr(b->right); - return {defArena->freshCell(Symbol{}, b->location), nullptr}; + return {defArena->freshCell(Symbol{}, b->location), nullptr}; + } } DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprTypeAssertion* t) diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index 8a07f040..d1d0f7a1 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -40,6 +40,8 @@ LUAU_FASTFLAG(LuauExpectedTypeVisitor) LUAU_FASTFLAGVARIABLE(LuauPopulateRefinedTypesInFragmentFromOldSolver) LUAU_FASTFLAG(LuauUseWorkspacePropToChooseSolver) LUAU_FASTFLAGVARIABLE(LuauFragmentRequiresCanBeResolvedToAModule) +LUAU_FASTFLAG(LuauFragmentAutocompleteTracksRValueRefinements) +LUAU_FASTFLAGVARIABLE(LuauPopulateSelfTypesInFragment) namespace Luau { @@ -421,6 +423,15 @@ FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* st { if (globFun->location.contains(cursorPos)) { + if (FFlag::LuauPopulateSelfTypesInFragment) + { + if (auto local = globFun->func->self) + { + localStack.push_back(local); + localMap[local->name] = local; + } + } + for (AstLocal* loc : globFun->func->args) { localStack.push_back(loc); @@ -603,7 +614,17 @@ struct UsageFinder : public AstVisitor if (auto ref = dfg->getRefinementKey(expr)) mentionedDefs.insert(ref->def); if (auto local = expr->as()) - localBindingsReferenced.emplace_back(dfg->getDef(local), local->local); + { + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + { + auto def = dfg->getDef(local); + localBindingsReferenced.emplace_back(def, local->local); + symbolsToRefine.emplace_back(def, Symbol(local->local)); + } + else + localBindingsReferenced.emplace_back(dfg->getDef(local), local->local); + + } return true; } @@ -611,6 +632,11 @@ struct UsageFinder : public AstVisitor { if (FFlag::LuauGlobalVariableModuleIsolation) globalDefsToPrePopulate.emplace_back(global->name, dfg->getDef(global)); + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + { + auto def = dfg->getDef(global); + symbolsToRefine.emplace_back(def, Symbol(global->name)); + } return true; } @@ -633,6 +659,7 @@ struct UsageFinder : public AstVisitor std::vector> referencedImportedBindings{{"", ""}}; std::vector> globalDefsToPrePopulate; std::vector globalFunctionsReferenced; + std::vector> symbolsToRefine; }; // Runs the `UsageFinder` traversal on the fragment and grabs all of the types that are @@ -685,7 +712,20 @@ void cloneTypesFromFragment( } } - if (FFlag::LuauPopulateRefinedTypesInFragmentFromOldSolver && !staleModule->checkedInNewSolver) + if (FFlag::LuauFragmentAutocompleteTracksRValueRefinements) + { + for (const auto& [d, syms] : f.symbolsToRefine) + { + for (const Scope* stale = staleScope; stale; stale = stale->parent.get()) + { + if (auto res = stale->refinements.find(syms); res != stale->refinements.end()) + { + destScope->rvalueRefinements[d] = Luau::cloneIncremental(res->second, *destArena, cloneState, destScope); + } + } + } + } + else if (FFlag::LuauPopulateRefinedTypesInFragmentFromOldSolver && !staleModule->checkedInNewSolver) { for (const auto& [d, loc] : f.localBindingsReferenced) { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 3adde176..f00c2700 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1672,7 +1672,10 @@ ModulePtr check( } unfreeze(result->interfaceTypes); - result->clonePublicInterface(builtinTypes, *iceHandler); + if (FFlag::LuauUseWorkspacePropToChooseSolver) + result->clonePublicInterface(builtinTypes, *iceHandler, SolverMode::New); + else + result->clonePublicInterface_DEPRECATED(builtinTypes, *iceHandler); if (FFlag::DebugLuauForbidInternalTypes) { diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 2758edff..e7c09bf2 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -275,7 +275,7 @@ Module::~Module() unfreeze(internalTypes); } -void Module::clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice) +void Module::clonePublicInterface_DEPRECATED(NotNull builtinTypes, InternalErrorReporter& ice) { CloneState cloneState{builtinTypes}; @@ -319,6 +319,50 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr this->exportedTypeBindings = moduleScope->exportedTypeBindings; } +void Module::clonePublicInterface(NotNull builtinTypes, InternalErrorReporter& ice, SolverMode mode) +{ + CloneState cloneState{builtinTypes}; + + ScopePtr moduleScope = getModuleScope(); + + TypePackId returnType = moduleScope->returnType; + std::optional varargPack = mode == SolverMode::New ? std::nullopt : moduleScope->varargPack; + + TxnLog log; + ClonePublicInterface clonePublicInterface{&log, builtinTypes, this}; + + returnType = clonePublicInterface.cloneTypePack(returnType); + + moduleScope->returnType = returnType; + if (varargPack) + { + varargPack = clonePublicInterface.cloneTypePack(*varargPack); + moduleScope->varargPack = varargPack; + } + + for (auto& [name, tf] : moduleScope->exportedTypeBindings) + { + tf = clonePublicInterface.cloneTypeFun(tf); + } + + for (auto& [name, ty] : declaredGlobals) + { + ty = clonePublicInterface.cloneType(ty); + } + + if (FFlag::LuauUserTypeFunctionAliases) + { + for (auto& tf : typeFunctionAliases) + { + *tf = clonePublicInterface.cloneTypeFun(*tf); + } + } + + // Copy external stuff over to Module itself + this->returnType = moduleScope->returnType; + this->exportedTypeBindings = moduleScope->exportedTypeBindings; +} + bool Module::hasModuleScope() const { return !scopes.empty(); diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index 9de1dc3c..b1afca8e 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -24,6 +24,8 @@ LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAGVARIABLE(LuauNewNonStrictVisitTypes2) LUAU_FASTFLAGVARIABLE(LuauNewNonStrictFixGenericTypePacks) +LUAU_FASTFLAGVARIABLE(LuauNewNonStrictMoreUnknownSymbols) +LUAU_FASTFLAGVARIABLE(LuauNewNonStrictNoErrorsPassingNever) namespace Luau { @@ -353,12 +355,24 @@ struct NonStrictTypeChecker NonStrictContext condB = visit(ifStatement->condition, ValueContext::RValue); NonStrictContext branchContext; - // If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error - if (ifStatement->elsebody) + if (FFlag::LuauNewNonStrictMoreUnknownSymbols) { NonStrictContext thenBody = visit(ifStatement->thenbody); - NonStrictContext elseBody = visit(ifStatement->elsebody); - branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody); + if (ifStatement->elsebody) + { + NonStrictContext elseBody = visit(ifStatement->elsebody); + branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody); + } + } + else + { + // If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error + if (ifStatement->elsebody) + { + NonStrictContext thenBody = visit(ifStatement->thenbody); + NonStrictContext elseBody = visit(ifStatement->elsebody); + branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody); + } } return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext); @@ -629,6 +643,13 @@ struct NonStrictTypeChecker NonStrictContext visit(AstExprCall* call) { + if (FFlag::LuauNewNonStrictMoreUnknownSymbols) + { + visit(call->func, ValueContext::RValue); + for (auto arg : call->args) + visit(arg, ValueContext::RValue); + } + NonStrictContext fresh{}; TypeId* originalCallTy = module->astOriginalCallTypes.find(call->func); if (!originalCallTy) @@ -715,7 +736,15 @@ struct NonStrictTypeChecker { AstExpr* arg = arguments[i]; if (auto runTimeFailureType = willRunTimeError(arg, fresh)) - reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location); + { + if (FFlag::LuauNewNonStrictNoErrorsPassingNever) + { + if (!get(follow(*runTimeFailureType))) + reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location); + } + else + reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location); + } } if (arguments.size() < argTypes.size()) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 34e50bf0..5b730eb8 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -2023,6 +2023,8 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) return "table_check " + tos(c.expectedType) + " :> " + tos(c.exprType); else if constexpr (std::is_same_v) return "simplify " + tos(c.ty); + else if constexpr (std::is_same_v) + return "push_function_type " + tos(c.expectedFunctionType) + " => " + tos(c.functionType); else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index 7bd03a09..291f170f 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -60,6 +60,7 @@ LUAU_FASTFLAGVARIABLE(LuauOccursCheckForRefinement) LUAU_FASTFLAGVARIABLE(LuauStuckTypeFunctionsStillDispatch) LUAU_FASTFLAG(LuauRefineTablesWithReadType) LUAU_FASTFLAGVARIABLE(LuauEmptyStringInKeyOf) +LUAU_FASTFLAGVARIABLE(LuauAvoidExcessiveTypeCopying) namespace Luau { @@ -2455,6 +2456,39 @@ struct RefineTypeScrubber : public Substitution }; +bool occurs(TypeId haystack, TypeId needle, DenseHashSet& seen) +{ + if (needle == haystack) + return true; + + if (seen.contains(haystack)) + return false; + + seen.insert(haystack); + + if (auto ut = get(haystack)) + { + for (auto option : ut) + if (occurs(option, needle, seen)) + return true; + } + + if (auto it = get(haystack)) + { + for (auto part : it) + if (occurs(part, needle, seen)) + return true; + } + + return false; +} + +bool occurs(TypeId haystack, TypeId needle) +{ + DenseHashSet seen{nullptr}; + return occurs(haystack, needle, seen); +} + } // namespace TypeFunctionReductionResult refineTypeFunction( @@ -2485,9 +2519,12 @@ TypeFunctionReductionResult refineTypeFunction( // Instead, we can clip the recursive part: // // t1 where t1 = refine => refine - RefineTypeScrubber rts{ctx, instance}; - if (auto result = rts.substitute(targetTy)) - targetTy = *result; + if (!FFlag::LuauAvoidExcessiveTypeCopying || occurs(targetTy, instance)) + { + RefineTypeScrubber rts{ctx, instance}; + if (auto result = rts.substitute(targetTy)) + targetTy = *result; + } } std::vector discriminantTypes; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e283e5dd..35a66b9b 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,6 +32,7 @@ LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification) LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauUseWorkspacePropToChooseSolver) LUAU_FASTFLAGVARIABLE(LuauReduceCheckBinaryExprStackPressure) @@ -303,7 +304,11 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo normalizer.clearCaches(); normalizer.arena = nullptr; - currentModule->clonePublicInterface(builtinTypes, *iceHandler); + if (FFlag::LuauUseWorkspacePropToChooseSolver) + currentModule->clonePublicInterface(builtinTypes, *iceHandler, SolverMode::Old); + else + currentModule->clonePublicInterface_DEPRECATED(builtinTypes, *iceHandler); + freeze(currentModule->internalTypes); freeze(currentModule->interfaceTypes); diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index cf41c81c..2c91264c 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauCodeGenAllocationCheck) + #if defined(_WIN32) #ifndef WIN32_LEAN_AND_MEAN @@ -52,7 +54,7 @@ static void freePagesImpl(uint8_t* mem, size_t size) CODEGEN_ASSERT(!"failed to deallocate block memory"); } -static void makePagesExecutable(uint8_t* mem, size_t size) +static void makePagesExecutable_DEPRECATED(uint8_t* mem, size_t size) { CODEGEN_ASSERT((uintptr_t(mem) & (kPageSize - 1)) == 0); CODEGEN_ASSERT(size == alignToPageSize(size)); @@ -62,6 +64,15 @@ static void makePagesExecutable(uint8_t* mem, size_t size) CODEGEN_ASSERT(!"Failed to change page protection"); } +[[nodiscard]] static bool makePagesExecutable(uint8_t* mem, size_t size) +{ + CODEGEN_ASSERT((uintptr_t(mem) & (kPageSize - 1)) == 0); + CODEGEN_ASSERT(size == alignToPageSize(size)); + + DWORD oldProtect; + return VirtualProtect(mem, size, PAGE_EXECUTE_READ, &oldProtect) != 0; +} + static void flushInstructionCache(uint8_t* mem, size_t size) { #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) @@ -91,7 +102,7 @@ static void freePagesImpl(uint8_t* mem, size_t size) CODEGEN_ASSERT(!"Failed to deallocate block memory"); } -static void makePagesExecutable(uint8_t* mem, size_t size) +static void makePagesExecutable_DEPRECATED(uint8_t* mem, size_t size) { CODEGEN_ASSERT((uintptr_t(mem) & (kPageSize - 1)) == 0); CODEGEN_ASSERT(size == alignToPageSize(size)); @@ -100,6 +111,14 @@ static void makePagesExecutable(uint8_t* mem, size_t size) CODEGEN_ASSERT(!"Failed to change page protection"); } +[[nodiscard]] static bool makePagesExecutable(uint8_t* mem, size_t size) +{ + CODEGEN_ASSERT((uintptr_t(mem) & (kPageSize - 1)) == 0); + CODEGEN_ASSERT(size == alignToPageSize(size)); + + return mprotect(mem, size, PROT_READ | PROT_EXEC) == 0; +} + static void flushInstructionCache(uint8_t* mem, size_t size) { #ifdef __APPLE__ @@ -184,7 +203,16 @@ bool CodeAllocator::allocate( size_t pageAlignedSize = alignToPageSize(startOffset + totalSize); - makePagesExecutable(blockPos, pageAlignedSize); + if (FFlag::LuauCodeGenAllocationCheck) + { + if (!makePagesExecutable(blockPos, pageAlignedSize)) + return false; + } + else + { + makePagesExecutable_DEPRECATED(blockPos, pageAlignedSize); + } + flushInstructionCache(blockPos + codeOffset, codeSize); result = blockPos + startOffset; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index bbe733dd..5178a9b4 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -22,6 +22,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauEagerGeneralization4) LUAU_FASTFLAG(LuauExpectedTypeVisitor) LUAU_FASTFLAG(LuauImplicitTableIndexerKeys2) +LUAU_FASTFLAG(LuauPushFunctionTypesInFunctionStatement) using namespace Luau; @@ -4694,4 +4695,43 @@ TEST_CASE_FIXTURE(ACFixture, "bidirectional_autocomplete_in_function_call") CHECK_EQ(ac.entryMap.count("right"), 1); } +TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_via_bidirectional_self") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauPushFunctionTypesInFunctionStatement, true}, + }; + + check(R"( + type IAccount = { + __index: IAccount, + new : (string, number) -> Account, + report: (self: Account) -> (), + } + + export type Account = setmetatable<{ + name: string, + balance: number + }, IAccount>; + + local Account = {} :: IAccount + Account.__index = Account + + function Account.new(name, balance): Account + local self = {} + self.name = name + self.balance = balance + return setmetatable(self, Account) + end + + function Account:report() + print("My balance is: " .. self.@1) + end + )"); + + auto ac = autocomplete('1'); + CHECK_EQ(ac.entryMap.count("name"), 1); + CHECK_EQ(ac.entryMap.count("balance"), 1); +} + TEST_SUITE_END(); diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index b0bff82e..b94083bc 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -35,6 +35,9 @@ LUAU_FASTFLAG(LuauFragmentAutocompleteIfRecommendations) LUAU_FASTFLAG(LuauPopulateRefinedTypesInFragmentFromOldSolver) LUAU_FASTFLAG(LuauSolverAgnosticStringification) LUAU_FASTFLAG(LuauFragmentRequiresCanBeResolvedToAModule) +LUAU_FASTFLAG(LuauFragmentAutocompleteTracksRValueRefinements) +LUAU_FASTFLAG(LuauPopulateSelfTypesInFragment) +LUAU_FASTFLAG(LuauPushFunctionTypesInFunctionStatement) static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) { @@ -70,6 +73,8 @@ struct FragmentAutocompleteFixtureImpl : BaseType ScopedFastFlag luauGlobalVariableModuleIsolation{FFlag::LuauGlobalVariableModuleIsolation, true}; ScopedFastFlag luauFragmentAutocompleteIfRecommendations{FFlag::LuauFragmentAutocompleteIfRecommendations, true}; ScopedFastFlag luauPopulateRefinedTypesInFragmentFromOldSolver{FFlag::LuauPopulateRefinedTypesInFragmentFromOldSolver, true}; + ScopedFastFlag sffLuauFragmentAutocompleteTracksRValueRefinement{FFlag::LuauFragmentAutocompleteTracksRValueRefinements, true}; + ScopedFastFlag sffLuauPopulateSelfTypesInFragment{FFlag::LuauPopulateSelfTypesInFragment, true}; FragmentAutocompleteFixtureImpl() : BaseType(true) @@ -263,6 +268,15 @@ struct FragmentAutocompleteBuiltinsFixture : FragmentAutocompleteFixtureImpl() { + } + + Frontend& getFrontend() override + { + if (frontend) + return *frontend; + Frontend& f = BuiltinsFixture::getFrontend(); + Luau::unfreeze(f.globals.globalTypes); + Luau::unfreeze(f.globalsForAutocomplete.globalTypes); const std::string fakeVecDecl = R"( declare class FakeVec function dot(self, x: FakeVec) : FakeVec @@ -281,6 +295,10 @@ end addGlobalBinding(getFrontend().globals, "game", Binding{getBuiltins()->anyType}); addGlobalBinding(getFrontend().globalsForAutocomplete, "game", Binding{getBuiltins()->anyType}); + Luau::freeze(f.globals.globalTypes); + Luau::freeze(f.globalsForAutocomplete.globalTypes); + + return *frontend; } }; @@ -3814,7 +3832,7 @@ end }); } -TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "tagged_union_completion_first_branch_of_union_new_solver" * doctest::skip(true)) +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "tagged_union_completion_first_branch_of_union_new_solver") { // TODO: CLI-155619 - Fragment autocomplete needs to use stale refinement information for modules typechecked in the new solver as well const std::string source = R"( @@ -3847,9 +3865,8 @@ end }); } -TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "tagged_union_completion_second_branch_of_union_new_solver" * doctest::skip(true)) +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "tagged_union_completion_second_branch_of_union_new_solver") { - // TODO: CLI-155619 - Fragment autocomplete needs to use stale refinement information for modules typechecked in the new solver as well const std::string source = R"( type Ok = { type: "ok", value: T} type Err = { type : "err", error : E} @@ -3913,6 +3930,162 @@ require(script.A). CHECK(result.result->acResults.entryMap.count("foo")); } +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "self_types_provide_rich_autocomplete") +{ + ScopedFastFlag sff{FFlag::LuauPushFunctionTypesInFunctionStatement, true}; + + const std::string source = R"( +type Service = { + Start: (self: Service) -> (), + Prop: number +} + +local Service: Service = {} + +function Service:Start() + +end +)"; + const std::string dest = R"( +type Service = { + Start: (self: Service) -> (), + Prop: number +} + +local Service: Service = {} + +function Service:Start() + self. +end +)"; + + autocompleteFragmentInBothSolvers( + source, + dest, + Position{9, 9}, + [](auto& result) + { + CHECK(!result.result->acResults.entryMap.empty()); + CHECK(result.result->acResults.entryMap.count("Prop")); + CHECK(result.result->acResults.entryMap.count("Start")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "self_with_fancy_metatable_setting_new_solver") +{ + ScopedFastFlag sff{FFlag::LuauPushFunctionTypesInFunctionStatement, true}; + const std::string source = R"( + type IAccount = { + __index: IAccount, + new : (string, number) -> Account, + report: (self: Account) -> (), + } + + export type Account = setmetatable<{ + name: string, + balance: number + }, IAccount>; + + local Account = {} :: IAccount + Account.__index = Account + + function Account.new(name, balance): Account + local self = {} + self.name = name + self.balance = balance + return setmetatable(self, Account) + end + + function Account:report() + print("My balance is: " .. ) + end +)"; + + const std::string dest = R"( + type IAccount = { + __index: IAccount, + new : (string, number) -> Account, + report: (self: Account) -> (), + } + + export type Account = setmetatable<{ + name: string, + balance: number + }, IAccount>; + + local Account = {} :: IAccount + Account.__index = Account + + function Account.new(name, balance): Account + local self = {} + self.name = name + self.balance = balance + return setmetatable(self, Account) + end + + function Account:report() + print("My balance is: " .. self. ) + end +)"; + + autocompleteFragmentInNewSolver( + source, + dest, + Position{23, 44}, + [](auto& result) + { + CHECK(!result.result->acResults.entryMap.empty()); + CHECK(result.result->acResults.entryMap.count("new")); + CHECK(result.result->acResults.entryMap.count("report")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "self_with_colon_good_recommendations") +{ + ScopedFastFlag sff{FFlag::LuauPushFunctionTypesInFunctionStatement, true}; + + const std::string source = R"( +type Service = { + Start: (self: Service) -> (), + Prop: number +} + +local Service: Service = {} + +function Service:Start() + +end +)"; + const std::string dest = R"( +type Service = { + Start: (self: Service) -> (), + Prop: number +} + +local Service: Service = {} + +function Service:Start() + self: +end +)"; + + autocompleteFragmentInBothSolvers( + source, + dest, + Position{9, 9}, + [](auto& result) + { + CHECK(!result.result->acResults.entryMap.empty()); + CHECK(result.result->acResults.entryMap.count("Prop")); + CHECK(result.result->acResults.entryMap["Prop"].wrongIndexType); + CHECK(result.result->acResults.entryMap.count("Start")); + CHECK(!result.result->acResults.entryMap["Start"].wrongIndexType); + } + ); +} + // NOLINTEND(bugprone-unchecked-optional-access) TEST_SUITE_END(); diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index ca7041ef..1d9eec9f 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -17,6 +17,8 @@ LUAU_FASTFLAG(LuauNewNonStrictVisitTypes2) LUAU_FASTFLAG(LuauNewNonStrictFixGenericTypePacks) +LUAU_FASTFLAG(LuauNewNonStrictMoreUnknownSymbols) +LUAU_FASTFLAG(LuauNewNonStrictNoErrorsPassingNever) using namespace Luau; @@ -190,12 +192,22 @@ local x abs(lower(x)) )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 4), "abs", result); + if (FFlag::LuauNewNonStrictMoreUnknownSymbols) + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 4), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 10), "lower", result); + } + else + { + + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 4), "abs", result); + } } -TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_warns_with_never_local") +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_does_not_warn_with_never_local") { CheckResult result = checkNonStrict(R"( local x : never @@ -205,10 +217,14 @@ else lower(x) end )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 8), "abs", result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(5, 10), "lower", result); + if (FFlag::LuauNewNonStrictNoErrorsPassingNever) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 8), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(5, 10), "lower", result); + } } TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_warns_nil_branches") @@ -251,7 +267,13 @@ if cond() then end )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauNewNonStrictMoreUnknownSymbols) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 8), "abs", result); + } + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_no_else_err_in_cond") @@ -267,14 +289,31 @@ end } TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_expr_should_warn") +{ + CheckResult result = checkNonStrict(R"( +local x = 42 +local y = if cond() then abs(x) else lower(x) +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 43), "lower", result); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_expr_should_not_warn_for_never") { CheckResult result = checkNonStrict(R"( local x : never local y = if cond() then abs(x) else lower(x) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 29), "abs", result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 43), "lower", result); + + if (FFlag::LuauNewNonStrictNoErrorsPassingNever) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 29), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 43), "lower", result); + } } TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "if_then_else_expr_doesnt_warn_else_branch") @@ -355,10 +394,20 @@ function f(x) lower(x) end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 8), "abs", result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 10), "lower", result); - NONSTRICT_REQUIRE_FUNC_DEFINITION_ERR(Position(1, 11), "x", result); + + + if (FFlag::LuauNewNonStrictNoErrorsPassingNever) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + NONSTRICT_REQUIRE_FUNC_DEFINITION_ERR(Position(1, 11), "x", result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(3, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 8), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 10), "lower", result); + NONSTRICT_REQUIRE_FUNC_DEFINITION_ERR(Position(1, 11), "x", result); + } } TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_sequencing_errors_2") @@ -369,10 +418,19 @@ local t = {function(x) lower(x) end} )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 8), "abs", result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 10), "lower", result); - CHECK(toString(result.errors[2]) == "Argument x with type 'unknown' is used in a way that will run time error"); + + if (FFlag::LuauNewNonStrictNoErrorsPassingNever) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == "Argument x with type 'unknown' is used in a way that will run time error"); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(3, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(2, 8), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 10), "lower", result); + CHECK(toString(result.errors[2]) == "Argument x with type 'unknown' is used in a way that will run time error"); + } } TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "local_fn_produces_error") @@ -401,7 +459,7 @@ local y = function() lower(x) end TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_if_warns_never") { CheckResult result = checkNonStrict(R"( -function f(x) +function f(x: never) if cond() then abs(x) else @@ -409,9 +467,15 @@ function f(x) end end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 12), "abs", result); - NONSTRICT_REQUIRE_CHECKED_ERR(Position(5, 14), "lower", result); + + if (FFlag::LuauNewNonStrictNoErrorsPassingNever) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(3, 12), "abs", result); + NONSTRICT_REQUIRE_CHECKED_ERR(Position(5, 14), "lower", result); + } } TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "function_def_if_no_else") @@ -721,4 +785,39 @@ TEST_CASE_FIXTURE(Fixture, "incomplete_function_annotation") LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "unknown_globals_in_function_calls") +{ + ScopedFastFlag sff{FFlag::LuauNewNonStrictMoreUnknownSymbols, true}; + + CheckResult result = check(Mode::Nonstrict, R"( + local function foo() : () + bar() + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + const UnknownSymbol* err = get(result.errors[0]); + CHECK_EQ(err->name, "bar"); + CHECK_EQ(err->context, UnknownSymbol::Context::Binding); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_globals_in_one_sided_conditionals") +{ + ScopedFastFlag sff{FFlag::LuauNewNonStrictMoreUnknownSymbols, true}; + + CheckResult result = check(Mode::Nonstrict, R"( + local function foo(cond) : () + if cond then + bar() + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + const UnknownSymbol* err = get(result.errors[0]); + CHECK_EQ(err->name, "bar"); + CHECK_EQ(err->context, UnknownSymbol::Context::Binding); +} + + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 599dfed3..b04c1984 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -33,6 +33,7 @@ LUAU_FASTFLAG(LuauAvoidGenericsLeakingDuringFunctionCallCheck) LUAU_FASTFLAG(LuauRemoveTypeCallsForReadWriteProps) LUAU_FASTFLAG(LuauSolverAgnosticStringification) LUAU_FASTFLAG(LuauSuppressErrorsForMultipleNonviableOverloads) +LUAU_FASTFLAG(LuauPushFunctionTypesInFunctionStatement) TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -3196,4 +3197,114 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_pack_variadic") )")); } +TEST_CASE_FIXTURE(Fixture, "table_annotated_explicit_self") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauPushFunctionTypesInFunctionStatement, true}, + }; + + CheckResult results = check(R"( + type MyObject = { + fn: (self: MyObject) -> number, + field: number + } + + local Foo = {} :: MyObject + + function Foo:fn() + local _ = self + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, results); + LUAU_REQUIRE_ERROR(results, FunctionExitsWithoutReturning); // `Foo:fn` should return a `number` + CHECK_EQ("MyObject", toString(requireTypeAtPosition({9, 24}))); +} + + +TEST_CASE_FIXTURE(Fixture, "oss_1871") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauPushFunctionTypesInFunctionStatement, true}, + }; + + + LUAU_REQUIRE_NO_ERRORS(check(R"( + export type Test = { + [string]: (string) -> () + } + + local TestTbl: Test = {} + + function TestTbl.Hello(Param) + local _ = Param + end + )")); + + CHECK_EQ("string", toString(requireTypeAtPosition({8, 25}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "io_manager_oop_ish") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauPushFunctionTypesInFunctionStatement, true}, + }; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + type IIOManager = { + __index: IIOManager, + write: (self: IOManager, text: string, label: string?) -> number, + } + + export type IOManager = setmetatable<{ + buffer: {string}, + memory: { [string]: number } + }, IIOManager>; + + local IO = {} :: IIOManager + IO.__index = IO + + function IO:write(text, label) + local _ = self + local _ = text + local _ = label + return 42 + end + + return IO + )")); + CHECK_EQ("IOManager", toString(requireTypeAtPosition({15, 25}))); + CHECK_EQ("string", toString(requireTypeAtPosition({16, 25}))); + CHECK_EQ("string?", toString(requireTypeAtPosition({17, 25}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "generic_function_statement") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauPushFunctionTypesInFunctionStatement, true}, + }; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + type Object = { + foobar: (number, string, T) -> T + } + + local Obj = {} :: Object + function Obj.foobar(bing, quxx, dunno) + local _ = bing + local _ = quxx + return dunno + end + )")); + + CHECK_EQ("number", toString(requireTypeAtPosition({7, 24}))); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 24}))); + // NOTE: This specifically _isn't_ `T` as defined by `Object.foobar` + CHECK_EQ("a", toString(requireTypeAtPosition({9, 21}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index c015319a..d923e13e 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -18,6 +18,8 @@ LUAU_FASTFLAG(LuauReportSubtypingErrors) LUAU_FASTFLAG(LuauEagerGeneralization4) LUAU_FASTFLAG(LuauStuckTypeFunctionsStillDispatch) LUAU_FASTFLAG(LuauRemoveTypeCallsForReadWriteProps) +LUAU_FASTFLAG(DebugLuauAssertOnForcedConstraint) +LUAU_FASTFLAG(LuauPushFunctionTypesInFunctionStatement) using namespace Luau; @@ -1234,6 +1236,13 @@ TEST_CASE_FIXTURE(Fixture, "generic_table_method") TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") { + // Prior to `LuauPushFunctionTypesInFunctionStatement`, we _always_ forced + // a constraint when solving this block. + ScopedFastFlag sffs[] = { + {FFlag::DebugLuauAssertOnForcedConstraint, true}, + {FFlag::LuauPushFunctionTypesInFunctionStatement, true}, + }; + CheckResult result = check(R"( local T = {} diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index c94b8f77..701f1deb 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -13,6 +13,8 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauTableLiteralSubtypeSpecificCheck2) LUAU_FASTFLAG(LuauRefineTablesWithReadType) LUAU_FASTFLAG(LuauReturnMappedGenericPacksFromSubtyping) +LUAU_FASTFLAG(LuauPushFunctionTypesInFunctionStatement) +LUAU_FASTFLAG(LuauTableLiteralSubtypeSpecificCheck2) TEST_SUITE_BEGIN("IntersectionTypes"); @@ -335,6 +337,11 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") { + ScopedFastFlag sffs[] = { + {FFlag::LuauPushFunctionTypesInFunctionStatement, true}, + {FFlag::LuauTableLiteralSubtypeSpecificCheck2, true}, + }; + CheckResult result = check(R"( type X = { x: (number) -> number } type Y = { y: (string) -> string } @@ -350,9 +357,14 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") if (FFlag::LuauSolverV2) { - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERROR_COUNT(3, result); CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); - CHECK_EQ(toString(result.errors[1]), "Cannot add property 'w' to table 'X & Y'"); + // I'm not writing this as a `toString` check, those are awful. + auto err1 = get(result.errors[1]); + REQUIRE(err1); + CHECK_EQ("number", toString(err1->givenType)); + CHECK_EQ("string", toString(err1->wantedType)); + CHECK_EQ(toString(result.errors[2]), "Cannot add property 'w' to table 'X & Y'"); } else { diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 3a10fa5c..ec67a9d3 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -38,6 +38,7 @@ LUAU_FASTFLAG(LuauTableLiteralSubtypeCheckFunctionCalls) LUAU_FASTFLAG(LuauAvoidGenericsLeakingDuringFunctionCallCheck) LUAU_FASTFLAG(LuauRefineTablesWithReadType) LUAU_FASTFLAG(LuauSolverAgnosticStringification) +LUAU_FASTFLAG(LuauDfgForwardNilFromAndOr) TEST_SUITE_BEGIN("TableTests"); @@ -6087,4 +6088,30 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "oss_1450") CHECK_EQ(R"("Alt" | "Ctrl" | "Space" | "Tab")", toString(err->givenType)); } +TEST_CASE_FIXTURE(Fixture, "oss_1888_and_or_subscriptable") +{ + ScopedFastFlag _{FFlag::LuauDfgForwardNilFromAndOr, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + export type CachedValue = { + future: any, + timestamp: number, + ttl: number?, + } + + type Cache = { [string]: CachedValue } + type CacheMap = { [string]: Cache } + + local _caches: CacheMap = {} + + local CacheManager = {} + + function CacheManager:has(cacheName: string, id: string): boolean + local cache = _caches[cacheName] + local entry = cache and cache[id] + return entry ~= nil + end + )")); +} + TEST_SUITE_END();