Merge branch 'master' into merge

This commit is contained in:
Vyacheslav Egorov 2022-12-02 12:56:19 +02:00
commit 471ec75a60
4 changed files with 233 additions and 55 deletions

View file

@ -95,6 +95,7 @@ struct TypeChecker
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0); void prototype(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0);
void prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); void checkBlock(const ScopePtr& scope, const AstStatBlock& statement);
void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockWithoutRecursionCheck(const ScopePtr& scope, const AstStatBlock& statement);
@ -399,6 +400,11 @@ private:
*/ */
DenseHashSet<std::pair<bool, Name>, HashBoolNamePair> duplicateTypeAliases; DenseHashSet<std::pair<bool, Name>, HashBoolNamePair> duplicateTypeAliases;
/**
* A set of incorrect class definitions which is used to avoid a second-pass analysis.
*/
DenseHashSet<const AstStatDeclareClass*> incorrectClassDefinitions{nullptr};
std::vector<std::pair<TypeId, ScopePtr>> deferredQuantification; std::vector<std::pair<TypeId, ScopePtr>> deferredQuantification;
}; };

View file

@ -49,6 +49,8 @@ LUAU_FASTFLAGVARIABLE(LuauReportShadowedTypeAlias, false)
LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false) LUAU_FASTFLAGVARIABLE(LuauBetterMessagingOnCountMismatch, false)
LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false) LUAU_FASTFLAGVARIABLE(LuauArgMismatchReportFunctionLocation, false)
LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false) LUAU_FASTFLAGVARIABLE(LuauImplicitElseRefinement, false)
LUAU_FASTFLAGVARIABLE(LuauDeclareClassPrototype, false)
LUAU_FASTFLAGVARIABLE(LuauCallableClasses, false)
namespace Luau namespace Luau
{ {
@ -356,6 +358,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo
unifierState.skipCacheForType.clear(); unifierState.skipCacheForType.clear();
duplicateTypeAliases.clear(); duplicateTypeAliases.clear();
incorrectClassDefinitions.clear();
return std::move(currentModule); return std::move(currentModule);
} }
@ -523,6 +526,10 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A
prototype(scope, *typealias, subLevel); prototype(scope, *typealias, subLevel);
++subLevel; ++subLevel;
} }
else if (const auto& declaredClass = stat->as<AstStatDeclareClass>(); FFlag::LuauDeclareClassPrototype && declaredClass)
{
prototype(scope, *declaredClass);
}
} }
auto protoIter = sorted.begin(); auto protoIter = sorted.begin();
@ -1670,8 +1677,10 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea
} }
} }
void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) void TypeChecker::prototype(const ScopePtr& scope, const AstStatDeclareClass& declaredClass)
{ {
LUAU_ASSERT(FFlag::LuauDeclareClassPrototype);
std::optional<TypeId> superTy = std::nullopt; std::optional<TypeId> superTy = std::nullopt;
if (declaredClass.superName) if (declaredClass.superName)
{ {
@ -1681,6 +1690,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar
if (!lookupType) if (!lookupType)
{ {
reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type}); reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type});
incorrectClassDefinitions.insert(&declaredClass);
return; return;
} }
@ -1692,7 +1702,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar
{ {
reportError(declaredClass.location, reportError(declaredClass.location,
GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)});
incorrectClassDefinitions.insert(&declaredClass);
return; return;
} }
} }
@ -1701,61 +1711,174 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar
TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName)); TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName));
ClassTypeVar* ctv = getMutable<ClassTypeVar>(classTy); ClassTypeVar* ctv = getMutable<ClassTypeVar>(classTy);
TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level});
TableTypeVar* metatable = getMutable<TableTypeVar>(metaTy);
ctv->metatable = metaTy; ctv->metatable = metaTy;
scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; scope->exportedTypeBindings[className] = TypeFun{{}, classTy};
}
for (const AstDeclaredClassProp& prop : declaredClass.props) void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass)
{
if (FFlag::LuauDeclareClassPrototype)
{ {
Name propName(prop.name.value); Name className(declaredClass.name.value);
TypeId propTy = resolveType(scope, *prop.ty);
bool assignToMetatable = isMetamethod(propName); // Don't bother checking if the class definition was incorrect
Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props; if (incorrectClassDefinitions.find(&declaredClass))
return;
// Function types always take 'self', but this isn't reflected in the std::optional<TypeFun> binding;
// parsed annotation. Add it here. if (auto it = scope->exportedTypeBindings.find(className); it != scope->exportedTypeBindings.end())
if (prop.isMethod) binding = it->second;
// This class definition must have been `prototype()`d first.
if (!binding)
ice("Class not predeclared");
TypeId classTy = binding->type;
ClassTypeVar* ctv = getMutable<ClassTypeVar>(classTy);
if (!ctv->metatable)
ice("No metatable for declared class");
TableTypeVar* metatable = getMutable<TableTypeVar>(*ctv->metatable);
for (const AstDeclaredClassProp& prop : declaredClass.props)
{ {
if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(propTy)) Name propName(prop.name.value);
TypeId propTy = resolveType(scope, *prop.ty);
bool assignToMetatable = isMetamethod(propName);
Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props;
// Function types always take 'self', but this isn't reflected in the
// parsed annotation. Add it here.
if (prop.isMethod)
{ {
ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(propTy))
ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); {
ftv->hasSelf = true; ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}});
ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes});
ftv->hasSelf = true;
}
} }
}
if (assignTo.count(propName) == 0) if (assignTo.count(propName) == 0)
{
assignTo[propName] = {propTy};
}
else
{
TypeId currentTy = assignTo[propName].type;
// We special-case this logic to keep the intersection flat; otherwise we
// would create a ton of nested intersection types.
if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(currentTy))
{ {
std::vector<TypeId> options = itv->parts; assignTo[propName] = {propTy};
options.push_back(propTy);
TypeId newItv = addType(IntersectionTypeVar{std::move(options)});
assignTo[propName] = {newItv};
}
else if (get<FunctionTypeVar>(currentTy))
{
TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}});
assignTo[propName] = {intersection};
} }
else else
{ {
reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); TypeId currentTy = assignTo[propName].type;
// We special-case this logic to keep the intersection flat; otherwise we
// would create a ton of nested intersection types.
if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(currentTy))
{
std::vector<TypeId> options = itv->parts;
options.push_back(propTy);
TypeId newItv = addType(IntersectionTypeVar{std::move(options)});
assignTo[propName] = {newItv};
}
else if (get<FunctionTypeVar>(currentTy))
{
TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}});
assignTo[propName] = {intersection};
}
else
{
reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())});
}
}
}
}
else
{
std::optional<TypeId> superTy = std::nullopt;
if (declaredClass.superName)
{
Name superName = Name(declaredClass.superName->value);
std::optional<TypeFun> lookupType = scope->lookupType(superName);
if (!lookupType)
{
reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type});
return;
}
// We don't have generic classes, so this assertion _should_ never be hit.
LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0);
superTy = lookupType->type;
if (!get<ClassTypeVar>(follow(*superTy)))
{
reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'",
superName.c_str(), declaredClass.name.value)});
return;
}
}
Name className(declaredClass.name.value);
TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {}, currentModuleName));
ClassTypeVar* ctv = getMutable<ClassTypeVar>(classTy);
TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level});
TableTypeVar* metatable = getMutable<TableTypeVar>(metaTy);
ctv->metatable = metaTy;
scope->exportedTypeBindings[className] = TypeFun{{}, classTy};
for (const AstDeclaredClassProp& prop : declaredClass.props)
{
Name propName(prop.name.value);
TypeId propTy = resolveType(scope, *prop.ty);
bool assignToMetatable = isMetamethod(propName);
Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props;
// Function types always take 'self', but this isn't reflected in the
// parsed annotation. Add it here.
if (prop.isMethod)
{
if (FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(propTy))
{
ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}});
ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes});
ftv->hasSelf = true;
}
}
if (assignTo.count(propName) == 0)
{
assignTo[propName] = {propTy};
}
else
{
TypeId currentTy = assignTo[propName].type;
// We special-case this logic to keep the intersection flat; otherwise we
// would create a ton of nested intersection types.
if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(currentTy))
{
std::vector<TypeId> options = itv->parts;
options.push_back(propTy);
TypeId newItv = addType(IntersectionTypeVar{std::move(options)});
assignTo[propName] = {newItv};
}
else if (get<FunctionTypeVar>(currentTy))
{
TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}});
assignTo[propName] = {intersection};
}
else
{
reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())});
}
} }
} }
} }
@ -4135,26 +4258,33 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
std::vector<Location> metaArgLocations; std::vector<Location> metaArgLocations;
// Might be a callable table // Might be a callable table or class
std::optional<TypeId> callTy = std::nullopt;
if (const MetatableTypeVar* mttv = get<MetatableTypeVar>(fn)) if (const MetatableTypeVar* mttv = get<MetatableTypeVar>(fn))
{ {
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false)) callTy = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, /* addErrors= */ false);
{ }
// Construct arguments with 'self' added in front else if (const ClassTypeVar* ctv = get<ClassTypeVar>(fn); FFlag::LuauCallableClasses && ctv && ctv->metatable)
TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); {
callTy = getIndexTypeFromType(scope, *ctv->metatable, "__call", expr.func->location, /* addErrors= */ false);
}
TypePack* metaCallArgs = getMutable<TypePack>(metaCallArgPack); if (callTy)
metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); {
// Construct arguments with 'self' added in front
TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail}));
metaArgLocations = *argLocations; TypePack* metaCallArgs = getMutable<TypePack>(metaCallArgPack);
metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); metaCallArgs->head.insert(metaCallArgs->head.begin(), fn);
fn = instantiate(scope, *ty, expr.func->location); metaArgLocations = *argLocations;
metaArgLocations.insert(metaArgLocations.begin(), expr.func->location);
argPack = metaCallArgPack; fn = instantiate(scope, *callTy, expr.func->location);
args = metaCallArgs;
argLocations = &metaArgLocations; argPack = metaCallArgPack;
} args = metaCallArgs;
argLocations = &metaArgLocations;
} }
const FunctionTypeVar* ftv = get<FunctionTypeVar>(fn); const FunctionTypeVar* ftv = get<FunctionTypeVar>(fn);

