diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 4eaa5969..2eea191a 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -95,6 +95,7 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0); + void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); @@ -399,6 +400,11 @@ private: */ DenseHashSet, HashBoolNamePair> duplicateTypeAliases; + /** + * A set of incorrect class definitions which is used to avoid a second-pass analysis. + */ + DenseHashSet incorrectClassDefinitions{nullptr}; + std::vector> deferredQuantification; }; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 30eb7ea2..dfd08e54 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -49,6 +49,7 @@ LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) +LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false) LUAU_FASTFLAGVARIABLE(LuauCallableClasses, false) namespace Luau @@ -357,6 +358,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo unifierState.skipCacheForType.clear(); duplicateTypeAliases.clear(); + incorrectClassDefinitions.clear(); return std::move(currentModule); } @@ -524,6 +526,10 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A prototype(scope, *typealias, subLevel); ++subLevel; } + else if (const auto& declaredClass = stat->as(); FFlag::LuauDeclareClassPrototype && declaredClass) + { + prototype(scope, *declaredClass); + } } auto protoIter = sorted.begin(); @@ -1671,8 +1677,10 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea } } -void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { + LUAU_ASSERT(FFlag::LuauDeclareClassPrototype); + std::optional superTy = std::nullopt; if (declaredClass.superName) { @@ -1682,6 +1690,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar if (!lookupType) { reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type}); + incorrectClassDefinitions.insert(&declaredClass); return; } @@ -1693,7 +1702,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar { reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); - + incorrectClassDefinitions.insert(&declaredClass); return; } } @@ -1702,61 +1711,174 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); ClassTypeVar* ctv = getMutable(classTy); - TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); - TableTypeVar* metatable = getMutable(metaTy); ctv->metatable = metaTy; - scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; +} - for (const AstDeclaredClassProp& prop : declaredClass.props) +void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +{ + if (FFlag::LuauDeclareClassPrototype) { - Name propName(prop.name.value); - TypeId propTy = resolveType(scope, *prop.ty); + Name className(declaredClass.name.value); - bool assignToMetatable = isMetamethod(propName); - Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + // Don't bother checking if the class definition was incorrect + if (incorrectClassDefinitions.find(&declaredClass)) + return; - // Function types always take 'self', but this isn't reflected in the - // parsed annotation. Add it here. - if (prop.isMethod) + std::optional binding; + if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end()) + binding = it->second; + + // This class definition must have been `prototype()`d first. + if (!binding) + ice("Class not predeclared"); + + TypeId classTy = binding->type; + ClassTypeVar* ctv = getMutable(classTy); + + if (!ctv->metatable) + ice("No metatable for declared class"); + + TableTypeVar* metatable = getMutable(*ctv->metatable); + for (const AstDeclaredClassProp& prop : declaredClass.props) { - if (FunctionTypeVar* ftv = getMutable(propTy)) + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, *prop.ty); + + bool assignToMetatable = isMetamethod(propName); + Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) { - ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); - ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); - ftv->hasSelf = true; + if (FunctionTypeVar* ftv = getMutable(propTy)) + { + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + ftv->hasSelf = true; + } } - } - if (assignTo.count(propName) == 0) - { - assignTo[propName] = {propTy}; - } - else - { - TypeId currentTy = assignTo[propName].type; - - // We special-case this logic to keep the intersection flat; otherwise we - // would create a ton of nested intersection types. - if (const IntersectionTypeVar* itv = get(currentTy)) + if (assignTo.count(propName) == 0) { - std::vector options = itv->parts; - options.push_back(propTy); - TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); - - assignTo[propName] = {newItv}; - } - else if (get(currentTy)) - { - TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); - - assignTo[propName] = {intersection}; + assignTo[propName] = {propTy}; } else { - reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + TypeId currentTy = assignTo[propName].type; + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionTypeVar* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); + + assignTo[propName] = {newItv}; + } + else if (get(currentTy)) + { + TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); + + assignTo[propName] = {intersection}; + } + else + { + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } + } + } + } + else + { + std::optional superTy = std::nullopt; + if (declaredClass.superName) + { + Name superName = Name(declaredClass.superName->value); + std::optional lookupType = scope->lookupType(superName); + + if (!lookupType) + { + reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type}); + return; + } + + // We don't have generic classes, so this assertion _should_ never be hit. + LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); + superTy = lookupType->type; + + if (!get(follow(*superTy))) + { + reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", + superName.c_str(), declaredClass.name.value)}); + return; + } + } + + Name className(declaredClass.name.value); + + TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); + + ClassTypeVar* ctv = getMutable(classTy); + TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); + TableTypeVar* metatable = getMutable(metaTy); + + ctv->metatable = metaTy; + + scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + + for (const AstDeclaredClassProp& prop : declaredClass.props) + { + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, *prop.ty); + + bool assignToMetatable = isMetamethod(propName); + Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; + + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) + { + if (FunctionTypeVar* ftv = getMutable(propTy)) + { + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + ftv->hasSelf = true; + } + } + + if (assignTo.count(propName) == 0) + { + assignTo[propName] = {propTy}; + } + else + { + TypeId currentTy = assignTo[propName].type; + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionTypeVar* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); + + assignTo[propName] = {newItv}; + } + else if (get(currentTy)) + { + TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); + + assignTo[propName] = {intersection}; + } + else + { + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } } } } diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 684b47e9..3556f0f1 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -396,4 +396,26 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") CHECK_EQ(toString(requireType("y")), "string"); } +TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") +{ + ScopedFastFlag LuauDeclareClassPrototype("LuauDeclareClassPrototype", true); + + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + declare class Channel + Messages: { Message } + OnMessage: (message: Message) -> () + end + + declare class Message + Text: string + Channel: Channel + end + )", + "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE(result.success); +} + TEST_SUITE_END();