Use the correct overload symbol

This commit is contained in:
JohnnyMorganz 2022-10-22 14:38:11 +01:00
parent cecea7785b
commit eef309f386

View file

@ -427,6 +427,36 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos)
return findVisitor.result;
}
static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol(
const Module& module, const Luau::TypeId ty, const AstExpr* parentExpr, const std::optional<DocumentationSymbol> documentationSymbol)
{
if (!documentationSymbol)
return std::nullopt;
// This might be an overloaded function.
if (get<IntersectionTypeVar>(follow(ty)))
{
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = *it;
}
}
if (matchingOverload)
{
std::string overloadSymbol = *documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
}
}
return documentationSymbol;
}
std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position)
{
std::vector<AstNode*> ancestry = findAstAncestryOfPosition(source, position);
@ -436,31 +466,7 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
if (std::optional<Binding> binding = findBindingAtPosition(module, source, position))
{
if (binding->documentationSymbol)
{
// This might be an overloaded function binding.
if (get<IntersectionTypeVar>(follow(binding->typeId)))
{
TypeId matchingOverload = nullptr;
if (parentExpr && parentExpr->is<AstExprCall>())
{
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
{
matchingOverload = *it;
}
}
if (matchingOverload)
{
std::string overloadSymbol = *binding->documentationSymbol + "/overload/";
// Default toString options are fine for this purpose.
overloadSymbol += toString(matchingOverload);
return overloadSymbol;
}
}
}
return binding->documentationSymbol;
return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, *binding->documentationSymbol);
}
if (targetExpr)
@ -474,14 +480,14 @@ std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const Source
{
if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end())
{
return propIt->second.documentationSymbol;
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
}
}
else if (const ClassTypeVar* ctv = get<ClassTypeVar>(parentTy))
{
if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end())
{
return propIt->second.documentationSymbol;
return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol);
}
}
}