View file

@ -91,6 +91,13 @@ struct ClassFixture : BuiltinsFixture
typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType};
addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); addGlobalBinding(frontend, "Vector2", vector2Type, "@test");
TypeId callableClassMetaType = arena.addType(TableTypeVar{});
TypeId callableClassType = arena.addType(ClassTypeVar{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"});
getMutable<TableTypeVar>(callableClassMetaType)->props = {
{"__call", {makeFunction(arena, nullopt, {callableClassType, typeChecker.stringType}, {typeChecker.numberType})}},
};
typeChecker.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType};
for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings)
persist(tf.type); persist(tf.type);
@ -514,4 +521,17 @@ TEST_CASE_FIXTURE(ClassFixture, "unions_of_intersections_of_classes")
LUAU_REQUIRE_NO_ERRORS(result); LUAU_REQUIRE_NO_ERRORS(result);
} }
TEST_CASE_FIXTURE(ClassFixture, "callable_classes")
{
ScopedFastFlag luauCallableClasses{"LuauCallableClasses", true};
CheckResult result = check(R"(
local x : CallableClass
local y = x("testing")
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ("number", toString(requireType("y")));
}
TEST_SUITE_END(); TEST_SUITE_END();

View file

@ -396,4 +396,26 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props")
CHECK_EQ(toString(requireType("y")), "string"); CHECK_EQ(toString(requireType("y")), "string");
} }
TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes")
{
ScopedFastFlag LuauDeclareClassPrototype("LuauDeclareClassPrototype", true);
unfreeze(typeChecker.globalTypes);
LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"(
declare class Channel
Messages: { Message }
OnMessage: (message: Message) -> ()
end
declare class Message
Text: string
Channel: Channel
end
)",
"@test");
freeze(typeChecker.globalTypes);
REQUIRE(result.success);
}
TEST_SUITE_END(); TEST_SUITE_END();