diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 8153c3d5..d05623a8 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -34,6 +34,8 @@ LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues) +LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses) +LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) namespace Luau { @@ -653,6 +655,7 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* block) { std::unordered_map aliasDefinitionLocations; + std::unordered_map classDefinitionLocations; // In order to enable mutually-recursive type aliases, we need to // populate the type bindings before we actually check any of the @@ -750,6 +753,32 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc scope->privateTypeBindings[function->name.value] = std::move(typeFunction); aliasDefinitionLocations[function->name.value] = function->location; } + else if (auto classDeclaration = stat->as()) + { + if (!FFlag::LuauNewSolverPrePopulateClasses) + continue; + + if (scope->exportedTypeBindings.count(classDeclaration->name.value)) + { + auto it = classDefinitionLocations.find(classDeclaration->name.value); + LUAU_ASSERT(it != classDefinitionLocations.end()); + reportError(classDeclaration->location, DuplicateTypeDefinition{classDeclaration->name.value, it->second}); + continue; + } + + // A class might have no name if the code is syntactically + // illegal. We mustn't prepopulate anything in this case. + if (classDeclaration->name == kParseNameError) + continue; + + ScopePtr defnScope = childScope(classDeclaration, scope); + + TypeId initialType = arena->addType(BlockedType{}); + TypeFun initialFun{initialType}; + scope->exportedTypeBindings[classDeclaration->name.value] = std::move(initialFun); + + classDefinitionLocations[classDeclaration->name.value] = classDeclaration->location; + } } } @@ -1645,6 +1674,11 @@ static bool isMetamethod(const Name& name) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) { + // If a class with the same name was already defined, we skip over + auto bindingIt = scope->exportedTypeBindings.find(declaredClass->name.value); + if (FFlag::LuauNewSolverPrePopulateClasses && bindingIt == scope->exportedTypeBindings.end()) + return ControlFlow::None; + std::optional superTy = std::make_optional(builtinTypes->classType); if (declaredClass->superName) { @@ -1659,7 +1693,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas // We don't have generic classes, so this assertion _should_ never be hit. LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); - superTy = lookupType->type; + if (FFlag::LuauNewSolverPrePopulateClasses) + superTy = follow(lookupType->type); + else + superTy = lookupType->type; if (!get(follow(*superTy))) { @@ -1682,7 +1719,14 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas ctv->metatable = metaTy; - scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + + if (FFlag::LuauNewSolverPrePopulateClasses) + { + TypeId classBindTy = bindingIt->second.type; + emplaceType(asMutable(classBindTy), classTy); + } + else + scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; if (declaredClass->indexer) { @@ -2844,6 +2888,10 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, ttv->state = TableState::Unsealed; ttv->definitionModuleName = module->name; + if (FFlag::LuauNewSolverPopulateTableLocations) + { + ttv->definitionLocation = expr->location; + } ttv->scope = scope.get(); interiorTypes.back().push_back(ty); @@ -3301,7 +3349,16 @@ TypeId ConstraintGenerator::resolveTableType(const ScopePtr& scope, AstType* ty, ice->ice("Unexpected property access " + std::to_string(int(astIndexer->access))); } - return arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); + TypeId tableTy = arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); + TableType* ttv = getMutable(tableTy); + + if (FFlag::LuauNewSolverPopulateTableLocations) + { + ttv->definitionModuleName = module->name; + ttv->definitionLocation = tab->location; + } + + return tableTy; } TypeId ConstraintGenerator::resolveFunctionType( diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 34e08fe3..398f0aa5 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -33,6 +33,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack) +LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) namespace Luau { @@ -1109,9 +1110,15 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul target = follow(instantiated); } + if (FFlag::LuauNewSolverPopulateTableLocations) + { + // This is a new type - redefine the location. + ttv->definitionLocation = constraint->location; + ttv->definitionModuleName = currentModuleName; + } + ttv->instantiatedTypeParams = typeArguments; ttv->instantiatedTypePackParams = packArguments; - // TODO: Fill in definitionModuleName. } bindResult(target); @@ -1433,7 +1440,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullis() || expr->is() || expr->is() || expr->is()) + else if (expr->is() || expr->is() || expr->is() || + expr->is()) { Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; u2.unify(actualArgTy, expectedArgTy); @@ -2326,12 +2334,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl return true; } -bool ConstraintSolver::tryDispatchIterableFunction( - TypeId nextTy, - TypeId tableTy, - const IterableConstraint& c, - NotNull constraint -) +bool ConstraintSolver::tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint) { const FunctionType* nextFn = get(nextTy); // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 5a530e83..2ab90ab5 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauNewSolverPrePopulateClasses) + TEST_SUITE_BEGIN("DefinitionTests"); TEST_CASE_FIXTURE(Fixture, "definition_file_simple") @@ -492,11 +494,8 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_indexer") TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { - unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = frontend.loadDefinitionFile( - frontend.globals, - frontend.globals.globalScope, - R"( + ScopedFastFlag _{FFlag::LuauNewSolverPrePopulateClasses, true}; + loadDefinition(R"( declare class Channel Messages: { Message } OnMessage: (message: Message) -> () @@ -506,13 +505,19 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") Text: string Channel: Channel end - )", - "@test", - /* captureComments */ false - ); - freeze(frontend.globals.globalTypes); + )"); - REQUIRE(result.success); + CheckResult result = check(R"( + local a: Channel + local b = a.Messages[1] + local c = b.Channel + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Channel"); + CHECK_EQ(toString(requireType("b")), "Message"); + CHECK_EQ(toString(requireType("c")), "Channel"); } TEST_CASE_FIXTURE(Fixture, "definition_file_has_source_module_name_set") diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 4f797690..5d5df24a 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauRequireCyclesDontAlwaysReturnAny) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauTypestateBuiltins2) +LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) using namespace Luau; @@ -466,7 +467,15 @@ local b: B.T = a LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::LuauSolverV2) - CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); + { + if (FFlag::LuauNewSolverPopulateTableLocations) + CHECK( + toString(result.errors.at(0)) == + "Type 'T' from 'game/A' could not be converted into 'T' from 'game/B'; at [read \"x\"], number is not exactly string" + ); + else + CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); + } else { const std::string expected = R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' @@ -507,7 +516,15 @@ local b: B.T = a LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::LuauSolverV2) - CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); + { + if (FFlag::LuauNewSolverPopulateTableLocations) + CHECK( + toString(result.errors.at(0)) == + "Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'; at [read \"x\"], number is not exactly string" + ); + else + CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); + } else { const std::string expected = R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'