From eef309f3867528061c16549a808102e0e2efb793 Mon Sep 17 00:00:00 2001 From: JohnnyMorganz Date: Sat, 22 Oct 2022 14:38:11 +0100 Subject: [PATCH] Use the correct overload symbol --- Analysis/src/AstQuery.cpp | 60 +++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 50299704..de6db423 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -427,6 +427,36 @@ ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) return findVisitor.result; } +static std::optional checkOverloadedDocumentationSymbol( + const Module& module, const Luau::TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) +{ + if (!documentationSymbol) + return std::nullopt; + + // This might be an overloaded function. + if (get(follow(ty))) + { + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) + { + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; + } + } + + return documentationSymbol; +} + std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) { std::vector ancestry = findAstAncestryOfPosition(source, position); @@ -436,31 +466,7 @@ std::optional getDocumentationSymbolAtPosition(const Source if (std::optional binding = findBindingAtPosition(module, source, position)) { - if (binding->documentationSymbol) - { - // This might be an overloaded function binding. - if (get(follow(binding->typeId))) - { - TypeId matchingOverload = nullptr; - if (parentExpr && parentExpr->is()) - { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) - { - matchingOverload = *it; - } - } - - if (matchingOverload) - { - std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; - // Default toString options are fine for this purpose. - overloadSymbol += toString(matchingOverload); - return overloadSymbol; - } - } - } - - return binding->documentationSymbol; + return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, *binding->documentationSymbol); } if (targetExpr) @@ -474,14 +480,14 @@ std::optional getDocumentationSymbolAtPosition(const Source { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) { - return propIt->second.documentationSymbol; + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); } } else if (const ClassTypeVar* ctv = get(parentTy)) { if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) { - return propIt->second.documentationSymbol; + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); } } }