diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index b93c2cc2..75388760 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -1,580 +1,580 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/AstQuery.h" - -#include "Luau/Module.h" -#include "Luau/Scope.h" -#include "Luau/TypeInfer.h" -#include "Luau/TypeVar.h" -#include "Luau/ToString.h" - -#include "Luau/Common.h" - -#include - -LUAU_FASTFLAGVARIABLE(LuauCheckOverloadedDocSymbol, false) - -namespace Luau -{ - -namespace -{ - - -struct AutocompleteNodeFinder : public AstVisitor -{ - const Position pos; - std::vector ancestry; - - explicit AutocompleteNodeFinder(Position pos, AstNode* root) - : pos(pos) - { - } - - bool visit(AstExpr* expr) override - { - if (expr->location.begin < pos && pos <= expr->location.end) - { - ancestry.push_back(expr); - return true; - } - return false; - } - - bool visit(AstStat* stat) override - { - if (stat->location.begin < pos && pos <= stat->location.end) - { - ancestry.push_back(stat); - return true; - } - return false; - } - - bool visit(AstType* type) override - { - if (type->location.begin < pos && pos <= type->location.end) - { - ancestry.push_back(type); - return true; - } - return false; - } - - bool visit(AstTypeError* type) override - { - // For a missing type, match the whole range including the start position - if (type->isMissing && type->location.containsClosed(pos)) - { - ancestry.push_back(type); - return true; - } - return false; - } - - bool visit(class AstTypePack* typePack) override - { - return true; - } - - bool visit(AstStatBlock* block) override - { - // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. - if (ancestry.empty()) - { - ancestry.push_back(block); - return true; - } - - // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. - // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} - if (!ancestry.empty() && ancestry.back()->is()) - return false; - - // Type annotation error might intersect the block statement when the function header is being written, - // annotation takes priority - if (!ancestry.empty() && ancestry.back()->is()) - return false; - - // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, - // the expression or type wins out. - // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to - // be within the block. - if (block->location.begin == pos && !ancestry.empty()) - { - if (ancestry.back()->asExpr() && !ancestry.back()->is()) - return false; - - if (ancestry.back()->asType()) - return false; - } - - if (block->location.begin <= pos && pos <= block->location.end) - { - ancestry.push_back(block); - return true; - } - return false; - } -}; - -struct FindNode : public AstVisitor -{ - const Position pos; - const Position documentEnd; - AstNode* best = nullptr; - - explicit FindNode(Position pos, Position documentEnd) - : pos(pos) - , documentEnd(documentEnd) - { - } - - bool visit(AstNode* node) override - { - if (node->location.contains(pos)) - { - best = node; - return true; - } - - // Edge case: If we ask for the node at the position that is the very end of the document - // return the innermost AST element that ends at that position. - - if (node->location.end == documentEnd && pos >= documentEnd) - { - best = node; - return true; - } - - return false; - } - - bool visit(AstStatBlock* block) override - { - visit(static_cast(block)); - - for (AstStat* stat : block->body) - { - if (stat->location.end < pos) - continue; - if (stat->location.begin > pos) - break; - - stat->visit(this); - } - - return false; - } -}; - -struct FindFullAncestry final : public AstVisitor -{ - std::vector nodes; - Position pos; - Position documentEnd; - - explicit FindFullAncestry(Position pos, Position documentEnd) - : pos(pos) - , documentEnd(documentEnd) - { - } - - bool visit(AstNode* node) - { - if (node->location.contains(pos)) - { - nodes.push_back(node); - return true; - } - - // Edge case: If we ask for the node at the position that is the very end of the document - // return the innermost AST element that ends at that position. - - if (node->location.end == documentEnd && pos >= documentEnd) - { - nodes.push_back(node); - return true; - } - - return false; - } -}; - -} // namespace - -std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos) -{ - AutocompleteNodeFinder finder{pos, source.root}; - source.root->visit(&finder); - return finder.ancestry; -} - -std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) -{ - const Position end = source.root->location.end; - if (pos > end) - pos = end; - - FindFullAncestry finder(pos, end); - source.root->visit(&finder); - return finder.nodes; -} - -AstNode* findNodeAtPosition(const SourceModule& source, Position pos) -{ - const Position end = source.root->location.end; - if (pos < source.root->location.begin) - return source.root; - - if (pos > end) - pos = end; - - FindNode findNode{pos, end}; - findNode.visit(source.root); - return findNode.best; -} - -AstExpr* findExprAtPosition(const SourceModule& source, Position pos) -{ - AstNode* node = findNodeAtPosition(source, pos); - if (node) - return node->asExpr(); - else - return nullptr; -} - -ScopePtr findScopeAtPosition(const Module& module, Position pos) -{ - LUAU_ASSERT(!module.scopes.empty()); - - Location scopeLocation = module.scopes.front().first; - ScopePtr scope = module.scopes.front().second; - for (const auto& s : module.scopes) - { - if (s.first.contains(pos)) - { - if (!scope || scopeLocation.encloses(s.first)) - { - scopeLocation = s.first; - scope = s.second; - } - } - } - return scope; -} - -std::optional findTypeAtPosition(const Module& module, const SourceModule& sourceModule, Position pos) -{ - if (auto expr = findExprAtPosition(sourceModule, pos)) - { - if (auto it = module.astTypes.find(expr)) - return *it; - } - - return std::nullopt; -} - -std::optional findExpectedTypeAtPosition(const Module& module, const SourceModule& sourceModule, Position pos) -{ - if (auto expr = findExprAtPosition(sourceModule, pos)) - { - if (auto it = module.astExpectedTypes.find(expr)) - return *it; - } - - return std::nullopt; -} - -static std::optional findBindingLocalStatement(const SourceModule& source, const Binding& binding) -{ - std::vector nodes = findAstAncestryOfPosition(source, binding.location.begin); - auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) { - return node->is(); - }); - return iter != nodes.rend() ? std::make_optional((*iter)->as()) : std::nullopt; -} - -std::optional findBindingAtPosition(const Module& module, const SourceModule& source, Position pos) -{ - AstExpr* expr = findExprAtPosition(source, pos); - if (!expr) - return std::nullopt; - - Symbol name; - if (auto g = expr->as()) - name = g->name; - else if (auto l = expr->as()) - name = l->local; - else - return std::nullopt; - - ScopePtr currentScope = findScopeAtPosition(module, pos); - LUAU_ASSERT(currentScope); - - while (currentScope) - { - auto iter = currentScope->bindings.find(name); - if (iter != currentScope->bindings.end() && iter->second.location.begin <= pos) - { - // Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope - std::optional bindingStatement = findBindingLocalStatement(source, iter->second); - if (!bindingStatement || !(*bindingStatement)->location.contains(pos)) - return iter->second; - } - currentScope = currentScope->parent; - } - - return std::nullopt; -} - -namespace -{ -struct FindExprOrLocal : public AstVisitor -{ - const Position pos; - ExprOrLocal result; - - explicit FindExprOrLocal(Position pos) - : pos(pos) - { - } - - // We want to find the result with the smallest location range. - bool isCloserMatch(Location newLocation) - { - auto current = result.getLocation(); - return newLocation.contains(pos) && (!current || current->encloses(newLocation)); - } - - bool visit(AstStatBlock* block) override - { - for (AstStat* stat : block->body) - { - if (stat->location.end <= pos) - continue; - if (stat->location.begin > pos) - break; - - stat->visit(this); - } - - return false; - } - - bool visit(AstExpr* expr) override - { - if (isCloserMatch(expr->location)) - { - result.setExpr(expr); - return true; - } - return false; - } - - bool visitLocal(AstLocal* local) - { - if (isCloserMatch(local->location)) - { - result.setLocal(local); - return true; - } - return false; - } - - bool visit(AstStatLocalFunction* function) override - { - visitLocal(function->name); - return true; - } - - bool visit(AstStatLocal* al) override - { - for (size_t i = 0; i < al->vars.size; ++i) - { - visitLocal(al->vars.data[i]); - } - return true; - } - - virtual bool visit(AstExprFunction* fn) override - { - for (size_t i = 0; i < fn->args.size; ++i) - { - visitLocal(fn->args.data[i]); - } - return visit((class AstExpr*)fn); - } - - virtual bool visit(AstStatFor* forStat) override - { - visitLocal(forStat->var); - return true; - } - - virtual bool visit(AstStatForIn* forIn) override - { - for (AstLocal* var : forIn->vars) - { - visitLocal(var); - } - return true; - } -}; -}; // namespace - -ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) -{ - FindExprOrLocal findVisitor{pos}; - findVisitor.visit(source.root); - return findVisitor.result; -} - -static std::optional checkOverloadedDocumentationSymbol( - const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) -{ - LUAU_ASSERT(FFlag::LuauCheckOverloadedDocSymbol); - - if (!documentationSymbol) - return std::nullopt; - - // This might be an overloaded function. - if (get(follow(ty))) - { - TypeId matchingOverload = nullptr; - if (parentExpr && parentExpr->is()) - { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) - { - matchingOverload = *it; - } - } - - if (matchingOverload) - { - std::string overloadSymbol = *documentationSymbol + "/overload/"; - // Default toString options are fine for this purpose. - overloadSymbol += toString(matchingOverload); - return overloadSymbol; - } - } - - return documentationSymbol; -} - -std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) -{ - std::vector ancestry = findAstAncestryOfPosition(source, position); - - AstExpr* targetExpr = ancestry.size() >= 1 ? ancestry[ancestry.size() - 1]->asExpr() : nullptr; - AstExpr* parentExpr = ancestry.size() >= 2 ? ancestry[ancestry.size() - 2]->asExpr() : nullptr; - - if (std::optional binding = findBindingAtPosition(module, source, position)) - { - if (FFlag::LuauCheckOverloadedDocSymbol) - { - return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol); - } - else - { - if (binding->documentationSymbol) - { - // This might be an overloaded function binding. - if (get(follow(binding->typeId))) - { - TypeId matchingOverload = nullptr; - if (parentExpr && parentExpr->is()) - { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) - { - matchingOverload = *it; - } - } - - if (matchingOverload) - { - std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; - // Default toString options are fine for this purpose. - overloadSymbol += toString(matchingOverload); - return overloadSymbol; - } - } - } - - return binding->documentationSymbol; - } - } - - if (targetExpr) - { - if (AstExprIndexName* indexName = targetExpr->as()) - { - if (auto it = module.astTypes.find(indexName->expr)) - { - TypeId parentTy = follow(*it); - if (const TableTypeVar* ttv = get(parentTy)) - { - if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) - { - if (FFlag::LuauCheckOverloadedDocSymbol) - return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); - else - return propIt->second.documentationSymbol; - } - } - else if (const ClassTypeVar* ctv = get(parentTy)) - { - if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) - { - if (FFlag::LuauCheckOverloadedDocSymbol) - return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); - else - return propIt->second.documentationSymbol; - } - } - } - } - else if (AstExprFunction* fn = targetExpr->as()) - { - // Handle event connection-like structures where we have - // something:Connect(function(a, b, c) end) - // In this case, we want to ascribe a documentation symbol to 'a' - // based on the documentation symbol of Connect. - if (parentExpr && parentExpr->is()) - { - AstExprCall* call = parentExpr->as(); - if (std::optional parentSymbol = getDocumentationSymbolAtPosition(source, module, call->func->location.begin)) - { - for (size_t i = 0; i < call->args.size; ++i) - { - AstExpr* callArg = call->args.data[i]; - if (callArg == targetExpr) - { - std::string fnSymbol = *parentSymbol + "/param/" + std::to_string(i); - for (size_t j = 0; j < fn->args.size; ++j) - { - AstLocal* fnArg = fn->args.data[j]; - - if (fnArg->location.contains(position)) - { - return fnSymbol + "/param/" + std::to_string(j); - } - } - } - } - } - } - } - } - - if (std::optional ty = findTypeAtPosition(module, source, position)) - { - if ((*ty)->documentationSymbol) - { - return (*ty)->documentationSymbol; - } - } - - return std::nullopt; -} - -} // namespace Luau +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AstQuery.h" + +#include "Luau/Module.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/ToString.h" + +#include "Luau/Common.h" + +#include + +LUAU_FASTFLAGVARIABLE(LuauCheckOverloadedDocSymbol, false) + +namespace Luau +{ + +namespace +{ + + +struct AutocompleteNodeFinder : public AstVisitor +{ + const Position pos; + std::vector ancestry; + + explicit AutocompleteNodeFinder(Position pos, AstNode* root) + : pos(pos) + { + } + + bool visit(AstExpr* expr) override + { + if (expr->location.begin < pos && pos <= expr->location.end) + { + ancestry.push_back(expr); + return true; + } + return false; + } + + bool visit(AstStat* stat) override + { + if (stat->location.begin < pos && pos <= stat->location.end) + { + ancestry.push_back(stat); + return true; + } + return false; + } + + bool visit(AstType* type) override + { + if (type->location.begin < pos && pos <= type->location.end) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(AstTypeError* type) override + { + // For a missing type, match the whole range including the start position + if (type->isMissing && type->location.containsClosed(pos)) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(class AstTypePack* typePack) override + { + return true; + } + + bool visit(AstStatBlock* block) override + { + // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. + if (ancestry.empty()) + { + ancestry.push_back(block); + return true; + } + + // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. + // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // Type annotation error might intersect the block statement when the function header is being written, + // annotation takes priority + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, + // the expression or type wins out. + // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to + // be within the block. + if (block->location.begin == pos && !ancestry.empty()) + { + if (ancestry.back()->asExpr() && !ancestry.back()->is()) + return false; + + if (ancestry.back()->asType()) + return false; + } + + if (block->location.begin <= pos && pos <= block->location.end) + { + ancestry.push_back(block); + return true; + } + return false; + } +}; + +struct FindNode : public AstVisitor +{ + const Position pos; + const Position documentEnd; + AstNode* best = nullptr; + + explicit FindNode(Position pos, Position documentEnd) + : pos(pos) + , documentEnd(documentEnd) + { + } + + bool visit(AstNode* node) override + { + if (node->location.contains(pos)) + { + best = node; + return true; + } + + // Edge case: If we ask for the node at the position that is the very end of the document + // return the innermost AST element that ends at that position. + + if (node->location.end == documentEnd && pos >= documentEnd) + { + best = node; + return true; + } + + return false; + } + + bool visit(AstStatBlock* block) override + { + visit(static_cast(block)); + + for (AstStat* stat : block->body) + { + if (stat->location.end < pos) + continue; + if (stat->location.begin > pos) + break; + + stat->visit(this); + } + + return false; + } +}; + +struct FindFullAncestry final : public AstVisitor +{ + std::vector nodes; + Position pos; + Position documentEnd; + + explicit FindFullAncestry(Position pos, Position documentEnd) + : pos(pos) + , documentEnd(documentEnd) + { + } + + bool visit(AstNode* node) + { + if (node->location.contains(pos)) + { + nodes.push_back(node); + return true; + } + + // Edge case: If we ask for the node at the position that is the very end of the document + // return the innermost AST element that ends at that position. + + if (node->location.end == documentEnd && pos >= documentEnd) + { + nodes.push_back(node); + return true; + } + + return false; + } +}; + +} // namespace + +std::vector findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos) +{ + AutocompleteNodeFinder finder{pos, source.root}; + source.root->visit(&finder); + return finder.ancestry; +} + +std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) +{ + const Position end = source.root->location.end; + if (pos > end) + pos = end; + + FindFullAncestry finder(pos, end); + source.root->visit(&finder); + return finder.nodes; +} + +AstNode* findNodeAtPosition(const SourceModule& source, Position pos) +{ + const Position end = source.root->location.end; + if (pos < source.root->location.begin) + return source.root; + + if (pos > end) + pos = end; + + FindNode findNode{pos, end}; + findNode.visit(source.root); + return findNode.best; +} + +AstExpr* findExprAtPosition(const SourceModule& source, Position pos) +{ + AstNode* node = findNodeAtPosition(source, pos); + if (node) + return node->asExpr(); + else + return nullptr; +} + +ScopePtr findScopeAtPosition(const Module& module, Position pos) +{ + LUAU_ASSERT(!module.scopes.empty()); + + Location scopeLocation = module.scopes.front().first; + ScopePtr scope = module.scopes.front().second; + for (const auto& s : module.scopes) + { + if (s.first.contains(pos)) + { + if (!scope || scopeLocation.encloses(s.first)) + { + scopeLocation = s.first; + scope = s.second; + } + } + } + return scope; +} + +std::optional findTypeAtPosition(const Module& module, const SourceModule& sourceModule, Position pos) +{ + if (auto expr = findExprAtPosition(sourceModule, pos)) + { + if (auto it = module.astTypes.find(expr)) + return *it; + } + + return std::nullopt; +} + +std::optional findExpectedTypeAtPosition(const Module& module, const SourceModule& sourceModule, Position pos) +{ + if (auto expr = findExprAtPosition(sourceModule, pos)) + { + if (auto it = module.astExpectedTypes.find(expr)) + return *it; + } + + return std::nullopt; +} + +static std::optional findBindingLocalStatement(const SourceModule& source, const Binding& binding) +{ + std::vector nodes = findAstAncestryOfPosition(source, binding.location.begin); + auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) { + return node->is(); + }); + return iter != nodes.rend() ? std::make_optional((*iter)->as()) : std::nullopt; +} + +std::optional findBindingAtPosition(const Module& module, const SourceModule& source, Position pos) +{ + AstExpr* expr = findExprAtPosition(source, pos); + if (!expr) + return std::nullopt; + + Symbol name; + if (auto g = expr->as()) + name = g->name; + else if (auto l = expr->as()) + name = l->local; + else + return std::nullopt; + + ScopePtr currentScope = findScopeAtPosition(module, pos); + LUAU_ASSERT(currentScope); + + while (currentScope) + { + auto iter = currentScope->bindings.find(name); + if (iter != currentScope->bindings.end() && iter->second.location.begin <= pos) + { + // Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope + std::optional bindingStatement = findBindingLocalStatement(source, iter->second); + if (!bindingStatement || !(*bindingStatement)->location.contains(pos)) + return iter->second; + } + currentScope = currentScope->parent; + } + + return std::nullopt; +} + +namespace +{ +struct FindExprOrLocal : public AstVisitor +{ + const Position pos; + ExprOrLocal result; + + explicit FindExprOrLocal(Position pos) + : pos(pos) + { + } + + // We want to find the result with the smallest location range. + bool isCloserMatch(Location newLocation) + { + auto current = result.getLocation(); + return newLocation.contains(pos) && (!current || current->encloses(newLocation)); + } + + bool visit(AstStatBlock* block) override + { + for (AstStat* stat : block->body) + { + if (stat->location.end <= pos) + continue; + if (stat->location.begin > pos) + break; + + stat->visit(this); + } + + return false; + } + + bool visit(AstExpr* expr) override + { + if (isCloserMatch(expr->location)) + { + result.setExpr(expr); + return true; + } + return false; + } + + bool visitLocal(AstLocal* local) + { + if (isCloserMatch(local->location)) + { + result.setLocal(local); + return true; + } + return false; + } + + bool visit(AstStatLocalFunction* function) override + { + visitLocal(function->name); + return true; + } + + bool visit(AstStatLocal* al) override + { + for (size_t i = 0; i < al->vars.size; ++i) + { + visitLocal(al->vars.data[i]); + } + return true; + } + + virtual bool visit(AstExprFunction* fn) override + { + for (size_t i = 0; i < fn->args.size; ++i) + { + visitLocal(fn->args.data[i]); + } + return visit((class AstExpr*)fn); + } + + virtual bool visit(AstStatFor* forStat) override + { + visitLocal(forStat->var); + return true; + } + + virtual bool visit(AstStatForIn* forIn) override + { + for (AstLocal* var : forIn->vars) + { + visitLocal(var); + } + return true; + } +}; +}; // namespace + +ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) +{ + FindExprOrLocal findVisitor{pos}; + findVisitor.visit(source.root); + return findVisitor.result; +} + +static std::optional checkOverloadedDocumentationSymbol( + const Module& module, const TypeId ty, const AstExpr* parentExpr, const std::optional documentationSymbol) +{ + LUAU_ASSERT(FFlag::LuauCheckOverloadedDocSymbol); + + if (!documentationSymbol) + return std::nullopt; + + // This might be an overloaded function. + if (get(follow(ty))) + { + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) + { + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; + } + } + + return documentationSymbol; +} + +std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) +{ + std::vector ancestry = findAstAncestryOfPosition(source, position); + + AstExpr* targetExpr = ancestry.size() >= 1 ? ancestry[ancestry.size() - 1]->asExpr() : nullptr; + AstExpr* parentExpr = ancestry.size() >= 2 ? ancestry[ancestry.size() - 2]->asExpr() : nullptr; + + if (std::optional binding = findBindingAtPosition(module, source, position)) + { + if (FFlag::LuauCheckOverloadedDocSymbol) + { + return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol); + } + else + { + if (binding->documentationSymbol) + { + // This might be an overloaded function binding. + if (get(follow(binding->typeId))) + { + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) + { + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) + { + matchingOverload = *it; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; + } + } + } + + return binding->documentationSymbol; + } + } + + if (targetExpr) + { + if (AstExprIndexName* indexName = targetExpr->as()) + { + if (auto it = module.astTypes.find(indexName->expr)) + { + TypeId parentTy = follow(*it); + if (const TableTypeVar* ttv = get(parentTy)) + { + if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) + { + if (FFlag::LuauCheckOverloadedDocSymbol) + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + else + return propIt->second.documentationSymbol; + } + } + else if (const ClassTypeVar* ctv = get(parentTy)) + { + if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) + { + if (FFlag::LuauCheckOverloadedDocSymbol) + return checkOverloadedDocumentationSymbol(module, propIt->second.type, parentExpr, propIt->second.documentationSymbol); + else + return propIt->second.documentationSymbol; + } + } + } + } + else if (AstExprFunction* fn = targetExpr->as()) + { + // Handle event connection-like structures where we have + // something:Connect(function(a, b, c) end) + // In this case, we want to ascribe a documentation symbol to 'a' + // based on the documentation symbol of Connect. + if (parentExpr && parentExpr->is()) + { + AstExprCall* call = parentExpr->as(); + if (std::optional parentSymbol = getDocumentationSymbolAtPosition(source, module, call->func->location.begin)) + { + for (size_t i = 0; i < call->args.size; ++i) + { + AstExpr* callArg = call->args.data[i]; + if (callArg == targetExpr) + { + std::string fnSymbol = *parentSymbol + "/param/" + std::to_string(i); + for (size_t j = 0; j < fn->args.size; ++j) + { + AstLocal* fnArg = fn->args.data[j]; + + if (fnArg->location.contains(position)) + { + return fnSymbol + "/param/" + std::to_string(j); + } + } + } + } + } + } + } + } + + if (std::optional ty = findTypeAtPosition(module, source, position)) + { + if ((*ty)->documentationSymbol) + { + return (*ty)->documentationSymbol; + } + } + + return std::nullopt; +} + +} // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 224e9440..1edcef84 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -1219,6 +1219,29 @@ static std::optional getMethodContainingClass(const ModuleP return std::nullopt; } +static bool stringPartOfInterpString(const AstNode* node, Position position) +{ + const AstExprInterpString* interpString = node->as(); + if (!interpString) + { + return false; + } + + printf("expressions.count: %ld\n", interpString->expressions.size); + + for (const AstExpr* expression : interpString->expressions) + { + printf("%s <= %s <= %s -- expression = %d\n", toString(expression->location.begin).c_str(), toString(position).c_str(), toString(expression->location.end).c_str(), expression->classIndex); + if (expression->location.containsClosed(position)) + { + return false; + } + } + + return true; + // return node->is(); +} + static std::optional autocompleteStringParams(const SourceModule& sourceModule, const ModulePtr& module, const std::vector& nodes, Position position, StringCompletionCallback callback) { @@ -1227,7 +1250,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - if (!nodes.back()->is() && !nodes.back()->is()) + if (!nodes.back()->is() && !stringPartOfInterpString(nodes.back(), position) && !nodes.back()->is()) { return std::nullopt; } @@ -1432,7 +1455,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M return autocompleteExpression(sourceModule, *module, singletonTypes, &typeArena, ancestry, position); else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is())) + else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is() || node->is())) { for (const auto& [kind, key, value] : exprTable->items) { @@ -1471,7 +1494,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { return {*ret, ancestry, AutocompleteContext::String}; } - else if (node->is()) + else if (node->is() || stringPartOfInterpString(node, position)) { AutocompleteEntryMap result; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 8338a04a..5cd8c74b 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -2677,6 +2677,7 @@ AstExpr* Parser::parseInterpString() if (!Lexer::fixupQuotedString(scratchData)) { nextLexeme(); + printf("FIXUP QUOTED STRING FAIL\n"); return reportExprError(startLocation, {}, "Interpolated string literal contains malformed escape sequence"); } @@ -2694,9 +2695,30 @@ AstExpr* Parser::parseInterpString() return allocator.alloc(startLocation, stringsArray, expressionsArray); } - AstExpr* expression = parseExpr(); + switch (lexer.current().type) + { + case Lexeme::InterpStringMid: + case Lexeme::InterpStringEnd: + { + nextLexeme(); + expressions.push_back(reportExprError(location, {}, "Malformed interpolated string, expected expression inside '{}'")); + AstArray> stringsArray = copy(strings); + AstArray expressionsArray = copy(expressions); - expressions.push_back(expression); + return allocator.alloc(startLocation, stringsArray, expressionsArray); + } + case Lexeme::BrokenString: + { + nextLexeme(); + expressions.push_back(reportExprError(location, {}, "Malformed interpolated string, did you forget to add a '`'?")); + AstArray> stringsArray = copy(strings); + AstArray expressionsArray = copy(expressions); + + return allocator.alloc(startLocation, stringsArray, expressionsArray); + } + default: + expressions.push_back(parseExpr()); + } switch (lexer.current().type) { @@ -2706,11 +2728,14 @@ AstExpr* Parser::parseInterpString() break; case Lexeme::BrokenInterpDoubleBrace: nextLexeme(); + printf("BROKEN INTERP DOUBLE BRACE\n"); return reportExprError(location, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); case Lexeme::BrokenString: nextLexeme(); + printf("BROKEN STRING\n"); return reportExprError(location, {}, "Malformed interpolated string, did you forget to add a '}'?"); default: + printf("DEFAULT: %s\n", lexer.current().toString().c_str()); return reportExprError(location, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); } } while (true); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 9a5c3411..4653b520 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -7,6 +7,7 @@ #include "Luau/StringUtils.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -2708,13 +2709,46 @@ a = if temp then even else abc@3 CHECK(ac.entryMap.count("abcdef")); } -TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string") +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_constant") { + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + printf("======= autocomplete_interpolated_string_constant =======\n"); + + // check(R"(f(`@1`))"); + // auto ac = autocomplete('1'); + // CHECK(ac.entryMap.empty()); + // CHECK_EQ(ac.context, AutocompleteContext::String); + + check(R"(f(`{"a"} @1`))"); + auto ac = autocomplete('1'); + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::String); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_expression") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + check(R"(f(`expression = {@1}`))"); + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string_expression_with_comments") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + check(R"(f(`expression = {--[[ bla bla bla ]]@1`))"); auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); CHECK_EQ(ac.context, AutocompleteContext::Expression); + + check(R"(f(`expression = {@1--[[ bla bla bla ]]`))"); + ac = autocomplete('1'); + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack")