Fix overloaded metamethod in class definitions (#653)

Fixes #652
This commit is contained in:
JohnnyMorganz 2022-08-29 16:28:04 +01:00 committed by GitHub
parent f2b334a4bb
commit e9e2cba77a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 14 deletions

View file

@ -1659,6 +1659,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar
TypeId propTy = resolveType(scope, *prop.ty);
bool assignToMetatable = isMetamethod(propName);
Luau::ClassTypeVar::Props& assignTo = assignToMetatable ? metatable->props : ctv->props;
// Function types always take 'self', but this isn't reflected in the
// parsed annotation. Add it here.
@ -1674,16 +1675,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar
}
}
if (ctv->props.count(propName) == 0)
if (assignTo.count(propName) == 0)
{
if (assignToMetatable)
metatable->props[propName] = {propTy};
else
ctv->props[propName] = {propTy};
assignTo[propName] = {propTy};
}
else
{
TypeId currentTy = assignToMetatable ? metatable->props[propName].type : ctv->props[propName].type;
TypeId currentTy = assignTo[propName].type;
// We special-case this logic to keep the intersection flat; otherwise we
// would create a ton of nested intersection types.
@ -1693,19 +1691,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar
options.push_back(propTy);
TypeId newItv = addType(IntersectionTypeVar{std::move(options)});
if (assignToMetatable)
metatable->props[propName] = {newItv};
else
ctv->props[propName] = {newItv};
assignTo[propName] = {newItv};
}
else if (get<FunctionTypeVar>(currentTy))
{
TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}});
if (assignToMetatable)
metatable->props[propName] = {intersection};
else
ctv->props[propName] = {intersection};
assignTo[propName] = {intersection};
}
else
{

View file

@ -336,4 +336,30 @@ local s : Cls = GetCls()
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(Fixture, "class_definition_overload_metamethods")
{
loadDefinition(R"(
declare class Vector3
end
declare class CFrame
function __mul(self, other: CFrame): CFrame
function __mul(self, other: Vector3): Vector3
end
declare function newVector3(): Vector3
declare function newCFrame(): CFrame
)");
CheckResult result = check(R"(
local base = newCFrame()
local shouldBeCFrame = base * newCFrame()
local shouldBeVector = base * newVector3()
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(toString(requireType("shouldBeCFrame")), "CFrame");
CHECK_EQ(toString(requireType("shouldBeVector")), "Vector3");
}
TEST_SUITE_END();