mirror of
https://github.com/luau-lang/luau.git
synced 2025-04-04 10:50:54 +01:00
Handle cyclically referenced declared classes (#729)
Closes #631 Performs a "two pass" approach, where we first call `prototype` on the class definition. The prototype phase checks for errors and creates a base CTV for the class, which other types can then reference. The second `check` phase fills it out with all the defined properties/methods Co-authored-by: vegorov-rbx <75688451+vegorov-rbx@users.noreply.github.com>
This commit is contained in:
parent
95d9c6d194
commit
7dbe47f4dd
3 changed files with 191 additions and 41 deletions
|
@ -95,6 +95,7 @@ struct TypeChecker
|
||||||
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
|
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
|
||||||
|
|
||||||
void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0);
|
void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0);
|
||||||
|
void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
|
||||||
|
|
||||||
void checkBlock(const ScopePtr& scope, const AstStatBlock& statement);
|
void checkBlock(const ScopePtr& scope, const AstStatBlock& statement);
|
||||||
void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement);
|
void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement);
|
||||||
|
@ -399,6 +400,11 @@ private:
|
||||||
*/
|
*/
|
||||||
DenseHashSet<std::pair<bool, Name>, HashBoolNamePair> duplicateTypeAliases;
|
DenseHashSet<std::pair<bool, Name>, HashBoolNamePair> duplicateTypeAliases;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A set of incorrect class definitions which is used to avoid a second-pass analysis.
|
||||||
|
*/
|
||||||
|
DenseHashSet<const AstStatDeclareClass*> incorrectClassDefinitions{nullptr};
|
||||||
|
|
||||||
std::vector<std::pair<TypeId, ScopePtr>> deferredQuantification;
|
std::vector<std::pair<TypeId, ScopePtr>> deferredQuantification;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false)
|
||||||
LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false)
|
LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false)
|
||||||
LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false)
|
LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false)
|
||||||
LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false)
|
LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false)
|
||||||
|
LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false)
|
||||||
|
|
||||||
namespace Luau
|
namespace Luau
|
||||||
{
|
{
|
||||||
|
@ -356,6 +357,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
|
||||||
unifierState.skipCacheForType.clear();
|
unifierState.skipCacheForType.clear();
|
||||||
|
|
||||||
duplicateTypeAliases.clear();
|
duplicateTypeAliases.clear();
|
||||||
|
incorrectClassDefinitions.clear();
|
||||||
|
|
||||||
return std::move(currentModule);
|
return std::move(currentModule);
|
||||||
}
|
}
|
||||||
|
@ -523,6 +525,10 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A
|
||||||
prototype(scope, *typealias, subLevel);
|
prototype(scope, *typealias, subLevel);
|
||||||
++subLevel;
|
++subLevel;
|
||||||
}
|
}
|
||||||
|
else if (const auto& declaredClass = stat->as<AstStatDeclareClass>(); FFlag::LuauDeclareClassPrototype && declaredClass)
|
||||||
|
{
|
||||||
|
prototype(scope, *declaredClass);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto protoIter = sorted.begin();
|
auto protoIter = sorted.begin();
|
||||||
|
@ -1670,8 +1676,10 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass)
|
void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass)
|
||||||
{
|
{
|
||||||
|
LUAU_ASSERT(FFlag::LuauDeclareClassPrototype);
|
||||||
|
|
||||||
std::optional<TypeId> superTy = std::nullopt;
|
std::optional<TypeId> superTy = std::nullopt;
|
||||||
if (declaredClass.superName)
|
if (declaredClass.superName)
|
||||||
{
|
{
|
||||||
|
@ -1681,6 +1689,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar
|
||||||
if (!lookupType)
|
if (!lookupType)
|
||||||
{
|
{
|
||||||
reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type});
|
reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type});
|
||||||
|
incorrectClassDefinitions.insert(&declaredClass);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1692,7 +1701,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar
|
||||||
{
|
{
|
||||||
reportError(declaredClass.location,
|
reportError(declaredClass.location,
|
||||||
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)});
|
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)});
|
||||||
|
incorrectClassDefinitions.insert(&declaredClass);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1701,61 +1710,174 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar
|
||||||
|
|
||||||
TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName));
|
TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName));
|
||||||
ClassTypeVar* ctv = getMutable<ClassTypeVar>(classTy);
|
ClassTypeVar* ctv = getMutable<ClassTypeVar>(classTy);
|
||||||
|
|
||||||
TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level});
|
TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level});
|
||||||
TableTypeVar* metatable = getMutable<TableTypeVar>(metaTy);
|
|
||||||
|
|
||||||
ctv->metatable = metaTy;
|
ctv->metatable = metaTy;
|
||||||
|
|
||||||
scope->exportedTypeBindings[className] = TypeFun{{}, classTy};
|
scope->exportedTypeBindings[className] = TypeFun{{}, classTy};
|
||||||
|
}
|
||||||
|
|
||||||
for (const AstDeclaredClassProp& prop : declaredClass.props)
|
void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass)
|
||||||
|
{
|
||||||
|
if (FFlag::LuauDeclareClassPrototype)
|
||||||
{
|
{
|
||||||
Name propName(prop.name.value);
|
Name className(declaredClass.name.value);
|
||||||
TypeId propTy = resolveType(scope, *prop.ty);
|
|
||||||
|
|
||||||
bool assignToMetatable = isMetamethod(propName);
|
// Don't bother checking if the class definition was incorrect
|
||||||
Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props;
|
if (incorrectClassDefinitions.find(&declaredClass))
|
||||||
|
return;
|
||||||
|
|
||||||
// Function types always take 'self', but this isn't reflected in the
|
std::optional<TypeFun> binding;
|
||||||
// parsed annotation. Add it here.
|
if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end())
|
||||||
if (prop.isMethod)
|
binding = it->second;
|
||||||
|
|
||||||
|
// This class definition must have been `prototype()`d first.
|
||||||
|
if (!binding)
|
||||||
|
ice("Class not predeclared");
|
||||||
|
|
||||||
|
TypeId classTy = binding->type;
|
||||||
|
ClassTypeVar* ctv = getMutable<ClassTypeVar>(classTy);
|
||||||
|
|
||||||
|
if (!ctv->metatable)
|
||||||
|
ice("No metatable for declared class");
|
||||||
|
|
||||||
|
TableTypeVar* metatable = getMutable<TableTypeVar>(*ctv->metatable);
|
||||||
|
for (const AstDeclaredClassProp& prop : declaredClass.props)
|
||||||
{
|
{
|
||||||
if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(propTy))
|
Name propName(prop.name.value);
|
||||||
|
TypeId propTy = resolveType(scope, *prop.ty);
|
||||||
|
|
||||||
|
bool assignToMetatable = isMetamethod(propName);
|
||||||
|
Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props;
|
||||||
|
|
||||||
|
// Function types always take 'self', but this isn't reflected in the
|
||||||
|
// parsed annotation. Add it here.
|
||||||
|
if (prop.isMethod)
|
||||||
{
|
{
|
||||||
ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}});
|
if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(propTy))
|
||||||
ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes});
|
{
|
||||||
ftv->hasSelf = true;
|
ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}});
|
||||||
|
ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes});
|
||||||
|
ftv->hasSelf = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (assignTo.count(propName) == 0)
|
if (assignTo.count(propName) == 0)
|
||||||
{
|
|
||||||
assignTo[propName] = {propTy};
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
TypeId currentTy = assignTo[propName].type;
|
|
||||||
|
|
||||||
// We special-case this logic to keep the intersection flat; otherwise we
|
|
||||||
// would create a ton of nested intersection types.
|
|
||||||
if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(currentTy))
|
|
||||||
{
|
{
|
||||||
std::vector<TypeId> options = itv->parts;
|
assignTo[propName] = {propTy};
|
||||||
options.push_back(propTy);
|
|
||||||
TypeId newItv = addType(IntersectionTypeVar{std::move(options)});
|
|
||||||
|
|
||||||
assignTo[propName] = {newItv};
|
|
||||||
}
|
|
||||||
else if (get<FunctionTypeVar>(currentTy))
|
|
||||||
{
|
|
||||||
TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}});
|
|
||||||
|
|
||||||
assignTo[propName] = {intersection};
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())});
|
TypeId currentTy = assignTo[propName].type;
|
||||||
|
|
||||||
|
// We special-case this logic to keep the intersection flat; otherwise we
|
||||||
|
// would create a ton of nested intersection types.
|
||||||
|
if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(currentTy))
|
||||||
|
{
|
||||||
|
std::vector<TypeId> options = itv->parts;
|
||||||
|
options.push_back(propTy);
|
||||||
|
TypeId newItv = addType(IntersectionTypeVar{std::move(options)});
|
||||||
|
|
||||||
|
assignTo[propName] = {newItv};
|
||||||
|
}
|
||||||
|
else if (get<FunctionTypeVar>(currentTy))
|
||||||
|
{
|
||||||
|
TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}});
|
||||||
|
|
||||||
|
assignTo[propName] = {intersection};
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
std::optional<TypeId> superTy = std::nullopt;
|
||||||
|
if (declaredClass.superName)
|
||||||
|
{
|
||||||
|
Name superName = Name(declaredClass.superName->value);
|
||||||
|
std::optional<TypeFun> lookupType = scope->lookupType(superName);
|
||||||
|
|
||||||
|
if (!lookupType)
|
||||||
|
{
|
||||||
|
reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We don't have generic classes, so this assertion _should_ never be hit.
|
||||||
|
LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0);
|
||||||
|
superTy = lookupType->type;
|
||||||
|
|
||||||
|
if (!get<ClassTypeVar>(follow(*superTy)))
|
||||||
|
{
|
||||||
|
reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'",
|
||||||
|
superName.c_str(), declaredClass.name.value)});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Name className(declaredClass.name.value);
|
||||||
|
|
||||||
|
TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName));
|
||||||
|
|
||||||
|
ClassTypeVar* ctv = getMutable<ClassTypeVar>(classTy);
|
||||||
|
TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level});
|
||||||
|
TableTypeVar* metatable = getMutable<TableTypeVar>(metaTy);
|
||||||
|
|
||||||
|
ctv->metatable = metaTy;
|
||||||
|
|
||||||
|
scope->exportedTypeBindings[className] = TypeFun{{}, classTy};
|
||||||
|
|
||||||
|
for (const AstDeclaredClassProp& prop : declaredClass.props)
|
||||||
|
{
|
||||||
|
Name propName(prop.name.value);
|
||||||
|
TypeId propTy = resolveType(scope, *prop.ty);
|
||||||
|
|
||||||
|
bool assignToMetatable = isMetamethod(propName);
|
||||||
|
Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props;
|
||||||
|
|
||||||
|
// Function types always take 'self', but this isn't reflected in the
|
||||||
|
// parsed annotation. Add it here.
|
||||||
|
if (prop.isMethod)
|
||||||
|
{
|
||||||
|
if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(propTy))
|
||||||
|
{
|
||||||
|
ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}});
|
||||||
|
ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes});
|
||||||
|
ftv->hasSelf = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (assignTo.count(propName) == 0)
|
||||||
|
{
|
||||||
|
assignTo[propName] = {propTy};
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
TypeId currentTy = assignTo[propName].type;
|
||||||
|
|
||||||
|
// We special-case this logic to keep the intersection flat; otherwise we
|
||||||
|
// would create a ton of nested intersection types.
|
||||||
|
if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(currentTy))
|
||||||
|
{
|
||||||
|
std::vector<TypeId> options = itv->parts;
|
||||||
|
options.push_back(propTy);
|
||||||
|
TypeId newItv = addType(IntersectionTypeVar{std::move(options)});
|
||||||
|
|
||||||
|
assignTo[propName] = {newItv};
|
||||||
|
}
|
||||||
|
else if (get<FunctionTypeVar>(currentTy))
|
||||||
|
{
|
||||||
|
TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}});
|
||||||
|
|
||||||
|
assignTo[propName] = {intersection};
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -396,4 +396,26 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props")
|
||||||
CHECK_EQ(toString(requireType("y")), "string");
|
CHECK_EQ(toString(requireType("y")), "string");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes")
|
||||||
|
{
|
||||||
|
ScopedFastFlag LuauDeclareClassPrototype("LuauDeclareClassPrototype", true);
|
||||||
|
|
||||||
|
unfreeze(typeChecker.globalTypes);
|
||||||
|
LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"(
|
||||||
|
declare class Channel
|
||||||
|
Messages: { Message }
|
||||||
|
OnMessage: (message: Message) -> ()
|
||||||
|
end
|
||||||
|
|
||||||
|
declare class Message
|
||||||
|
Text: string
|
||||||
|
Channel: Channel
|
||||||
|
end
|
||||||
|
)",
|
||||||
|
"@test");
|
||||||
|
freeze(typeChecker.globalTypes);
|
||||||
|
|
||||||
|
REQUIRE(result.success);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_SUITE_END();
|
TEST_SUITE_END();
|
||||||
|
|
Loading…
Add table
Reference in a new issue