diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index c5d7501d..4f61eaa5 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -95,6 +95,7 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0); + void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); @@ -399,6 +400,11 @@ private: */ DenseHashSet, HashBoolNamePair> duplicateTypeAliases; + /** + * A set of incorrect class definitions which is used to avoid a second-pass analysis. + */ + DenseHashSet incorrectClassDefinitions{nullptr}; + std::vector> deferredQuantification; }; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index ccb1490a..4d34d56e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -355,6 +355,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo unifierState.skipCacheForType.clear(); duplicateTypeAliases.clear(); + incorrectClassDefinitions.clear(); return std::move(currentModule); } @@ -522,6 +523,10 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A prototype(scope, *typealias, subLevel); ++subLevel; } + else if (const auto& declaredClass = stat->as()) + { + prototype(scope, *declaredClass); + } } auto protoIter = sorted.begin(); @@ -1626,8 +1631,10 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea } } -void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) { + Name className(declaredClass.name.value); + std::optional superTy = std::nullopt; if (declaredClass.superName) { @@ -1637,6 +1644,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar if (!lookupType) { reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type}); + incorrectClassDefinitions.insert(&declaredClass); return; } @@ -1648,22 +1656,42 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar { reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); - + incorrectClassDefinitions.insert(&declaredClass); return; } } - Name className(declaredClass.name.value); - TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); ClassTypeVar* ctv = getMutable(classTy); - TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); - TableTypeVar* metatable = getMutable(metaTy); ctv->metatable = metaTy; - scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +{ + Name className(declaredClass.name.value); + + // Don't bother checking if the class definition was incorrect + if (incorrectClassDefinitions.find(&declaredClass)) + return; + + std::optional binding; + if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end()) + binding = it->second; + + // This class definition must have been `prototype()`d first. + if (!binding) + ice("Not predeclared"); + + TypeId classTy = binding->type; + ClassTypeVar* ctv = getMutable(classTy); + + if (!ctv->metatable) + ice("No metatable for declared class"); + + TableTypeVar* metatable = getMutable(*ctv->metatable); for (const AstDeclaredClassProp& prop : declaredClass.props) {