From 9a4bc6aeb8a5178c7e80d80363c6ce2634446ecb Mon Sep 17 00:00:00 2001 From: checkraisefold Date: Tue, 5 Nov 2024 10:33:21 -0800 Subject: [PATCH 1/8] Fix definition module name & location (#1495) Closes #1441 Brings behavior to parity with the old solver by filling in definitionLocation and definitionModuleName for Luau-consuming programs/libraries to use. --- Analysis/src/ConstraintGenerator.cpp | 16 +++++++++++++++- Analysis/src/ConstraintSolver.cpp | 18 ++++++++++-------- tests/TypeInfer.modules.test.cpp | 21 +++++++++++++++++++-- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index 8153c3d5..c8aa8209 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -34,6 +34,7 @@ LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues) +LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) namespace Luau { @@ -2844,6 +2845,10 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, ttv->state = TableState::Unsealed; ttv->definitionModuleName = module->name; + if (FFlag::LuauNewSolverPopulateTableLocations) + { + ttv->definitionLocation = expr->location; + } ttv->scope = scope.get(); interiorTypes.back().push_back(ty); @@ -3301,7 +3306,16 @@ TypeId ConstraintGenerator::resolveTableType(const ScopePtr& scope, AstType* ty, ice->ice("Unexpected property access " + std::to_string(int(astIndexer->access))); } - return arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); + TypeId tableTy = arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); + TableType* ttv = getMutable(tableTy); + + if (FFlag::LuauNewSolverPopulateTableLocations) + { + ttv->definitionModuleName = module->name; + ttv->definitionLocation = tab->location; + } + + return tableTy; } TypeId ConstraintGenerator::resolveFunctionType( diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 34e08fe3..bcda4e23 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -33,6 +33,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack) +LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) namespace Luau { @@ -1108,10 +1109,15 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul target = follow(instantiated); } + else if (FFlag::LuauNewSolverPopulateTableLocations) + { + // This is a new type - redefine the location. + ttv->definitionLocation = constraint->location; + ttv->definitionModuleName = currentModuleName; + } ttv->instantiatedTypeParams = typeArguments; ttv->instantiatedTypePackParams = packArguments; - // TODO: Fill in definitionModuleName. } bindResult(target); @@ -1433,7 +1439,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullis() || expr->is() || expr->is() || expr->is()) + else if (expr->is() || expr->is() || expr->is() || + expr->is()) { Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; u2.unify(actualArgTy, expectedArgTy); @@ -2326,12 +2333,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl return true; } -bool ConstraintSolver::tryDispatchIterableFunction( - TypeId nextTy, - TypeId tableTy, - const IterableConstraint& c, - NotNull constraint -) +bool ConstraintSolver::tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint) { const FunctionType* nextFn = get(nextTy); // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 4f797690..5d5df24a 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauRequireCyclesDontAlwaysReturnAny) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauTypestateBuiltins2) +LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) using namespace Luau; @@ -466,7 +467,15 @@ local b: B.T = a LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::LuauSolverV2) - CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); + { + if (FFlag::LuauNewSolverPopulateTableLocations) + CHECK( + toString(result.errors.at(0)) == + "Type 'T' from 'game/A' could not be converted into 'T' from 'game/B'; at [read \"x\"], number is not exactly string" + ); + else + CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); + } else { const std::string expected = R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' @@ -507,7 +516,15 @@ local b: B.T = a LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::LuauSolverV2) - CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); + { + if (FFlag::LuauNewSolverPopulateTableLocations) + CHECK( + toString(result.errors.at(0)) == + "Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'; at [read \"x\"], number is not exactly string" + ); + else + CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); + } else { const std::string expected = R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' From f1d4621d591b6c4fadef99f446967479605b60d4 Mon Sep 17 00:00:00 2001 From: checkraisefold Date: Tue, 5 Nov 2024 15:21:18 -0800 Subject: [PATCH 2/8] Pre-populate/duplicate check class definitions (new solver) (#1493) Closes #1492 Tested and working with the test case in the aforementioned issue, along with the full defs of luau-lsp with no issues or type errors In normal Luau files, you can use type aliases and type functions before they are declared. The same extends to declaration files, **except** in the new solver. The old solver perfectly allows this, and in fact intentionally adds it: https://github.com/luau-lang/luau/blob/db809395bf5739c895a24dc73960b9e9ab6468c5/Analysis/src/TypeInfer.cpp#L1711-L1717 This causes *much* headache and pain for external projects that make use of declaration files; namely, luau-lsp generates them from MaximumADHD's API dump, which is not ordered by dependency. This means silent error-types popping up everywhere because types are used before they are declared. The workaround would be to make code to manually reorder class definitions based on their dependencies with a bunch of code, but this is clearly not ideal, and won't work for classes dependent on each other/recursive. The solution used here is the same as is used for type aliases - the name binding for the class is given a blocked type before running the rest of constraint generation on the block. Questions remain: - Should the logic be split off of `checkAliases`? - Should a bound type be used, or should the (blocked) binding type be directly emplaced with the class type? What are the ramifications of emplacing with the bound versus the raw type? One ramification was initially ran into through an assertion because the class `superTy`/`parent` was bound, and several pieces of code assume it is not, so it had to be made followed. - Is folllowing `superTy` to set `parent` the correct workaround for the assertions thrown, or should the code expecting `parent` to be a ClassType without following it be modified instead to follow `parent`? - Should `scope->privateTypeBindings` also be checked for the duplicate error? I would presume so, since having a class with the same name as a private alias or type function should error as well? The extraneous whitespace changes are clang-format ones done automatically that should've been done in the last release - I can remove them if necessary and let another sync or OSS cleanup commit fix it. --- Analysis/src/ConstraintGenerator.cpp | 47 ++++++++++++++++++++++++++-- tests/TypeInfer.definitions.test.cpp | 27 +++++++++------- 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index c8aa8209..d05623a8 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -34,6 +34,7 @@ LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues) +LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses) LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) namespace Luau @@ -654,6 +655,7 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* block) { std::unordered_map aliasDefinitionLocations; + std::unordered_map classDefinitionLocations; // In order to enable mutually-recursive type aliases, we need to // populate the type bindings before we actually check any of the @@ -751,6 +753,32 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc scope->privateTypeBindings[function->name.value] = std::move(typeFunction); aliasDefinitionLocations[function->name.value] = function->location; } + else if (auto classDeclaration = stat->as()) + { + if (!FFlag::LuauNewSolverPrePopulateClasses) + continue; + + if (scope->exportedTypeBindings.count(classDeclaration->name.value)) + { + auto it = classDefinitionLocations.find(classDeclaration->name.value); + LUAU_ASSERT(it != classDefinitionLocations.end()); + reportError(classDeclaration->location, DuplicateTypeDefinition{classDeclaration->name.value, it->second}); + continue; + } + + // A class might have no name if the code is syntactically + // illegal. We mustn't prepopulate anything in this case. + if (classDeclaration->name == kParseNameError) + continue; + + ScopePtr defnScope = childScope(classDeclaration, scope); + + TypeId initialType = arena->addType(BlockedType{}); + TypeFun initialFun{initialType}; + scope->exportedTypeBindings[classDeclaration->name.value] = std::move(initialFun); + + classDefinitionLocations[classDeclaration->name.value] = classDeclaration->location; + } } } @@ -1646,6 +1674,11 @@ static bool isMetamethod(const Name& name) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) { + // If a class with the same name was already defined, we skip over + auto bindingIt = scope->exportedTypeBindings.find(declaredClass->name.value); + if (FFlag::LuauNewSolverPrePopulateClasses && bindingIt == scope->exportedTypeBindings.end()) + return ControlFlow::None; + std::optional superTy = std::make_optional(builtinTypes->classType); if (declaredClass->superName) { @@ -1660,7 +1693,10 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas // We don't have generic classes, so this assertion _should_ never be hit. LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); - superTy = lookupType->type; + if (FFlag::LuauNewSolverPrePopulateClasses) + superTy = follow(lookupType->type); + else + superTy = lookupType->type; if (!get(follow(*superTy))) { @@ -1683,7 +1719,14 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas ctv->metatable = metaTy; - scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + + if (FFlag::LuauNewSolverPrePopulateClasses) + { + TypeId classBindTy = bindingIt->second.type; + emplaceType(asMutable(classBindTy), classTy); + } + else + scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; if (declaredClass->indexer) { diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 5a530e83..2ab90ab5 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauNewSolverPrePopulateClasses) + TEST_SUITE_BEGIN("DefinitionTests"); TEST_CASE_FIXTURE(Fixture, "definition_file_simple") @@ -492,11 +494,8 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_indexer") TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { - unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = frontend.loadDefinitionFile( - frontend.globals, - frontend.globals.globalScope, - R"( + ScopedFastFlag _{FFlag::LuauNewSolverPrePopulateClasses, true}; + loadDefinition(R"( declare class Channel Messages: { Message } OnMessage: (message: Message) -> () @@ -506,13 +505,19 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") Text: string Channel: Channel end - )", - "@test", - /* captureComments */ false - ); - freeze(frontend.globals.globalTypes); + )"); - REQUIRE(result.success); + CheckResult result = check(R"( + local a: Channel + local b = a.Messages[1] + local c = b.Channel + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Channel"); + CHECK_EQ(toString(requireType("b")), "Message"); + CHECK_EQ(toString(requireType("c")), "Channel"); } TEST_CASE_FIXTURE(Fixture, "definition_file_has_source_module_name_set") From 47543e5df11bb4e7e9f7b653c65cee54d957b307 Mon Sep 17 00:00:00 2001 From: aaron Date: Tue, 5 Nov 2024 15:25:38 -0800 Subject: [PATCH 3/8] Set the defining module even when the new solver cloned the type. (#1506) Follow up to #1495: a small fixup for the defining module and location to get set even when cloning was required. --- Analysis/src/ConstraintSolver.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index bcda4e23..398f0aa5 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -1109,7 +1109,8 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul target = follow(instantiated); } - else if (FFlag::LuauNewSolverPopulateTableLocations) + + if (FFlag::LuauNewSolverPopulateTableLocations) { // This is a new type - redefine the location. ttv->definitionLocation = constraint->location; From 26b2307a8bc0c3783029da4e2bb34347f8ff9c32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bar=C4=B1=C5=9F?= <34089907+Barocena@users.noreply.github.com> Date: Thu, 7 Nov 2024 02:23:33 +0300 Subject: [PATCH 4/8] Replace old site urls (#1505) this PR replaces all the old site urls from luau-lang.org to luau.org --- CONTRIBUTING.md | 2 +- Config/include/Luau/LinterConfig.h | 2 +- README.md | 8 ++++---- VM/src/lapi.cpp | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 26579740..d5d41c42 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,7 +8,7 @@ Some questions help improve the language, implementation or documentation by ins ## Documentation -A [separate site repository](https://github.com/luau-lang/site) hosts the language documentation, which is accessible on https://luau-lang.org. +A [separate site repository](https://github.com/luau-lang/site) hosts the language documentation, which is accessible on https://luau.org. Changes to this documentation that improve clarity, fix grammatical issues, explain aspects that haven't been explained before and the like are warmly welcomed. Please feel free to [create a pull request](https://help.github.com/articles/about-pull-requests/) to improve our documentation. Note that at this point the documentation is English-only. diff --git a/Config/include/Luau/LinterConfig.h b/Config/include/Luau/LinterConfig.h index 3a68c0d7..e9305009 100644 --- a/Config/include/Luau/LinterConfig.h +++ b/Config/include/Luau/LinterConfig.h @@ -15,7 +15,7 @@ struct HotComment; struct LintWarning { - // Make sure any new lint codes are documented here: https://luau-lang.org/lint + // Make sure any new lint codes are documented here: https://luau.org/lint // Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints enum Code { diff --git a/README.md b/README.md index ba337585..edf4a553 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,11 @@ Luau ![CI](https://github.com/luau-lang/luau/actions/workflows/build.yml/badge.s Luau (lowercase u, /ˈlu.aʊ/) is a fast, small, safe, gradually typed embeddable scripting language derived from [Lua](https://lua.org). -It is designed to be backwards compatible with Lua 5.1, as well as incorporating [some features](https://luau-lang.org/compatibility) from future Lua releases, but also expands the feature set (most notably with type annotations). Luau is largely implemented from scratch, with the language runtime being a very heavily modified version of Lua 5.1 runtime, with completely rewritten interpreter and other [performance innovations](https://luau-lang.org/performance). The runtime mostly preserves Lua 5.1 API, so existing bindings should be more or less compatible with a few caveats. +It is designed to be backwards compatible with Lua 5.1, as well as incorporating [some features](https://luau.org/compatibility) from future Lua releases, but also expands the feature set (most notably with type annotations). Luau is largely implemented from scratch, with the language runtime being a very heavily modified version of Lua 5.1 runtime, with completely rewritten interpreter and other [performance innovations](https://luau.org/performance). The runtime mostly preserves Lua 5.1 API, so existing bindings should be more or less compatible with a few caveats. Luau is used by Roblox game developers to write game code, as well as by Roblox engineers to implement large parts of the user-facing application code as well as portions of the editor (Roblox Studio) as plugins. Roblox chose to open-source Luau to foster collaboration within the Roblox community as well as to allow other companies and communities to benefit from the ongoing language and runtime innovation. As a consequence, Luau is now also used by games like Alan Wake 2 and Warframe. -This repository hosts source code for the language implementation and associated tooling. Documentation for the language is available at https://luau-lang.org/ and accepts contributions via [site repository](https://github.com/luau-lang/site); the language is evolved through RFCs that are located in [rfcs repository](https://github.com/luau-lang/rfcs). +This repository hosts source code for the language implementation and associated tooling. Documentation for the language is available at https://luau.org/ and accepts contributions via [site repository](https://github.com/luau-lang/site); the language is evolved through RFCs that are located in [rfcs repository](https://github.com/luau-lang/rfcs). # Usage @@ -15,7 +15,7 @@ Luau is an embeddable language, but it also comes with two command-line tools by `luau` is a command-line REPL and can also run input files. Note that REPL runs in a sandboxed environment and as such doesn't have access to the underlying file system except for ability to `require` modules. -`luau-analyze` is a command-line type checker and linter; given a set of input files, it produces errors/warnings according to the file configuration, which can be customized by using `--!` comments in the files or [`.luaurc`](https://rfcs.luau-lang.org/config-luaurc) files. For details please refer to [type checking]( https://luau-lang.org/typecheck) and [linting](https://luau-lang.org/lint) documentation. +`luau-analyze` is a command-line type checker and linter; given a set of input files, it produces errors/warnings according to the file configuration, which can be customized by using `--!` comments in the files or [`.luaurc`](https://rfcs.luau.org/config-luaurc) files. For details please refer to [type checking]( https://luau.org/typecheck) and [linting](https://luau.org/lint) documentation. # Installation @@ -28,7 +28,7 @@ Alternatively, you can use one of the packaged distributions (note that these ar - Alpine Linux: [Enable community repositories](https://wiki.alpinelinux.org/w/index.php?title=Enable_Community_Repository) and run `apk add luau` - Gentoo Linux: Luau is [officially packaged by Gentoo](https://packages.gentoo.org/packages/dev-lang/luau) and can be installed using `emerge dev-lang/luau`. You may have to unmask the package first before installing it (which can be done by including the `--autounmask=y` option in the `emerge` command). -After installing, you will want to validate the installation was successful by running the test case [here](https://luau-lang.org/getting-started). +After installing, you will want to validate the installation was successful by running the test case [here](https://luau.org/getting-started). ## Building diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 4c42f8c1..d382a924 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -40,7 +40,7 @@ const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Ri "$URL: www.lua.org $\n"; const char* luau_ident = "$Luau: Copyright (C) 2019-2023 Roblox Corporation $\n" - "$URL: luau-lang.org $\n"; + "$URL: luau.org $\n"; #define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base)) From a36a3c41cc58740823a38a17474dc9246ebbbca6 Mon Sep 17 00:00:00 2001 From: Hunter Goldstein Date: Fri, 8 Nov 2024 13:41:45 -0800 Subject: [PATCH 5/8] Sync to `upstream/release/651` (#1513) ### What's New? * Fragment Autocomplete: a new API allows for type checking a small fragment of code against an existing file, significantly speeding up autocomplete performance in large files. ### New Solver * E-Graphs have landed: this is an ongoing approach to make the new type solver simplify types in a more consistent and principled manner, based on similar work (see: https://egraphs-good.github.io/). * Adds support for exporting / local user type functions (previously they were always exported). * Fixes a set of bugs in which the new solver will fail to complete inference for simple expressions with just literals and operators. ### General Updates * Requiring a path with a ".lua" or ".luau" extension will now have a bespoke error suggesting to remove said extension. * Fixes a bug in which whether two `Luau::Symbol`s are equal depends on whether the new solver is enabled. --- Internal Contributors: Co-authored-by: Aaron Weiss Co-authored-by: Andy Friesen Co-authored-by: David Cope Co-authored-by: Hunter Goldstein Co-authored-by: Varun Saini Co-authored-by: Vighnesh Vijay Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/Autocomplete.h | 86 +- Analysis/include/Luau/AutocompleteTypes.h | 92 + Analysis/include/Luau/ConstraintGenerator.h | 9 +- Analysis/include/Luau/ConstraintSolver.h | 7 + Analysis/include/Luau/EqSatSimplification.h | 50 + .../include/Luau/EqSatSimplificationImpl.h | 363 +++ Analysis/include/Luau/FileResolver.h | 6 +- Analysis/include/Luau/FragmentAutocomplete.h | 19 +- Analysis/include/Luau/ToString.h | 1 + Analysis/include/Luau/Type.h | 25 +- Analysis/src/Autocomplete.cpp | 1997 +------------- Analysis/src/AutocompleteCore.cpp | 2002 ++++++++++++++ Analysis/src/AutocompleteCore.h | 27 + Analysis/src/BuiltinDefinitions.cpp | 4 +- Analysis/src/Constraint.cpp | 17 + Analysis/src/ConstraintGenerator.cpp | 153 +- Analysis/src/ConstraintSolver.cpp | 70 +- Analysis/src/EqSatSimplification.cpp | 2449 +++++++++++++++++ Analysis/src/FragmentAutocomplete.cpp | 130 +- Analysis/src/Frontend.cpp | 20 +- Analysis/src/Substitution.cpp | 2 +- Analysis/src/Symbol.cpp | 3 +- Analysis/src/ToString.cpp | 4 +- Analysis/src/TypeFunction.cpp | 119 +- Ast/include/Luau/Allocator.h | 48 + Ast/include/Luau/Ast.h | 11 +- Ast/include/Luau/Lexer.h | 35 +- Ast/src/Allocator.cpp | 66 + Ast/src/Ast.cpp | 14 +- Ast/src/Lexer.cpp | 59 +- Ast/src/Parser.cpp | 27 +- CLI/Analyze.cpp | 10 +- CLI/FileUtils.cpp | 10 + CLI/FileUtils.h | 2 + CLI/Require.cpp | 20 +- Common/include/Luau/Variant.h | 3 +- Config/include/Luau/Config.h | 36 +- Config/src/Config.cpp | 74 +- EqSat/include/Luau/EGraph.h | 130 +- EqSat/include/Luau/Id.h | 9 +- EqSat/include/Luau/Language.h | 209 +- EqSat/include/Luau/LanguageHash.h | 1 + EqSat/include/Luau/UnionFind.h | 4 +- EqSat/src/Id.cpp | 12 +- EqSat/src/UnionFind.cpp | 32 +- Sources.cmake | 9 +- VM/src/lapi.cpp | 2 +- tests/AstJsonEncoder.test.cpp | 4 +- tests/Autocomplete.test.cpp | 34 - tests/Config.test.cpp | 12 +- tests/ConstraintGeneratorFixture.cpp | 15 +- tests/ConstraintGeneratorFixture.h | 4 +- tests/EqSat.language.test.cpp | 14 +- tests/EqSatSimplification.test.cpp | 728 +++++ tests/Fixture.h | 34 + tests/FragmentAutocomplete.test.cpp | 282 +- tests/Parser.test.cpp | 9 +- tests/RequireByString.test.cpp | 7 + tests/ToString.test.cpp | 1 + tests/TypeFunction.user.test.cpp | 90 + tests/TypeInfer.functions.test.cpp | 32 +- tests/TypeInfer.test.cpp | 33 + 62 files changed, 7318 insertions(+), 2459 deletions(-) create mode 100644 Analysis/include/Luau/AutocompleteTypes.h create mode 100644 Analysis/include/Luau/EqSatSimplification.h create mode 100644 Analysis/include/Luau/EqSatSimplificationImpl.h create mode 100644 Analysis/src/AutocompleteCore.cpp create mode 100644 Analysis/src/AutocompleteCore.h create mode 100644 Analysis/src/EqSatSimplification.cpp create mode 100644 Ast/include/Luau/Allocator.h create mode 100644 Ast/src/Allocator.cpp create mode 100644 tests/EqSatSimplification.test.cpp diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index 96bac9e4..b54f7a44 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -1,10 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/AutocompleteTypes.h" #include "Luau/Location.h" #include "Luau/Type.h" -#include #include #include #include @@ -16,90 +16,8 @@ struct Frontend; struct SourceModule; struct Module; struct TypeChecker; - -using ModulePtr = std::shared_ptr; - -enum class AutocompleteContext -{ - Unknown, - Expression, - Statement, - Property, - Type, - Keyword, - String, -}; - -enum class AutocompleteEntryKind -{ - Property, - Binding, - Keyword, - String, - Type, - Module, - GeneratedFunction, - RequirePath, -}; - -enum class ParenthesesRecommendation -{ - None, - CursorAfter, - CursorInside, -}; - -enum class TypeCorrectKind -{ - None, - Correct, - CorrectFunctionResult, -}; - -struct AutocompleteEntry -{ - AutocompleteEntryKind kind = AutocompleteEntryKind::Property; - // Nullopt if kind is Keyword - std::optional type = std::nullopt; - bool deprecated = false; - // Only meaningful if kind is Property. - bool wrongIndexType = false; - // Set if this suggestion matches the type expected in the context - TypeCorrectKind typeCorrect = TypeCorrectKind::None; - - std::optional containingClass = std::nullopt; - std::optional prop = std::nullopt; - std::optional documentationSymbol = std::nullopt; - Tags tags; - ParenthesesRecommendation parens = ParenthesesRecommendation::None; - std::optional insertText; - - // Only meaningful if kind is Property. - bool indexedWithSelf = false; -}; - -using AutocompleteEntryMap = std::unordered_map; -struct AutocompleteResult -{ - AutocompleteEntryMap entryMap; - std::vector ancestry; - AutocompleteContext context = AutocompleteContext::Unknown; - - AutocompleteResult() = default; - AutocompleteResult(AutocompleteEntryMap entryMap, std::vector ancestry, AutocompleteContext context) - : entryMap(std::move(entryMap)) - , ancestry(std::move(ancestry)) - , context(context) - { - } -}; - -using ModuleName = std::string; -using StringCompletionCallback = - std::function(std::string tag, std::optional ctx, std::optional contents)>; +struct FileResolver; AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); -constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)"; - } // namespace Luau diff --git a/Analysis/include/Luau/AutocompleteTypes.h b/Analysis/include/Luau/AutocompleteTypes.h new file mode 100644 index 00000000..37d45244 --- /dev/null +++ b/Analysis/include/Luau/AutocompleteTypes.h @@ -0,0 +1,92 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Type.h" + +#include + +namespace Luau +{ + +enum class AutocompleteContext +{ + Unknown, + Expression, + Statement, + Property, + Type, + Keyword, + String, +}; + +enum class AutocompleteEntryKind +{ + Property, + Binding, + Keyword, + String, + Type, + Module, + GeneratedFunction, + RequirePath, +}; + +enum class ParenthesesRecommendation +{ + None, + CursorAfter, + CursorInside, +}; + +enum class TypeCorrectKind +{ + None, + Correct, + CorrectFunctionResult, +}; + +struct AutocompleteEntry +{ + AutocompleteEntryKind kind = AutocompleteEntryKind::Property; + // Nullopt if kind is Keyword + std::optional type = std::nullopt; + bool deprecated = false; + // Only meaningful if kind is Property. + bool wrongIndexType = false; + // Set if this suggestion matches the type expected in the context + TypeCorrectKind typeCorrect = TypeCorrectKind::None; + + std::optional containingClass = std::nullopt; + std::optional prop = std::nullopt; + std::optional documentationSymbol = std::nullopt; + Tags tags; + ParenthesesRecommendation parens = ParenthesesRecommendation::None; + std::optional insertText; + + // Only meaningful if kind is Property. + bool indexedWithSelf = false; +}; + +using AutocompleteEntryMap = std::unordered_map; +struct AutocompleteResult +{ + AutocompleteEntryMap entryMap; + std::vector ancestry; + AutocompleteContext context = AutocompleteContext::Unknown; + + AutocompleteResult() = default; + AutocompleteResult(AutocompleteEntryMap entryMap, std::vector ancestry, AutocompleteContext context) + : entryMap(std::move(entryMap)) + , ancestry(std::move(ancestry)) + , context(context) + { + } +}; + +using StringCompletionCallback = + std::function(std::string tag, std::optional ctx, std::optional contents)>; + +constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)"; + +} // namespace Luau diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index 435c62fb..b3b35fc2 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -5,6 +5,7 @@ #include "Luau/Constraint.h" #include "Luau/ControlFlow.h" #include "Luau/DataFlowGraph.h" +#include "Luau/EqSatSimplification.h" #include "Luau/InsertionOrderedMap.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" @@ -15,7 +16,6 @@ #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Variant.h" -#include "Luau/Normalize.h" #include #include @@ -109,6 +109,9 @@ struct ConstraintGenerator // Needed to be able to enable error-suppression preservation for immediate refinements. NotNull normalizer; + + NotNull simplifier; + // Needed to register all available type functions for execution at later stages. NotNull typeFunctionRuntime; // Needed to resolve modules to make 'require' import types properly. @@ -128,6 +131,7 @@ struct ConstraintGenerator ConstraintGenerator( ModulePtr module, NotNull normalizer, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull moduleResolver, NotNull builtinTypes, @@ -405,6 +409,7 @@ private: TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); // make an intersect type function of these two types TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + void prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program); /** Scan the program for global definitions. * @@ -435,6 +440,8 @@ private: const ScopePtr& scope, Location location ); + + TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right); }; /** Borrow a vector of pointers from a vector of owning pointers to constraints. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index c9336c1d..37042c75 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -5,6 +5,7 @@ #include "Luau/Constraint.h" #include "Luau/DataFlowGraph.h" #include "Luau/DenseHash.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Error.h" #include "Luau/Location.h" #include "Luau/Module.h" @@ -64,6 +65,7 @@ struct ConstraintSolver NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; + NotNull simplifier; NotNull typeFunctionRuntime; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; @@ -117,6 +119,7 @@ struct ConstraintSolver explicit ConstraintSolver( NotNull normalizer, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull rootScope, std::vector> constraints, @@ -384,6 +387,10 @@ public: **/ void reproduceConstraints(NotNull scope, const Location& location, const Substitution& subst); + TypeId simplifyIntersection(NotNull scope, Location location, TypeId left, TypeId right); + TypeId simplifyIntersection(NotNull scope, Location location, std::set parts); + TypeId simplifyUnion(NotNull scope, Location location, TypeId left, TypeId right); + TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; diff --git a/Analysis/include/Luau/EqSatSimplification.h b/Analysis/include/Luau/EqSatSimplification.h new file mode 100644 index 00000000..16d00849 --- /dev/null +++ b/Analysis/include/Luau/EqSatSimplification.h @@ -0,0 +1,50 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/TypeFwd.h" +#include "Luau/NotNull.h" +#include "Luau/DenseHash.h" + +#include +#include +#include + +namespace Luau +{ +struct TypeArena; +} + +// The EqSat stuff is pretty template heavy, so we go to some lengths to prevent +// the complexity from leaking outside its implementation sources. +namespace Luau::EqSatSimplification +{ + +struct Simplifier; + +using SimplifierPtr = std::unique_ptr; + +SimplifierPtr newSimplifier(NotNull arena, NotNull builtinTypes); + +} // namespace Luau::EqSatSimplification + +namespace Luau +{ + +struct EqSatSimplificationResult +{ + TypeId result; + + // New type function applications that were created by the reduction phase. + // We return these so that the ConstraintSolver can know to try to reduce + // them. + std::vector newTypeFunctions; +}; + +using EqSatSimplification::newSimplifier; // NOLINT: clang-tidy thinks these are unused. It is incorrect. +using Luau::EqSatSimplification::Simplifier; // NOLINT +using Luau::EqSatSimplification::SimplifierPtr; + +std::optional eqSatSimplify(NotNull simplifier, TypeId ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/EqSatSimplificationImpl.h b/Analysis/include/Luau/EqSatSimplificationImpl.h new file mode 100644 index 00000000..24e8777a --- /dev/null +++ b/Analysis/include/Luau/EqSatSimplificationImpl.h @@ -0,0 +1,363 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/EGraph.h" +#include "Luau/Id.h" +#include "Luau/Language.h" +#include "Luau/Lexer.h" // For Allocator +#include "Luau/NotNull.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ +struct TypeFunction; +} + +namespace Luau::EqSatSimplification +{ + +using StringId = uint32_t; +using Id = Luau::EqSat::Id; + +LUAU_EQSAT_UNIT(TNil); +LUAU_EQSAT_UNIT(TBoolean); +LUAU_EQSAT_UNIT(TNumber); +LUAU_EQSAT_UNIT(TString); +LUAU_EQSAT_UNIT(TThread); +LUAU_EQSAT_UNIT(TTopFunction); +LUAU_EQSAT_UNIT(TTopTable); +LUAU_EQSAT_UNIT(TTopClass); +LUAU_EQSAT_UNIT(TBuffer); + +// Used for any type that eqsat can't do anything interesting with. +LUAU_EQSAT_ATOM(TOpaque, TypeId); + +LUAU_EQSAT_ATOM(SBoolean, bool); +LUAU_EQSAT_ATOM(SString, StringId); + +LUAU_EQSAT_ATOM(TFunction, TypeId); + +LUAU_EQSAT_ATOM(TImportedTable, TypeId); + +LUAU_EQSAT_ATOM(TClass, TypeId); + +LUAU_EQSAT_UNIT(TAny); +LUAU_EQSAT_UNIT(TError); +LUAU_EQSAT_UNIT(TUnknown); +LUAU_EQSAT_UNIT(TNever); + +LUAU_EQSAT_NODE_SET(Union); +LUAU_EQSAT_NODE_SET(Intersection); + +LUAU_EQSAT_NODE_ARRAY(Negation, 1); + +LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, const TypeFunction*); + +LUAU_EQSAT_UNIT(TNoRefine); +LUAU_EQSAT_UNIT(Invalid); + +// enodes are immutable, but types are cyclic. We need a way to tie the knot. +// We handle this by generating TBound nodes at points where we encounter cycles. +// Each TBound has an ordinal that we later map onto the type. +// We use a substitution rule to replace all TBound nodes with their referrent. +LUAU_EQSAT_ATOM(TBound, size_t); + +// Tables are sufficiently unlike other enodes that the Language.h macros won't cut it. +struct TTable +{ + explicit TTable(Id basis); + TTable(Id basis, std::vector propNames_, std::vector propTypes_); + + // All TTables extend some other table. This may be TTopTable. + // + // It will frequently be a TImportedTable, in which case we can reuse things + // like source location and documentation info. + Id getBasis() const; + EqSat::Slice propTypes() const; + // TODO: Also support read-only table props + // TODO: Indexer type, index result type. + + std::vector propNames; + + // The enode interface + EqSat::Slice mutableOperands(); + EqSat::Slice operands() const; + bool operator==(const TTable& rhs) const; + bool operator!=(const TTable& rhs) const + { + return !(*this == rhs); + } + + struct Hash + { + size_t operator()(const TTable& value) const; + }; + +private: + // The first element of this vector is the basis. Subsequent elements are + // property types. As we add other things like read-only properties and + // indexers, the structure of this array is likely to change. + // + // We encode our data in this way so that the operands() method can properly + // return a Slice. + std::vector storage; +}; + +using EType = EqSat::Language< + TNil, + TBoolean, + TNumber, + TString, + TThread, + TTopFunction, + TTopTable, + TTopClass, + TBuffer, + + TOpaque, + + SBoolean, + SString, + + TFunction, + TTable, + TImportedTable, + TClass, + + TAny, + TError, + TUnknown, + TNever, + + Union, + Intersection, + + Negation, + + TTypeFun, + + Invalid, + TNoRefine, + TBound>; + + +struct StringCache +{ + Allocator allocator; + DenseHashMap strings{{}}; + std::vector views; + + StringId add(std::string_view s); + std::string_view asStringView(StringId id) const; + std::string asString(StringId id) const; +}; + +using EGraph = Luau::EqSat::EGraph; + +struct Simplify +{ + using Data = bool; + + template + Data make(const EGraph&, const T&) const; + + void join(Data& left, const Data& right) const; +}; + +struct Subst +{ + Id eclass; + Id newClass; + + std::string desc; + + Subst(Id eclass, Id newClass, std::string desc = ""); +}; + +struct Simplifier +{ + NotNull arena; + NotNull builtinTypes; + EGraph egraph; + StringCache stringCache; + + // enodes are immutable but types can be cyclic, so we need some way to + // encode the cycle. This map is used to connect TBound nodes to the right + // eclass. + // + // The cyclicIntersection rewrite rule uses this to sense when a cycle can + // be deleted from an intersection or union. + std::unordered_map mappingIdToClass; + + std::vector substs; + + using RewriteRuleFn = void (Simplifier::*)(Id id); + + Simplifier(NotNull arena, NotNull builtinTypes); + + // Utilities + const EqSat::EClass& get(Id id) const; + Id find(Id id) const; + Id add(EType enode); + + template + const Tag* isTag(Id id) const; + + template + const Tag* isTag(const EType& enode) const; + + void subst(Id from, Id to); + void subst(Id from, Id to, const std::string& ruleName); + void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map& forceNodes); + + void unionClasses(std::vector& hereParts, Id there); + + // Rewrite rules + void simplifyUnion(Id id); + void uninhabitedIntersection(Id id); + void intersectWithNegatedClass(Id id); + void intersectWithNoRefine(Id id); + void cyclicIntersectionOfUnion(Id id); + void cyclicUnionOfIntersection(Id id); + void expandNegation(Id id); + void intersectionOfUnion(Id id); + void intersectTableProperty(Id id); + void uninhabitedTable(Id id); + void unneededTableModification(Id id); + void builtinTypeFunctions(Id id); + void iffyTypeFunctions(Id id); +}; + +template +struct QueryIterator +{ + QueryIterator(); + QueryIterator(EGraph* egraph, Id eclass); + + bool operator==(const QueryIterator& other) const; + bool operator!=(const QueryIterator& other) const; + + std::pair operator*() const; + + QueryIterator& operator++(); + QueryIterator& operator++(int); + +private: + EGraph* egraph = nullptr; + Id eclass; + size_t index = 0; +}; + +template +struct Query +{ + EGraph* egraph; + Id eclass; + + Query(EGraph* egraph, Id eclass) + : egraph(egraph) + , eclass(eclass) + { + } + + QueryIterator begin() + { + return QueryIterator{egraph, eclass}; + } + + QueryIterator end() + { + return QueryIterator{}; + } +}; + +template +QueryIterator::QueryIterator() + : egraph(nullptr) + , eclass(Id{0}) + , index(0) +{ +} + +template +QueryIterator::QueryIterator(EGraph* egraph_, Id eclass) + : egraph(egraph_) + , eclass(eclass) + , index(0) +{ + const auto& ecl = (*egraph)[eclass]; + + static constexpr const int idx = EType::VariantTy::getTypeId(); + + for (const auto& enode : ecl.nodes) + { + if (enode.index() < idx) + ++index; + else + break; + } + + if (index >= ecl.nodes.size() || ecl.nodes[index].index() != idx) + { + egraph = nullptr; + index = 0; + } +} + +template +bool QueryIterator::operator==(const QueryIterator& rhs) const +{ + if (egraph == nullptr && rhs.egraph == nullptr) + return true; + + return egraph == rhs.egraph && eclass == rhs.eclass && index == rhs.index; +} + +template +bool QueryIterator::operator!=(const QueryIterator& rhs) const +{ + return !(*this == rhs); +} + +template +std::pair QueryIterator::operator*() const +{ + LUAU_ASSERT(egraph != nullptr); + + EGraph::EClassT& ecl = (*egraph)[eclass]; + + LUAU_ASSERT(index < ecl.nodes.size()); + auto& enode = ecl.nodes[index]; + Tag* result = enode.template get(); + LUAU_ASSERT(result); + return {result, index}; +} + +// pre-increment +template +QueryIterator& QueryIterator::operator++() +{ + const auto& ecl = (*egraph)[eclass]; + + ++index; + if (index >= ecl.nodes.size() || ecl.nodes[index].index() != EType::VariantTy::getTypeId()) + { + egraph = nullptr; + index = 0; + } + + return *this; +} + +// post-increment +template +QueryIterator& QueryIterator::operator++(int) +{ + QueryIterator res = *this; + ++res; + return res; +} + +} // namespace Luau::EqSatSimplification diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 2f17e566..d3fc6ad3 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -32,7 +32,11 @@ struct ModuleInfo bool optional = false; }; -using RequireSuggestion = std::string; +struct RequireSuggestion +{ + std::string label; + std::string fullPath; +}; using RequireSuggestions = std::vector; struct FileResolver diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h index 671cbb69..50c456f1 100644 --- a/Analysis/include/Luau/FragmentAutocomplete.h +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -3,9 +3,10 @@ #include "Luau/Ast.h" #include "Luau/Parser.h" -#include "Luau/Autocomplete.h" +#include "Luau/AutocompleteTypes.h" #include "Luau/DenseHash.h" #include "Luau/Module.h" +#include "Luau/Frontend.h" #include #include @@ -27,13 +28,23 @@ struct FragmentParseResult std::string fragmentToParse; AstStatBlock* root = nullptr; std::vector ancestry; + AstStat* nearestStatement = nullptr; std::unique_ptr alloc = std::make_unique(); }; struct FragmentTypeCheckResult { ModulePtr incrementalModule = nullptr; - Scope* freshScope = nullptr; + ScopePtr freshScope; + std::vector ancestry; +}; + +struct FragmentAutocompleteResult +{ + ModulePtr incrementalModule; + Scope* freshScope; + TypeArena arenaForAutocomplete; + AutocompleteResult acResults; }; FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); @@ -48,11 +59,11 @@ FragmentTypeCheckResult typecheckFragment( std::string_view src ); -AutocompleteResult fragmentAutocomplete( +FragmentAutocompleteResult fragmentAutocomplete( Frontend& frontend, std::string_view src, const ModuleName& moduleName, - Position& cursorPosition, + Position cursorPosition, std::optional opts, StringCompletionCallback callback ); diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index f8001e08..4862e3b4 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -44,6 +44,7 @@ struct ToStringOptions bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self + bool useQuestionMarks = true; // If true, use a postfix ? for options, else write them out as unions that include nil. size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index d100fa4d..0005605e 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -31,6 +31,7 @@ namespace Luau struct TypeArena; struct Scope; using ScopePtr = std::shared_ptr; +struct Module; struct TypeFunction; struct Constraint; @@ -598,6 +599,18 @@ struct ClassType } }; +// Data required to initialize a user-defined function and its environment +struct UserDefinedFunctionData +{ + // Store a weak module reference to ensure the lifetime requirements are preserved + std::weak_ptr owner; + + // References to AST elements are owned by the Module allocator which also stores this type + AstStatTypeFunction* definition = nullptr; + + DenseHashMap environment{""}; +}; + /** * An instance of a type function that has not yet been reduced to a more concrete * type. The constraint solver receives a constraint to reduce each @@ -613,17 +626,20 @@ struct TypeFunctionInstanceType std::vector packArguments; std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs + UserDefinedFunctionData userFuncData; TypeFunctionInstanceType( NotNull function, std::vector typeArguments, std::vector packArguments, - std::optional userFuncName = std::nullopt + std::optional userFuncName, + UserDefinedFunctionData userFuncData ) : function(function) , typeArguments(typeArguments) , packArguments(packArguments) , userFuncName(userFuncName) + , userFuncData(userFuncData) { } @@ -640,6 +656,13 @@ struct TypeFunctionInstanceType , packArguments(packArguments) { } + + TypeFunctionInstanceType(NotNull function, std::vector typeArguments, std::vector packArguments) + : function{function} + , typeArguments(typeArguments) + , packArguments(packArguments) + { + } }; /** Represents a pending type alias instantiation. diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 829f6bb7..eb7e2298 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -2,2000 +2,17 @@ #include "Luau/Autocomplete.h" #include "Luau/AstQuery.h" -#include "Luau/BuiltinDefinitions.h" -#include "Luau/Common.h" -#include "Luau/FileResolver.h" +#include "Luau/TypeArena.h" +#include "Luau/Module.h" #include "Luau/Frontend.h" -#include "Luau/ToString.h" -#include "Luau/Subtyping.h" -#include "Luau/TypeInfer.h" -#include "Luau/TypePack.h" -#include -#include -#include +#include "AutocompleteCore.h" LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions) - -LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) -LUAU_FASTINT(LuauTypeInferIterationLimit) -LUAU_FASTINT(LuauTypeInferRecursionLimit) - -static const std::unordered_set kStatementStartingKeywords = - {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; namespace Luau { - -static bool alreadyHasParens(const std::vector& nodes) -{ - auto iter = nodes.rbegin(); - while (iter != nodes.rend() && - ((*iter)->is() || (*iter)->is() || (*iter)->is() || (*iter)->is())) - { - iter++; - } - - if (iter == nodes.rend() || iter == nodes.rbegin()) - { - return false; - } - - if (AstExprCall* call = (*iter)->as()) - { - return call->func == *(iter - 1); - } - - return false; -} - -static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionType* func, const std::vector& nodes) -{ - if (alreadyHasParens(nodes)) - { - return ParenthesesRecommendation::None; - } - - auto idxExpr = nodes.back()->as(); - bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; - auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); - - if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) - return ParenthesesRecommendation::CursorInside; - - bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); - return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; -} - -static ParenthesesRecommendation getParenRecommendationForIntersect(const IntersectionType* intersect, const std::vector& nodes) -{ - ParenthesesRecommendation rec = ParenthesesRecommendation::None; - for (Luau::TypeId partId : intersect->parts) - { - if (auto partFunc = Luau::get(partId)) - { - rec = std::max(rec, getParenRecommendationForFunc(partFunc, nodes)); - } - else - { - return ParenthesesRecommendation::None; - } - } - return rec; -} - -static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::vector& nodes, TypeCorrectKind typeCorrect) -{ - // If element is already type-correct, even a function should be inserted without parenthesis - if (typeCorrect == TypeCorrectKind::Correct) - return ParenthesesRecommendation::None; - - id = Luau::follow(id); - if (auto func = get(id)) - { - return getParenRecommendationForFunc(func, nodes); - } - else if (auto intersect = get(id)) - { - return getParenRecommendationForIntersect(intersect, nodes); - } - return ParenthesesRecommendation::None; -} - -static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) -{ - auto expr = node->asExpr(); - if (!expr) - return std::nullopt; - - // Extra care for first function call argument location - // When we don't have anything inside () yet, we also don't have an AST node to base our lookup - if (AstExprCall* exprCall = expr->as()) - { - if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) - { - auto it = module.astTypes.find(exprCall->func); - - if (!it) - return std::nullopt; - - const FunctionType* ftv = get(follow(*it)); - - if (!ftv) - return std::nullopt; - - auto [head, tail] = flatten(ftv->argTypes); - unsigned index = exprCall->self ? 1 : 0; - - if (index < head.size()) - return head[index]; - - return std::nullopt; - } - } - - auto it = module.astExpectedTypes.find(expr); - if (!it) - return std::nullopt; - - return *it; -} - -static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena, NotNull builtinTypes) -{ - InternalErrorReporter iceReporter; - UnifierSharedState unifierState(&iceReporter); - Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; - - if (FFlag::LuauSolverV2) - { - TypeCheckLimits limits; - TypeFunctionRuntime typeFunctionRuntime{ - NotNull{&iceReporter}, NotNull{&limits} - }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime - - unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; - unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; - - Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; - - return subtyping.isSubtype(subTy, superTy, scope).isSubtype; - } - else - { - Unifier unifier(NotNull{&normalizer}, scope, Location(), Variance::Covariant); - - // Cost of normalization can be too high for autocomplete response time requirements - unifier.normalize = false; - unifier.checkInhabited = false; - - return unifier.canUnify(subTy, superTy).empty(); - } -} - -static TypeCorrectKind checkTypeCorrectKind( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - AstNode* node, - Position position, - TypeId ty -) -{ - ty = follow(ty); - - LUAU_ASSERT(module.hasModuleScope()); - - NotNull moduleScope{module.getModuleScope().get()}; - - auto typeAtPosition = findExpectedTypeAt(module, node, position); - - if (!typeAtPosition) - return TypeCorrectKind::None; - - TypeId expectedType = follow(*typeAtPosition); - - auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) - { - if (std::optional firstRetTy = first(ftv->retTypes)) - return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); - - return false; - }; - - // We also want to suggest functions that return compatible result - if (const FunctionType* ftv = get(ty); ftv && checkFunctionType(ftv)) - { - return TypeCorrectKind::CorrectFunctionResult; - } - else if (const IntersectionType* itv = get(ty)) - { - for (TypeId id : itv->parts) - { - id = follow(id); - - if (const FunctionType* ftv = get(id); ftv && checkFunctionType(ftv)) - { - return TypeCorrectKind::CorrectFunctionResult; - } - } - } - - return checkTypeMatch(ty, expectedType, moduleScope, typeArena, builtinTypes) ? TypeCorrectKind::Correct : TypeCorrectKind::None; -} - -enum class PropIndexType -{ - Point, - Colon, - Key, -}; - -static void autocompleteProps( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - TypeId rootTy, - TypeId ty, - PropIndexType indexType, - const std::vector& nodes, - AutocompleteEntryMap& result, - std::unordered_set& seen, - std::optional containingClass = std::nullopt -) -{ - rootTy = follow(rootTy); - ty = follow(ty); - - if (seen.count(ty)) - return; - seen.insert(ty); - - auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) - { - if (indexType == PropIndexType::Key) - return false; - - bool calledWithSelf = indexType == PropIndexType::Colon; - - auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) - { - // Strong match with definition is a success - if (calledWithSelf == ftv->hasSelf) - return true; - - // Calls on classes require strict match between how function is declared and how it's called - if (get(rootTy)) - return false; - - // When called with ':', but declared without 'self', it is invalid if a function has incompatible first argument or no arguments at all - // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible - if (std::optional firstArgTy = first(ftv->argTypes)) - { - if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes)) - return calledWithSelf; - } - - return !calledWithSelf; - }; - - if (const FunctionType* ftv = get(type)) - return !isCompatibleCall(ftv); - - // For intersections, any part that is successful makes the whole call successful - if (const IntersectionType* itv = get(type)) - { - for (auto subType : itv->parts) - { - if (const FunctionType* ftv = get(Luau::follow(subType))) - { - if (isCompatibleCall(ftv)) - return false; - } - } - } - - return calledWithSelf; - }; - - auto fillProps = [&](const ClassType::Props& props) - { - for (const auto& [name, prop] : props) - { - // We are walking up the class hierarchy, so if we encounter a property that we have - // already populated, it takes precedence over the property we found just now. - if (result.count(name) == 0 && name != kParseNameError) - { - Luau::TypeId type; - - if (FFlag::LuauSolverV2) - { - if (auto ty = prop.readTy) - type = follow(*ty); - else - continue; - } - else - type = follow(prop.type()); - - TypeCorrectKind typeCorrect = indexType == PropIndexType::Key - ? TypeCorrectKind::Correct - : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); - - ParenthesesRecommendation parens = - indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); - - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Property, - type, - prop.deprecated, - isWrongIndexer(type), - typeCorrect, - containingClass, - &prop, - prop.documentationSymbol, - {}, - parens, - {}, - indexType == PropIndexType::Colon - }; - } - } - }; - - auto fillMetatableProps = [&](const TableType* mtable) - { - auto indexIt = mtable->props.find("__index"); - if (indexIt != mtable->props.end()) - { - TypeId followed = follow(indexIt->second.type()); - if (get(followed) || get(followed)) - { - autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen); - } - else if (auto indexFunction = get(followed)) - { - std::optional indexFunctionResult = first(indexFunction->retTypes); - if (indexFunctionResult) - autocompleteProps(module, typeArena, builtinTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); - } - } - }; - - if (auto cls = get(ty)) - { - containingClass = containingClass.value_or(cls); - fillProps(cls->props); - if (cls->parent) - autocompleteProps(module, typeArena, builtinTypes, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); - } - else if (auto tbl = get(ty)) - fillProps(tbl->props); - else if (auto mt = get(ty)) - { - autocompleteProps(module, typeArena, builtinTypes, rootTy, mt->table, indexType, nodes, result, seen); - - if (auto mtable = get(follow(mt->metatable))) - fillMetatableProps(mtable); - } - else if (auto i = get(ty)) - { - // Complete all properties in every variant - for (TypeId ty : i->parts) - { - AutocompleteEntryMap inner; - std::unordered_set innerSeen = seen; - - autocompleteProps(module, typeArena, builtinTypes, rootTy, ty, indexType, nodes, inner, innerSeen); - - for (auto& pair : inner) - result.insert(pair); - } - } - else if (auto u = get(ty)) - { - // Complete all properties common to all variants - auto iter = begin(u); - auto endIter = end(u); - - while (iter != endIter) - { - if (isNil(*iter)) - ++iter; - else - break; - } - - if (iter == endIter) - return; - - autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, result, seen); - - ++iter; - - while (iter != endIter) - { - AutocompleteEntryMap inner; - std::unordered_set innerSeen; - - if (isNil(*iter)) - { - ++iter; - continue; - } - - autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, inner, innerSeen); - - std::unordered_set toRemove; - - for (const auto& [k, v] : result) - { - (void)v; - if (!inner.count(k)) - toRemove.insert(k); - } - - for (const std::string& k : toRemove) - result.erase(k); - - ++iter; - } - } - else if (auto pt = get(ty)) - { - if (pt->metatable) - { - if (auto mtable = get(*pt->metatable)) - fillMetatableProps(mtable); - } - } - else if (get(get(ty))) - { - autocompleteProps(module, typeArena, builtinTypes, rootTy, builtinTypes->stringType, indexType, nodes, result, seen); - } -} - -static void autocompleteKeywords( - const SourceModule& sourceModule, - const std::vector& ancestry, - Position position, - AutocompleteEntryMap& result -) -{ - LUAU_ASSERT(!ancestry.empty()); - - AstNode* node = ancestry.back(); - - if (!node->is() && node->asExpr()) - { - // This is not strictly correct. We should recommend `and` and `or` only after - // another expression, not at the start of a new one. We should only recommend - // `not` at the start of an expression. Detecting either case reliably is quite - // complex, however; this is good enough for now. - - // These are not context-sensitive keywords, so we can unconditionally assign. - result["and"] = {AutocompleteEntryKind::Keyword}; - result["or"] = {AutocompleteEntryKind::Keyword}; - result["not"] = {AutocompleteEntryKind::Keyword}; - } -} - -static void autocompleteProps( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - TypeId ty, - PropIndexType indexType, - const std::vector& nodes, - AutocompleteEntryMap& result -) -{ - std::unordered_set seen; - autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen); -} - -AutocompleteEntryMap autocompleteProps( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - TypeId ty, - PropIndexType indexType, - const std::vector& nodes -) -{ - AutocompleteEntryMap result; - autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result); - return result; -} - -AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position position, std::string_view moduleName) -{ - AutocompleteEntryMap result; - - for (ScopePtr scope = findScopeAtPosition(module, position); scope; scope = scope->parent) - { - if (auto it = scope->importedTypeBindings.find(std::string(moduleName)); it != scope->importedTypeBindings.end()) - { - for (const auto& [name, ty] : it->second) - result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type}; - - break; - } - } - - return result; -} - -static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result) -{ - if (position == node->location.begin || position == node->location.end) - { - if (auto str = node->as(); str && str->quoteStyle == AstExprConstantString::Quoted) - return; - else if (node->is()) - return; - } - - auto formatKey = [addQuotes](const std::string& key) - { - if (addQuotes) - return "\"" + escape(key) + "\""; - - return escape(key); - }; - - ty = follow(ty); - - if (auto ss = get(get(ty))) - { - result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; - } - else if (auto uty = get(ty)) - { - for (auto el : uty) - { - if (auto ss = get(get(el))) - result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; - } - } -}; - -static bool canSuggestInferredType(ScopePtr scope, TypeId ty) -{ - ty = follow(ty); - - // No point in suggesting 'any', invalid to suggest others - if (get(ty) || get(ty) || get(ty) || get(ty)) - return false; - - // No syntax for unnamed tables with a metatable - if (get(ty)) - return false; - - if (const TableType* ttv = get(ty)) - { - if (ttv->name) - return true; - - if (ttv->syntheticName) - return false; - } - - // We might still have a type with cycles or one that is too long, we'll check that later - return true; -} - -// Walk complex type trees to find the element that is being edited -static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position); - -static std::optional findTypeElementAt(const AstTypeList& astTypeList, TypePackId tp, Position position) -{ - for (size_t i = 0; i < astTypeList.types.size; i++) - { - AstType* type = astTypeList.types.data[i]; - - if (type->location.containsClosed(position)) - { - auto [head, _] = flatten(tp); - - if (i < head.size()) - return findTypeElementAt(type, head[i], position); - } - } - - if (AstTypePack* argTp = astTypeList.tailType) - { - if (auto variadic = argTp->as()) - { - if (variadic->location.containsClosed(position)) - { - auto [_, tail] = flatten(tp); - - if (tail) - { - if (const VariadicTypePack* vtp = get(follow(*tail))) - return findTypeElementAt(variadic->variadicType, vtp->ty, position); - } - } - } - } - - return {}; -} - -static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position) -{ - ty = follow(ty); - - if (astType->is()) - return ty; - - if (astType->is()) - return ty; - - if (AstTypeFunction* type = astType->as()) - { - const FunctionType* ftv = get(ty); - - if (!ftv) - return {}; - - if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) - return element; - - if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) - return element; - } - - // It's possible to walk through other types like intrsection and unions if we find value in doing that - return {}; -} - -std::optional getLocalTypeInScopeAt(const Module& module, Position position, AstLocal* local) -{ - if (ScopePtr scope = findScopeAtPosition(module, position)) - { - for (const auto& [name, binding] : scope->bindings) - { - if (name == local) - return binding.typeId; - } - } - - return {}; -} - -template -static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) -{ - ToStringOptions opts; - opts.useLineBreaks = false; - opts.hideTableKind = true; - opts.functionTypeArguments = functionTypeArguments; - opts.scope = scope; - ToStringResult name = toStringDetailed(ty, opts); - - if (name.error || name.invalid || name.cycle || name.truncated) - return std::nullopt; - - return name.name; -} - -static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool functionTypeArguments = false) -{ - if (!canSuggestInferredType(scope, ty)) - return std::nullopt; - - return tryToStringDetailed(scope, ty, functionTypeArguments); -} - -static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) -{ - std::optional ty; - - if (topType) - ty = findTypeElementAt(topType, inferredType, position); - else - ty = inferredType; - - if (!ty) - return false; - - if (auto name = tryGetTypeNameInScope(scope, *ty)) - { - if (auto it = result.find(*name); it != result.end()) - it->second.typeCorrect = TypeCorrectKind::Correct; - else - result[*name] = AutocompleteEntry{AutocompleteEntryKind::Type, *ty, false, false, TypeCorrectKind::Correct}; - - return true; - } - - return false; -} - -static std::optional tryGetTypePackTypeAt(TypePackId tp, size_t index) -{ - auto [tpHead, tpTail] = flatten(tp); - - if (index < tpHead.size()) - return tpHead[index]; - - // Infinite tail - if (tpTail) - { - if (const VariadicTypePack* vtp = get(follow(*tpTail))) - return vtp->ty; - } - - return {}; -} - -template -std::optional returnFirstNonnullOptionOfType(const UnionType* utv) -{ - std::optional ret; - for (TypeId subTy : utv) - { - if (isNil(subTy)) - continue; - - if (const T* ftv = get(follow(subTy))) - { - if (ret.has_value()) - { - return std::nullopt; - } - ret = ftv; - } - else - { - return std::nullopt; - } - } - return ret; -} - -static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) -{ - auto typeAtPosition = findExpectedTypeAt(module, node, position); - - if (!typeAtPosition) - return std::nullopt; - - TypeId expectedType = follow(*typeAtPosition); - - if (get(expectedType)) - return true; - - if (const IntersectionType* itv = get(expectedType)) - { - return std::all_of( - begin(itv->parts), - end(itv->parts), - [](auto&& ty) - { - return get(Luau::follow(ty)) != nullptr; - } - ); - } - - if (const UnionType* utv = get(expectedType)) - return returnFirstNonnullOptionOfType(utv).has_value(); - - return false; -} - -AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position position, const std::vector& ancestry) -{ - AutocompleteEntryMap result; - - ScopePtr startScope = findScopeAtPosition(module, position); - - for (ScopePtr scope = startScope; scope; scope = scope->parent) - { - for (const auto& [name, ty] : scope->exportedTypeBindings) - { - if (!result.count(name)) - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Type, - ty.type, - false, - false, - TypeCorrectKind::None, - std::nullopt, - std::nullopt, - ty.type->documentationSymbol - }; - } - - for (const auto& [name, ty] : scope->privateTypeBindings) - { - if (!result.count(name)) - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Type, - ty.type, - false, - false, - TypeCorrectKind::None, - std::nullopt, - std::nullopt, - ty.type->documentationSymbol - }; - } - - for (const auto& [name, _] : scope->importedTypeBindings) - { - if (auto binding = scope->linearSearchForBinding(name, true)) - { - if (!result.count(name)) - result[name] = AutocompleteEntry{AutocompleteEntryKind::Module, binding->typeId}; - } - } - } - - AstNode* parent = nullptr; - AstType* topType = nullptr; // TODO: rename? - - for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) - { - if (AstType* asType = (*it)->asType()) - { - topType = asType; - } - else - { - parent = *it; - break; - } - } - - if (!parent) - return result; - - if (AstStatLocal* node = parent->as()) // Try to provide inferred type of the local - { - // Look at which of the variable types we are defining - for (size_t i = 0; i < node->vars.size; i++) - { - AstLocal* var = node->vars.data[i]; - - if (var->annotation && var->annotation->location.containsClosed(position)) - { - if (node->values.size == 0) - break; - - unsigned tailPos = 0; - - // For multiple return values we will try to unpack last function call return type pack - if (i >= node->values.size) - { - tailPos = int(i) - int(node->values.size) + 1; - i = int(node->values.size) - 1; - } - - AstExpr* expr = node->values.data[i]->asExpr(); - - if (!expr) - break; - - TypeId inferredType = nullptr; - - if (AstExprCall* exprCall = expr->as()) - { - if (auto it = module.astTypes.find(exprCall->func)) - { - if (const FunctionType* ftv = get(follow(*it))) - { - if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) - inferredType = *ty; - } - } - } - else - { - if (tailPos != 0) - break; - - if (auto it = module.astTypes.find(expr)) - inferredType = *it; - } - - if (inferredType) - tryAddTypeCorrectSuggestion(result, startScope, topType, inferredType, position); - - break; - } - } - } - else if (AstExprFunction* node = parent->as()) - { - // For lookup inside expected function type if that's available - auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* - { - auto it = module.astExpectedTypes.find(expr); - - if (!it) - return nullptr; - - TypeId ty = follow(*it); - - if (const FunctionType* ftv = get(ty)) - return ftv; - - // Handle optional function type - if (const UnionType* utv = get(ty)) - { - return returnFirstNonnullOptionOfType(utv).value_or(nullptr); - } - - return nullptr; - }; - - // Find which argument type we are defining - for (size_t i = 0; i < node->args.size; i++) - { - AstLocal* arg = node->args.data[i]; - - if (arg->annotation && arg->annotation->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, i)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - // Otherwise, try to use the type inferred by typechecker - else if (auto inferredType = getLocalTypeInScopeAt(module, position, arg)) - { - tryAddTypeCorrectSuggestion(result, startScope, topType, *inferredType, position); - } - - break; - } - } - - if (AstTypePack* argTp = node->varargAnnotation) - { - if (auto variadic = argTp->as()) - { - if (variadic->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, ~0u)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - } - } - } - - if (!node->returnAnnotation) - return result; - - for (size_t i = 0; i < node->returnAnnotation->types.size; i++) - { - AstType* ret = node->returnAnnotation->types.data[i]; - - if (ret->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - - // TODO: with additional type information, we could suggest inferred return type here - break; - } - } - - if (AstTypePack* retTp = node->returnAnnotation->tailType) - { - if (auto variadic = retTp->as()) - { - if (variadic->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - } - } - } - } - - return result; -} - -static bool isInLocalNames(const std::vector& ancestry, Position position) -{ - for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) - { - if (auto statLocal = (*iter)->as()) - { - for (auto var : statLocal->vars) - { - if (var->location.containsClosed(position)) - { - return true; - } - } - } - else if (auto funcExpr = (*iter)->as()) - { - if (funcExpr->argLocation && funcExpr->argLocation->contains(position)) - { - return true; - } - } - else if (auto localFunc = (*iter)->as()) - { - return localFunc->name->location.containsClosed(position); - } - else if (auto block = (*iter)->as()) - { - if (block->body.size > 0) - { - return false; - } - } - else if ((*iter)->asStat()) - { - return false; - } - } - return false; -} - -static bool isIdentifier(AstNode* node) -{ - return node->is() || node->is(); -} - -static bool isBeingDefined(const std::vector& ancestry, const Symbol& symbol) -{ - // Current set of rules only check for local binding match - if (!symbol.local) - return false; - - for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) - { - if (auto statLocal = (*iter)->as()) - { - for (auto var : statLocal->vars) - { - if (symbol.local == var) - return true; - } - } - } - - return false; -} - -template -T* extractStat(const std::vector& ancestry) -{ - AstNode* node = ancestry.size() >= 1 ? ancestry.rbegin()[0] : nullptr; - if (!node) - return nullptr; - - if (T* t = node->as()) - return t; - - AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; - if (!parent) - return nullptr; - - AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; - AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; - - if (!grandParent) - return nullptr; - - if (T* t = parent->as(); t && grandParent->is()) - return t; - - if (!greatGrandParent) - return nullptr; - - if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) - return t; - - return nullptr; -} - -static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& binding, Position pos) -{ - if (symbol.local) - return binding.location.end < pos; - - // Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it - return binding.location == Location() || !binding.location.containsClosed(pos); -} - -static AutocompleteEntryMap autocompleteStatement( - const SourceModule& sourceModule, - const Module& module, - const std::vector& ancestry, - Position position -) -{ - // This is inefficient. :( - ScopePtr scope = findScopeAtPosition(module, position); - - AutocompleteEntryMap result; - - if (isInLocalNames(ancestry, position)) - { - autocompleteKeywords(sourceModule, ancestry, position, result); - return result; - } - - while (scope) - { - for (const auto& [name, binding] : scope->bindings) - { - if (!isBindingLegalAtCurrentPosition(name, binding, position)) - continue; - - std::string n = toString(name); - if (!result.count(n)) - result[n] = { - AutocompleteEntryKind::Binding, - binding.typeId, - binding.deprecated, - false, - TypeCorrectKind::None, - std::nullopt, - std::nullopt, - binding.documentationSymbol, - {}, - getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None) - }; - } - - scope = scope->parent; - } - - for (const auto& kw : kStatementStartingKeywords) - result.emplace(kw, AutocompleteEntry{AutocompleteEntryKind::Keyword}); - - for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) - { - if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatIf* statIf = (*it)->as()) - { - bool hasEnd = statIf->thenbody->hasEnd; - if (statIf->elsebody) - { - if (AstStatBlock* elseBlock = statIf->elsebody->as()) - hasEnd = elseBlock->hasEnd; - } - - if (!hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - - if (ancestry.size() >= 2) - { - AstNode* parent = ancestry.rbegin()[1]; - if (AstStatIf* statIf = parent->as()) - { - if (!statIf->elsebody || (statIf->elseLocation && statIf->elseLocation->containsClosed(position))) - { - result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - } - - if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - - if (ancestry.size() >= 4) - { - auto iter = ancestry.rbegin(); - if (AstStatIf* statIf = iter[3]->as(); - statIf != nullptr && !statIf->elsebody && iter[2]->is() && iter[1]->is() && isIdentifier(iter[0])) - { - result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - } - - if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - - return result; -} - -// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) -static bool autocompleteIfElseExpression( - const AstNode* node, - const std::vector& ancestry, - const Position& position, - AutocompleteEntryMap& outResult -) -{ - AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; - if (!parent) - return false; - - if (node->is()) - { - // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else - // expression. - return true; - } - - AstExprIfElse* ifElseExpr = parent->as(); - if (!ifElseExpr || ifElseExpr->condition->location.containsClosed(position)) - { - return false; - } - else if (!ifElseExpr->hasThen) - { - outResult["then"] = {AutocompleteEntryKind::Keyword}; - return true; - } - else if (ifElseExpr->trueExpr->location.containsClosed(position)) - { - return false; - } - else if (!ifElseExpr->hasElse) - { - outResult["else"] = {AutocompleteEntryKind::Keyword}; - outResult["elseif"] = {AutocompleteEntryKind::Keyword}; - return true; - } - else - { - return false; - } -} - -static AutocompleteContext autocompleteExpression( - const SourceModule& sourceModule, - const Module& module, - NotNull builtinTypes, - TypeArena* typeArena, - const std::vector& ancestry, - Position position, - AutocompleteEntryMap& result -) -{ - LUAU_ASSERT(!ancestry.empty()); - - AstNode* node = ancestry.rbegin()[0]; - - if (node->is()) - { - if (auto it = module.astTypes.find(node->asExpr())) - autocompleteProps(module, typeArena, builtinTypes, *it, PropIndexType::Point, ancestry, result); - } - else if (autocompleteIfElseExpression(node, ancestry, position, result)) - return AutocompleteContext::Keyword; - else if (node->is()) - return AutocompleteContext::Unknown; - else - { - // This is inefficient. :( - ScopePtr scope = findScopeAtPosition(module, position); - - while (scope) - { - for (const auto& [name, binding] : scope->bindings) - { - if (!isBindingLegalAtCurrentPosition(name, binding, position)) - continue; - - if (isBeingDefined(ancestry, name)) - continue; - - std::string n = toString(name); - if (!result.count(n)) - { - TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId); - - result[n] = { - AutocompleteEntryKind::Binding, - binding.typeId, - binding.deprecated, - false, - typeCorrect, - std::nullopt, - std::nullopt, - binding.documentationSymbol, - {}, - getParenRecommendation(binding.typeId, ancestry, typeCorrect) - }; - } - } - - scope = scope->parent; - } - - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->nilType); - TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->trueType); - TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->falseType); - TypeCorrectKind correctForFunction = - functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForTrue}; - result["false"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForFalse}; - result["nil"] = {AutocompleteEntryKind::Keyword, builtinTypes->nilType, false, false, correctForNil}; - result["not"] = {AutocompleteEntryKind::Keyword}; - result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; - - if (auto ty = findExpectedTypeAt(module, node, position)) - autocompleteStringSingleton(*ty, true, node, position, result); - } - - return AutocompleteContext::Expression; -} - -static AutocompleteResult autocompleteExpression( - const SourceModule& sourceModule, - const Module& module, - NotNull builtinTypes, - TypeArena* typeArena, - const std::vector& ancestry, - Position position -) -{ - AutocompleteEntryMap result; - AutocompleteContext context = autocompleteExpression(sourceModule, module, builtinTypes, typeArena, ancestry, position, result); - return {result, ancestry, context}; -} - -static std::optional getMethodContainingClass(const ModulePtr& module, AstExpr* funcExpr) -{ - AstExpr* parentExpr = nullptr; - if (auto indexName = funcExpr->as()) - { - parentExpr = indexName->expr; - } - else if (auto indexExpr = funcExpr->as()) - { - parentExpr = indexExpr->expr; - } - else - { - return std::nullopt; - } - - auto parentIt = module->astTypes.find(parentExpr); - if (!parentIt) - { - return std::nullopt; - } - - Luau::TypeId parentType = Luau::follow(*parentIt); - - if (auto parentClass = Luau::get(parentType)) - { - return parentClass; - } - - if (auto parentUnion = Luau::get(parentType)) - { - return returnFirstNonnullOptionOfType(parentUnion); - } - - return std::nullopt; -} - -static bool stringPartOfInterpString(const AstNode* node, Position position) -{ - const AstExprInterpString* interpString = node->as(); - if (!interpString) - { - return false; - } - - for (const AstExpr* expression : interpString->expressions) - { - if (expression->location.containsClosed(position)) - { - return false; - } - } - - return true; -} - -static bool isSimpleInterpolatedString(const AstNode* node) -{ - const AstExprInterpString* interpString = node->as(); - return interpString != nullptr && interpString->expressions.size == 0; -} - -static std::optional getStringContents(const AstNode* node) -{ - if (const AstExprConstantString* string = node->as()) - { - return std::string(string->value.data, string->value.size); - } - else if (const AstExprInterpString* interpString = node->as(); interpString && interpString->expressions.size == 0) - { - LUAU_ASSERT(interpString->strings.size == 1); - return std::string(interpString->strings.data->data, interpString->strings.data->size); - } - else - { - return std::nullopt; - } -} - -static std::optional convertRequireSuggestionsToAutocompleteEntryMap(std::optional suggestions) -{ - if (!suggestions) - return std::nullopt; - - AutocompleteEntryMap result; - for (const RequireSuggestion& suggestion : *suggestions) - { - result[suggestion] = {AutocompleteEntryKind::RequirePath}; - } - return result; -} - -static std::optional autocompleteStringParams( - const SourceModule& sourceModule, - const ModulePtr& module, - const std::vector& nodes, - Position position, - FileResolver* fileResolver, - StringCompletionCallback callback -) -{ - if (nodes.size() < 2) - { - return std::nullopt; - } - - if (!nodes.back()->is() && !isSimpleInterpolatedString(nodes.back()) && !nodes.back()->is()) - { - return std::nullopt; - } - - if (!nodes.back()->is()) - { - if (nodes.back()->location.end == position || nodes.back()->location.begin == position) - { - return std::nullopt; - } - } - - AstExprCall* candidate = nodes.at(nodes.size() - 2)->as(); - if (!candidate) - { - return std::nullopt; - } - - // HACK: All current instances of 'magic string' params are the first parameter of their functions, - // so we encode that here rather than putting a useless member on the FunctionType struct. - if (candidate->args.size > 1 && !candidate->args.data[0]->location.contains(position)) - { - return std::nullopt; - } - - auto it = module->astTypes.find(candidate->func); - if (!it) - { - return std::nullopt; - } - - std::optional candidateString = getStringContents(nodes.back()); - - auto performCallback = [&](const FunctionType* funcType) -> std::optional - { - for (const std::string& tag : funcType->tags) - { - if (FFlag::AutocompleteRequirePathSuggestions) - { - if (tag == kRequireTagName && fileResolver) - { - return convertRequireSuggestionsToAutocompleteEntryMap(fileResolver->getRequireSuggestions(module->name, candidateString)); - } - } - if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) - { - return ret; - } - } - return std::nullopt; - }; - - auto followedId = Luau::follow(*it); - if (auto functionType = Luau::get(followedId)) - { - return performCallback(functionType); - } - - if (auto intersect = Luau::get(followedId)) - { - for (TypeId part : intersect->parts) - { - if (auto candidateFunctionType = Luau::get(part)) - { - if (std::optional ret = performCallback(candidateFunctionType)) - { - return ret; - } - } - } - } - - return std::nullopt; -} - -static AutocompleteResult autocompleteWhileLoopKeywords(std::vector ancestry) -{ - AutocompleteEntryMap ret; - ret["do"] = {AutocompleteEntryKind::Keyword}; - ret["and"] = {AutocompleteEntryKind::Keyword}; - ret["or"] = {AutocompleteEntryKind::Keyword}; - return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword}; -} - -static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) -{ - std::string result = "function("; - - auto [args, tail] = Luau::flatten(funcTy.argTypes); - - bool first = true; - // Skip the implicit 'self' argument if call is indexed with ':' - for (size_t argIdx = 0; argIdx < args.size(); ++argIdx) - { - if (!first) - result += ", "; - else - first = false; - - std::string name; - if (argIdx < funcTy.argNames.size() && funcTy.argNames[argIdx]) - name = funcTy.argNames[argIdx]->name; - else - name = "a" + std::to_string(argIdx); - - if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) - result += name + ": " + *type; - else - result += name; - } - - if (tail && (Luau::isVariadic(*tail) || Luau::get(Luau::follow(*tail)))) - { - if (!first) - result += ", "; - - std::optional varArgType; - if (const VariadicTypePack* pack = get(follow(*tail))) - { - if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) - varArgType = std::move(res); - } - - if (varArgType) - result += "...: " + *varArgType; - else - result += "..."; - } - - result += ")"; - - auto [rets, retTail] = Luau::flatten(funcTy.retTypes); - if (const size_t totalRetSize = rets.size() + (retTail ? 1 : 0); totalRetSize > 0) - { - if (std::optional returnTypes = tryToStringDetailed(scope, funcTy.retTypes, true)) - { - result += ": "; - bool wrap = totalRetSize != 1; - if (wrap) - result += "("; - result += *returnTypes; - if (wrap) - result += ")"; - } - } - result += " end"; - return result; -} - -static std::optional makeAnonymousAutofilled( - const ModulePtr& module, - Position position, - const AstNode* node, - const std::vector& ancestry -) -{ - const AstExprCall* call = node->as(); - if (!call && ancestry.size() > 1) - call = ancestry[ancestry.size() - 2]->as(); - - if (!call) - return std::nullopt; - - if (!call->location.containsClosed(position) || call->func->location.containsClosed(position)) - return std::nullopt; - - TypeId* typeIter = module->astTypes.find(call->func); - if (!typeIter) - return std::nullopt; - - const FunctionType* outerFunction = get(follow(*typeIter)); - if (!outerFunction) - return std::nullopt; - - size_t argument = 0; - for (size_t i = 0; i < call->args.size; ++i) - { - if (call->args.data[i]->location.containsClosed(position)) - { - argument = i; - break; - } - } - - if (call->self) - argument++; - - std::optional argType; - auto [args, tail] = flatten(outerFunction->argTypes); - if (argument < args.size()) - argType = args[argument]; - - if (!argType) - return std::nullopt; - - TypeId followed = follow(*argType); - const FunctionType* type = get(followed); - if (!type) - { - if (const UnionType* unionType = get(followed)) - { - if (std::optional nonnullFunction = returnFirstNonnullOptionOfType(unionType)) - type = *nonnullFunction; - } - } - - if (!type) - return std::nullopt; - - const ScopePtr scope = findScopeAtPosition(*module, position); - if (!scope) - return std::nullopt; - - AutocompleteEntry entry; - entry.kind = AutocompleteEntryKind::GeneratedFunction; - entry.typeCorrect = TypeCorrectKind::Correct; - entry.type = argType; - entry.insertText = makeAnonymous(scope, *type); - return std::make_optional(std::move(entry)); -} - -static AutocompleteResult autocomplete( - const SourceModule& sourceModule, - const ModulePtr& module, - NotNull builtinTypes, - TypeArena* typeArena, - Scope* globalScope, - Position position, - FileResolver* fileResolver, - StringCompletionCallback callback -) -{ - if (isWithinComment(sourceModule, position)) - return {}; - - std::vector ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position); - LUAU_ASSERT(!ancestry.empty()); - AstNode* node = ancestry.back(); - - AstExprConstantNil dummy{Location{}}; - AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; - - // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node - if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) - { - ancestry.pop_back(); - - node = ancestry.back(); - parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; - } - - if (auto indexName = node->as()) - { - auto it = module->astTypes.find(indexName->expr); - if (!it) - return {}; - - TypeId ty = follow(*it); - PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - - return {autocompleteProps(*module, typeArena, builtinTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; - } - else if (auto typeReference = node->as()) - { - if (typeReference->prefix) - return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), ancestry, AutocompleteContext::Type}; - else - return {autocompleteTypeNames(*module, position, ancestry), ancestry, AutocompleteContext::Type}; - } - else if (node->is()) - { - return {autocompleteTypeNames(*module, position, ancestry), ancestry, AutocompleteContext::Type}; - } - else if (AstStatLocal* statLocal = node->as()) - { - if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) - return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Unknown}; - else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - else - return {}; - } - - else if (AstStatFor* statFor = extractStat(ancestry)) - { - if (!statFor->hasDo || position < statFor->doLocation.begin) - { - if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || - (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - return {}; - } - - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - } - - else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) - { - if (!statForIn->hasIn || position <= statForIn->inLocation.begin) - { - AstLocal* lastName = statForIn->vars.data[statForIn->vars.size - 1]; - if (lastName->name == kParseNameError || lastName->location.containsClosed(position)) - { - // Here we are either working with a missing binding (as would be the case in a bare "for" keyword) or - // the cursor is still touching a binding name. The user is still typing a new name, so we should not offer - // any suggestions. - return {}; - } - - return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } - - if (!statForIn->hasDo || position <= statForIn->doLocation.begin) - { - LUAU_ASSERT(statForIn->values.size > 0); - AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; - - if (lastExpr->location.containsClosed(position)) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (position > lastExpr->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - - return {}; // Not sure what this means - } - } - else if (AstStatForIn* statForIn = extractStat(ancestry)) - { - // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. - // ex "for f in f do" - if (!statForIn->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - } - - else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) - { - if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) - { - return autocompleteWhileLoopKeywords(ancestry); - } - - if (!statWhile->hasDo || position < statWhile->doLocation.begin) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (statWhile->hasDo && position > statWhile->doLocation.end) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - } - - else if (AstStatWhile* statWhile = extractStat(ancestry); - (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition && - !statWhile->condition->location.containsClosed(position))) - { - return autocompleteWhileLoopKeywords(ancestry); - } - else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) - { - return { - {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, - ancestry, - AutocompleteContext::Keyword - }; - } - else if (AstStatIf* statIf = parent->as(); statIf && node->is()) - { - if (statIf->condition->is()) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } - else if (AstStatIf* statIf = extractStat(ancestry); statIf && - (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && - (statIf->condition && !statIf->condition->location.containsClosed(position))) - { - AutocompleteEntryMap ret; - ret["then"] = {AutocompleteEntryKind::Keyword}; - ret["and"] = {AutocompleteEntryKind::Keyword}; - ret["or"] = {AutocompleteEntryKind::Keyword}; - return {std::move(ret), ancestry, AutocompleteContext::Keyword}; - } - else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - else if (AstExprTable* exprTable = parent->as(); - exprTable && (node->is() || node->is() || node->is())) - { - for (const auto& [kind, key, value] : exprTable->items) - { - // If item doesn't have a key, maybe the value is actually the key - if (key ? key == node : node->is() && value == node) - { - if (auto it = module->astExpectedTypes.find(exprTable)) - { - auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); - - if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*nodeIt, !node->is(), node, position, result); - - if (!key) - { - // If there is "no key," it may be that the user - // intends for the current token to be the key, but - // has yet to type the `=` sign. - // - // If the key type is a union of singleton strings, - // suggest those too. - if (auto ttv = get(follow(*it)); ttv && ttv->indexer) - { - autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); - } - } - - // Remove keys that are already completed - for (const auto& item : exprTable->items) - { - if (!item.key) - continue; - - if (auto stringKey = item.key->as()) - result.erase(std::string(stringKey->value.data, stringKey->value.size)); - } - - // If we know for sure that a key is being written, do not offer general expression suggestions - if (!key) - autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position, result); - - return {result, ancestry, AutocompleteContext::Property}; - } - - break; - } - } - } - else if (AstExprTable* exprTable = node->as()) - { - AutocompleteEntryMap result; - - if (auto it = module->astExpectedTypes.find(exprTable)) - { - result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); - - // If the key type is a union of singleton strings, - // suggest those too. - if (auto ttv = get(follow(*it)); ttv && ttv->indexer) - { - autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); - } - - // Remove keys that are already completed - for (const auto& item : exprTable->items) - { - if (!item.key) - continue; - - if (auto stringKey = item.key->as()) - result.erase(std::string(stringKey->value.data, stringKey->value.size)); - } - } - - // Also offer general expression suggestions - autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position, result); - - return {result, ancestry, AutocompleteContext::Property}; - } - else if (isIdentifier(node) && (parent->is() || parent->is())) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - - if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, fileResolver, callback)) - { - return {*ret, ancestry, AutocompleteContext::String}; - } - else if (node->is() || isSimpleInterpolatedString(node)) - { - AutocompleteEntryMap result; - - if (auto it = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*it, false, node, position, result); - - if (ancestry.size() >= 2) - { - if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) - { - if (auto it = module->astTypes.find(idxExpr->expr)) - autocompleteProps(*module, typeArena, builtinTypes, follow(*it), PropIndexType::Point, ancestry, result); - } - else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) - { - if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) - { - if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) - autocompleteStringSingleton(*it, false, node, position, result); - } - } - } - - return {result, ancestry, AutocompleteContext::String}; - } - else if (stringPartOfInterpString(node, position)) - { - // We're not a simple interpolated string, we're something like `a{"b"}@1`, and we - // can't know what to format to - AutocompleteEntryMap map; - return {map, ancestry, AutocompleteContext::String}; - } - - if (node->is()) - return {}; - - if (node->asExpr()) - { - AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) - ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); - return ret; - } - else if (node->asStat()) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - - return {}; -} - AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { const SourceModule* sourceModule = frontend.getSourceModule(moduleName); @@ -2019,7 +36,13 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName globalScope = frontend.globalsForAutocomplete.globalScope.get(); TypeArena typeArena; - return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, frontend.fileResolver, callback); + if (isWithinComment(*sourceModule, position)) + return {}; + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*sourceModule, position); + LUAU_ASSERT(!ancestry.empty()); + ScopePtr startScope = findScopeAtPosition(*module, position); + return autocomplete_(module, builtinTypes, &typeArena, ancestry, globalScope, startScope, position, frontend.fileResolver, callback); } } // namespace Luau diff --git a/Analysis/src/AutocompleteCore.cpp b/Analysis/src/AutocompleteCore.cpp new file mode 100644 index 00000000..ee045771 --- /dev/null +++ b/Analysis/src/AutocompleteCore.cpp @@ -0,0 +1,2002 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "AutocompleteCore.h" + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/AutocompleteTypes.h" + +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/FileResolver.h" +#include "Luau/Frontend.h" +#include "Luau/ToString.h" +#include "Luau/Subtyping.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypePack.h" + +#include +#include +#include + +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(AutocompleteRequirePathSuggestions2) +LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) +LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTypeInferRecursionLimit) + +LUAU_FASTFLAGVARIABLE(LuauAutocompleteRefactorsForIncrementalAutocomplete) + +static const std::unordered_set kStatementStartingKeywords = + {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; + +namespace Luau +{ + +static bool alreadyHasParens(const std::vector& nodes) +{ + auto iter = nodes.rbegin(); + while (iter != nodes.rend() && + ((*iter)->is() || (*iter)->is() || (*iter)->is() || (*iter)->is())) + { + iter++; + } + + if (iter == nodes.rend() || iter == nodes.rbegin()) + { + return false; + } + + if (AstExprCall* call = (*iter)->as()) + { + return call->func == *(iter - 1); + } + + return false; +} + +static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionType* func, const std::vector& nodes) +{ + if (alreadyHasParens(nodes)) + { + return ParenthesesRecommendation::None; + } + + auto idxExpr = nodes.back()->as(); + bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; + auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); + + if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) + return ParenthesesRecommendation::CursorInside; + + bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); + return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; +} + +static ParenthesesRecommendation getParenRecommendationForIntersect(const IntersectionType* intersect, const std::vector& nodes) +{ + ParenthesesRecommendation rec = ParenthesesRecommendation::None; + for (Luau::TypeId partId : intersect->parts) + { + if (auto partFunc = Luau::get(partId)) + { + rec = std::max(rec, getParenRecommendationForFunc(partFunc, nodes)); + } + else + { + return ParenthesesRecommendation::None; + } + } + return rec; +} + +static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::vector& nodes, TypeCorrectKind typeCorrect) +{ + // If element is already type-correct, even a function should be inserted without parenthesis + if (typeCorrect == TypeCorrectKind::Correct) + return ParenthesesRecommendation::None; + + id = Luau::follow(id); + if (auto func = get(id)) + { + return getParenRecommendationForFunc(func, nodes); + } + else if (auto intersect = get(id)) + { + return getParenRecommendationForIntersect(intersect, nodes); + } + return ParenthesesRecommendation::None; +} + +static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) +{ + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + // Extra care for first function call argument location + // When we don't have anything inside () yet, we also don't have an AST node to base our lookup + if (AstExprCall* exprCall = expr->as()) + { + if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) + { + auto it = module.astTypes.find(exprCall->func); + + if (!it) + return std::nullopt; + + const FunctionType* ftv = get(follow(*it)); + + if (!ftv) + return std::nullopt; + + auto [head, tail] = flatten(ftv->argTypes); + unsigned index = exprCall->self ? 1 : 0; + + if (index < head.size()) + return head[index]; + + return std::nullopt; + } + } + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + return *it; +} + +static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena, NotNull builtinTypes) +{ + InternalErrorReporter iceReporter; + UnifierSharedState unifierState(&iceReporter); + Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; + + if (FFlag::LuauSolverV2) + { + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{ + NotNull{&iceReporter}, NotNull{&limits} + }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime + + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + + Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter}}; + + return subtyping.isSubtype(subTy, superTy, scope).isSubtype; + } + else + { + Unifier unifier(NotNull{&normalizer}, scope, Location(), Variance::Covariant); + + // Cost of normalization can be too high for autocomplete response time requirements + unifier.normalize = false; + unifier.checkInhabited = false; + + return unifier.canUnify(subTy, superTy).empty(); + } +} + +static TypeCorrectKind checkTypeCorrectKind( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + AstNode* node, + Position position, + TypeId ty +) +{ + ty = follow(ty); + + LUAU_ASSERT(module.hasModuleScope()); + + NotNull moduleScope{module.getModuleScope().get()}; + + auto typeAtPosition = findExpectedTypeAt(module, node, position); + + if (!typeAtPosition) + return TypeCorrectKind::None; + + TypeId expectedType = follow(*typeAtPosition); + + auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) + { + if (std::optional firstRetTy = first(ftv->retTypes)) + return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); + + return false; + }; + + // We also want to suggest functions that return compatible result + if (const FunctionType* ftv = get(ty); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + else if (const IntersectionType* itv = get(ty)) + { + for (TypeId id : itv->parts) + { + id = follow(id); + + if (const FunctionType* ftv = get(id); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + } + } + + return checkTypeMatch(ty, expectedType, moduleScope, typeArena, builtinTypes) ? TypeCorrectKind::Correct : TypeCorrectKind::None; +} + +enum class PropIndexType +{ + Point, + Colon, + Key, +}; + +static void autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId rootTy, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes, + AutocompleteEntryMap& result, + std::unordered_set& seen, + std::optional containingClass = std::nullopt +) +{ + rootTy = follow(rootTy); + ty = follow(ty); + + if (seen.count(ty)) + return; + seen.insert(ty); + + auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) + { + if (indexType == PropIndexType::Key) + return false; + + bool calledWithSelf = indexType == PropIndexType::Colon; + + auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) + { + // Strong match with definition is a success + if (calledWithSelf == ftv->hasSelf) + return true; + + // Calls on classes require strict match between how function is declared and how it's called + if (get(rootTy)) + return false; + + // When called with ':', but declared without 'self', it is invalid if a function has incompatible first argument or no arguments at all + // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible + if (std::optional firstArgTy = first(ftv->argTypes)) + { + if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes)) + return calledWithSelf; + } + + return !calledWithSelf; + }; + + if (const FunctionType* ftv = get(type)) + return !isCompatibleCall(ftv); + + // For intersections, any part that is successful makes the whole call successful + if (const IntersectionType* itv = get(type)) + { + for (auto subType : itv->parts) + { + if (const FunctionType* ftv = get(Luau::follow(subType))) + { + if (isCompatibleCall(ftv)) + return false; + } + } + } + + return calledWithSelf; + }; + + auto fillProps = [&](const ClassType::Props& props) + { + for (const auto& [name, prop] : props) + { + // We are walking up the class hierarchy, so if we encounter a property that we have + // already populated, it takes precedence over the property we found just now. + if (result.count(name) == 0 && name != kParseNameError) + { + Luau::TypeId type; + + if (FFlag::LuauSolverV2) + { + if (auto ty = prop.readTy) + type = follow(*ty); + else + continue; + } + else + type = follow(prop.type()); + + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key + ? TypeCorrectKind::Correct + : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); + + ParenthesesRecommendation parens = + indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); + + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Property, + type, + prop.deprecated, + isWrongIndexer(type), + typeCorrect, + containingClass, + &prop, + prop.documentationSymbol, + {}, + parens, + {}, + indexType == PropIndexType::Colon + }; + } + } + }; + + auto fillMetatableProps = [&](const TableType* mtable) + { + auto indexIt = mtable->props.find("__index"); + if (indexIt != mtable->props.end()) + { + TypeId followed = follow(indexIt->second.type()); + if (get(followed) || get(followed)) + { + autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen); + } + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retTypes); + if (indexFunctionResult) + autocompleteProps(module, typeArena, builtinTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + } + } + }; + + if (auto cls = get(ty)) + { + containingClass = containingClass.value_or(cls); + fillProps(cls->props); + if (cls->parent) + autocompleteProps(module, typeArena, builtinTypes, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); + } + else if (auto tbl = get(ty)) + fillProps(tbl->props); + else if (auto mt = get(ty)) + { + autocompleteProps(module, typeArena, builtinTypes, rootTy, mt->table, indexType, nodes, result, seen); + + if (auto mtable = get(follow(mt->metatable))) + fillMetatableProps(mtable); + } + else if (auto i = get(ty)) + { + // Complete all properties in every variant + for (TypeId ty : i->parts) + { + AutocompleteEntryMap inner; + std::unordered_set innerSeen = seen; + + autocompleteProps(module, typeArena, builtinTypes, rootTy, ty, indexType, nodes, inner, innerSeen); + + for (auto& pair : inner) + result.insert(pair); + } + } + else if (auto u = get(ty)) + { + // Complete all properties common to all variants + auto iter = begin(u); + auto endIter = end(u); + + while (iter != endIter) + { + if (isNil(*iter)) + ++iter; + else + break; + } + + if (iter == endIter) + return; + + autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, result, seen); + + ++iter; + + while (iter != endIter) + { + AutocompleteEntryMap inner; + std::unordered_set innerSeen; + + if (isNil(*iter)) + { + ++iter; + continue; + } + + autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, inner, innerSeen); + + std::unordered_set toRemove; + + for (const auto& [k, v] : result) + { + (void)v; + if (!inner.count(k)) + toRemove.insert(k); + } + + for (const std::string& k : toRemove) + result.erase(k); + + ++iter; + } + } + else if (auto pt = get(ty)) + { + if (pt->metatable) + { + if (auto mtable = get(*pt->metatable)) + fillMetatableProps(mtable); + } + } + else if (get(get(ty))) + { + autocompleteProps(module, typeArena, builtinTypes, rootTy, builtinTypes->stringType, indexType, nodes, result, seen); + } +} + +static void autocompleteKeywords(const std::vector& ancestry, Position position, AutocompleteEntryMap& result) +{ + LUAU_ASSERT(!ancestry.empty()); + + AstNode* node = ancestry.back(); + + if (!node->is() && node->asExpr()) + { + // This is not strictly correct. We should recommend `and` and `or` only after + // another expression, not at the start of a new one. We should only recommend + // `not` at the start of an expression. Detecting either case reliably is quite + // complex, however; this is good enough for now. + + // These are not context-sensitive keywords, so we can unconditionally assign. + result["and"] = {AutocompleteEntryKind::Keyword}; + result["or"] = {AutocompleteEntryKind::Keyword}; + result["not"] = {AutocompleteEntryKind::Keyword}; + } +} + +static void autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes, + AutocompleteEntryMap& result +) +{ + std::unordered_set seen; + autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen); +} + +AutocompleteEntryMap autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes +) +{ + AutocompleteEntryMap result; + autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result); + return result; +} + +AutocompleteEntryMap autocompleteModuleTypes(const Module& module, const ScopePtr& scopeAtPosition, Position position, std::string_view moduleName) +{ + AutocompleteEntryMap result; + ScopePtr startScope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + for (ScopePtr& scope = startScope; scope; scope = scope->parent) + { + if (auto it = scope->importedTypeBindings.find(std::string(moduleName)); it != scope->importedTypeBindings.end()) + { + for (const auto& [name, ty] : it->second) + result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type}; + + break; + } + } + + return result; +} + +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result) +{ + if (position == node->location.begin || position == node->location.end) + { + if (auto str = node->as(); str && str->isQuoted()) + return; + else if (node->is()) + return; + } + + auto formatKey = [addQuotes](const std::string& key) + { + if (addQuotes) + return "\"" + escape(key) + "\""; + + return escape(key); + }; + + ty = follow(ty); + + if (auto ss = get(get(ty))) + { + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + else if (auto uty = get(ty)) + { + for (auto el : uty) + { + if (auto ss = get(get(el))) + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + } +}; + +static bool canSuggestInferredType(ScopePtr scope, TypeId ty) +{ + ty = follow(ty); + + // No point in suggesting 'any', invalid to suggest others + if (get(ty) || get(ty) || get(ty) || get(ty)) + return false; + + // No syntax for unnamed tables with a metatable + if (get(ty)) + return false; + + if (const TableType* ttv = get(ty)) + { + if (ttv->name) + return true; + + if (ttv->syntheticName) + return false; + } + + // We might still have a type with cycles or one that is too long, we'll check that later + return true; +} + +// Walk complex type trees to find the element that is being edited +static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position); + +static std::optional findTypeElementAt(const AstTypeList& astTypeList, TypePackId tp, Position position) +{ + for (size_t i = 0; i < astTypeList.types.size; i++) + { + AstType* type = astTypeList.types.data[i]; + + if (type->location.containsClosed(position)) + { + auto [head, _] = flatten(tp); + + if (i < head.size()) + return findTypeElementAt(type, head[i], position); + } + } + + if (AstTypePack* argTp = astTypeList.tailType) + { + if (auto variadic = argTp->as()) + { + if (variadic->location.containsClosed(position)) + { + auto [_, tail] = flatten(tp); + + if (tail) + { + if (const VariadicTypePack* vtp = get(follow(*tail))) + return findTypeElementAt(variadic->variadicType, vtp->ty, position); + } + } + } + } + + return {}; +} + +static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position) +{ + ty = follow(ty); + + if (astType->is()) + return ty; + + if (astType->is()) + return ty; + + if (AstTypeFunction* type = astType->as()) + { + const FunctionType* ftv = get(ty); + + if (!ftv) + return {}; + + if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) + return element; + + if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) + return element; + } + + // It's possible to walk through other types like intrsection and unions if we find value in doing that + return {}; +} + +std::optional getLocalTypeInScopeAt(const Module& module, const ScopePtr& scopeAtPosition, Position position, AstLocal* local) +{ + if (ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position)) + { + for (const auto& [name, binding] : scope->bindings) + { + if (name == local) + return binding.typeId; + } + } + + return {}; +} + +template +static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) +{ + ToStringOptions opts; + opts.useLineBreaks = false; + opts.hideTableKind = true; + opts.functionTypeArguments = functionTypeArguments; + opts.scope = scope; + ToStringResult name = toStringDetailed(ty, opts); + + if (name.error || name.invalid || name.cycle || name.truncated) + return std::nullopt; + + return name.name; +} + +static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool functionTypeArguments = false) +{ + if (!canSuggestInferredType(scope, ty)) + return std::nullopt; + + return tryToStringDetailed(scope, ty, functionTypeArguments); +} + +static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) +{ + std::optional ty; + + if (topType) + ty = findTypeElementAt(topType, inferredType, position); + else + ty = inferredType; + + if (!ty) + return false; + + if (auto name = tryGetTypeNameInScope(scope, *ty)) + { + if (auto it = result.find(*name); it != result.end()) + it->second.typeCorrect = TypeCorrectKind::Correct; + else + result[*name] = AutocompleteEntry{AutocompleteEntryKind::Type, *ty, false, false, TypeCorrectKind::Correct}; + + return true; + } + + return false; +} + +static std::optional tryGetTypePackTypeAt(TypePackId tp, size_t index) +{ + auto [tpHead, tpTail] = flatten(tp); + + if (index < tpHead.size()) + return tpHead[index]; + + // Infinite tail + if (tpTail) + { + if (const VariadicTypePack* vtp = get(follow(*tpTail))) + return vtp->ty; + } + + return {}; +} + +template +std::optional returnFirstNonnullOptionOfType(const UnionType* utv) +{ + std::optional ret; + for (TypeId subTy : utv) + { + if (isNil(subTy)) + continue; + + if (const T* ftv = get(follow(subTy))) + { + if (ret.has_value()) + { + return std::nullopt; + } + ret = ftv; + } + else + { + return std::nullopt; + } + } + return ret; +} + +static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) +{ + auto typeAtPosition = findExpectedTypeAt(module, node, position); + + if (!typeAtPosition) + return std::nullopt; + + TypeId expectedType = follow(*typeAtPosition); + + if (get(expectedType)) + return true; + + if (const IntersectionType* itv = get(expectedType)) + { + return std::all_of( + begin(itv->parts), + end(itv->parts), + [](auto&& ty) + { + return get(Luau::follow(ty)) != nullptr; + } + ); + } + + if (const UnionType* utv = get(expectedType)) + return returnFirstNonnullOptionOfType(utv).has_value(); + + return false; +} + +AutocompleteEntryMap autocompleteTypeNames( + const Module& module, + const ScopePtr& scopeAtPosition, + Position& position, + const std::vector& ancestry +) +{ + AutocompleteEntryMap result; + + ScopePtr startScope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + + for (ScopePtr scope = startScope; scope; scope = scope->parent) + { + for (const auto& [name, ty] : scope->exportedTypeBindings) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Type, + ty.type, + false, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + ty.type->documentationSymbol + }; + } + + for (const auto& [name, ty] : scope->privateTypeBindings) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Type, + ty.type, + false, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + ty.type->documentationSymbol + }; + } + + for (const auto& [name, _] : scope->importedTypeBindings) + { + if (auto binding = scope->linearSearchForBinding(name, true)) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{AutocompleteEntryKind::Module, binding->typeId}; + } + } + } + + AstNode* parent = nullptr; + AstType* topType = nullptr; // TODO: rename? + + for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) + { + if (AstType* asType = (*it)->asType()) + { + topType = asType; + } + else + { + parent = *it; + break; + } + } + + if (!parent) + return result; + + if (AstStatLocal* node = parent->as()) // Try to provide inferred type of the local + { + // Look at which of the variable types we are defining + for (size_t i = 0; i < node->vars.size; i++) + { + AstLocal* var = node->vars.data[i]; + + if (var->annotation && var->annotation->location.containsClosed(position)) + { + if (node->values.size == 0) + break; + + unsigned tailPos = 0; + + // For multiple return values we will try to unpack last function call return type pack + if (i >= node->values.size) + { + tailPos = int(i) - int(node->values.size) + 1; + i = int(node->values.size) - 1; + } + + AstExpr* expr = node->values.data[i]->asExpr(); + + if (!expr) + break; + + TypeId inferredType = nullptr; + + if (AstExprCall* exprCall = expr->as()) + { + if (auto it = module.astTypes.find(exprCall->func)) + { + if (const FunctionType* ftv = get(follow(*it))) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) + inferredType = *ty; + } + } + } + else + { + if (tailPos != 0) + break; + + if (auto it = module.astTypes.find(expr)) + inferredType = *it; + } + + if (inferredType) + tryAddTypeCorrectSuggestion(result, startScope, topType, inferredType, position); + + break; + } + } + } + else if (AstExprFunction* node = parent->as()) + { + // For lookup inside expected function type if that's available + auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* + { + auto it = module.astExpectedTypes.find(expr); + + if (!it) + return nullptr; + + TypeId ty = follow(*it); + + if (const FunctionType* ftv = get(ty)) + return ftv; + + // Handle optional function type + if (const UnionType* utv = get(ty)) + { + return returnFirstNonnullOptionOfType(utv).value_or(nullptr); + } + + return nullptr; + }; + + // Find which argument type we are defining + for (size_t i = 0; i < node->args.size; i++) + { + AstLocal* arg = node->args.data[i]; + + if (arg->annotation && arg->annotation->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, i)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + // Otherwise, try to use the type inferred by typechecker + else if (auto inferredType = getLocalTypeInScopeAt(module, scopeAtPosition, position, arg)) + { + tryAddTypeCorrectSuggestion(result, startScope, topType, *inferredType, position); + } + + break; + } + } + + if (AstTypePack* argTp = node->varargAnnotation) + { + if (auto variadic = argTp->as()) + { + if (variadic->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, ~0u)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + } + } + } + + if (!node->returnAnnotation) + return result; + + for (size_t i = 0; i < node->returnAnnotation->types.size; i++) + { + AstType* ret = node->returnAnnotation->types.data[i]; + + if (ret->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + + // TODO: with additional type information, we could suggest inferred return type here + break; + } + } + + if (AstTypePack* retTp = node->returnAnnotation->tailType) + { + if (auto variadic = retTp->as()) + { + if (variadic->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + } + } + } + } + + return result; +} + +static bool isInLocalNames(const std::vector& ancestry, Position position) +{ + for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) + { + if (auto statLocal = (*iter)->as()) + { + for (auto var : statLocal->vars) + { + if (var->location.containsClosed(position)) + { + return true; + } + } + } + else if (auto funcExpr = (*iter)->as()) + { + if (funcExpr->argLocation && funcExpr->argLocation->contains(position)) + { + return true; + } + } + else if (auto localFunc = (*iter)->as()) + { + return localFunc->name->location.containsClosed(position); + } + else if (auto block = (*iter)->as()) + { + if (block->body.size > 0) + { + return false; + } + } + else if ((*iter)->asStat()) + { + return false; + } + } + return false; +} + +static bool isIdentifier(AstNode* node) +{ + return node->is() || node->is(); +} + +static bool isBeingDefined(const std::vector& ancestry, const Symbol& symbol) +{ + // Current set of rules only check for local binding match + if (!symbol.local) + return false; + + for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) + { + if (auto statLocal = (*iter)->as()) + { + for (auto var : statLocal->vars) + { + if (symbol.local == var) + return true; + } + } + } + + return false; +} + +template +T* extractStat(const std::vector& ancestry) +{ + AstNode* node = ancestry.size() >= 1 ? ancestry.rbegin()[0] : nullptr; + if (!node) + return nullptr; + + if (T* t = node->as()) + return t; + + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; + if (!parent) + return nullptr; + + AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; + AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; + + if (!grandParent) + return nullptr; + + if (T* t = parent->as(); t && grandParent->is()) + return t; + + if (!greatGrandParent) + return nullptr; + + if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) + return t; + + return nullptr; +} + +static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& binding, Position pos) +{ + if (symbol.local) + return binding.location.end < pos; + + // Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it + return binding.location == Location() || !binding.location.containsClosed(pos); +} + +static AutocompleteEntryMap autocompleteStatement( + const Module& module, + const std::vector& ancestry, + const ScopePtr& scopeAtPosition, + Position& position +) +{ + // This is inefficient. :( + ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + + AutocompleteEntryMap result; + + if (isInLocalNames(ancestry, position)) + { + autocompleteKeywords(ancestry, position, result); + return result; + } + + while (scope) + { + for (const auto& [name, binding] : scope->bindings) + { + if (!isBindingLegalAtCurrentPosition(name, binding, position)) + continue; + + std::string n = toString(name); + if (!result.count(n)) + result[n] = { + AutocompleteEntryKind::Binding, + binding.typeId, + binding.deprecated, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + binding.documentationSymbol, + {}, + getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None) + }; + } + + scope = scope->parent; + } + + for (const auto& kw : kStatementStartingKeywords) + result.emplace(kw, AutocompleteEntry{AutocompleteEntryKind::Keyword}); + + for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) + { + if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstStatIf* statIf = (*it)->as()) + { + bool hasEnd = statIf->thenbody->hasEnd; + if (statIf->elsebody) + { + if (AstStatBlock* elseBlock = statIf->elsebody->as()) + hasEnd = elseBlock->hasEnd; + } + + if (!hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + + if (ancestry.size() >= 2) + { + AstNode* parent = ancestry.rbegin()[1]; + if (AstStatIf* statIf = parent->as()) + { + if (!statIf->elsebody || (statIf->elseLocation && statIf->elseLocation->containsClosed(position))) + { + result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + } + + if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + + if (ancestry.size() >= 4) + { + auto iter = ancestry.rbegin(); + if (AstStatIf* statIf = iter[3]->as(); + statIf != nullptr && !statIf->elsebody && iter[2]->is() && iter[1]->is() && isIdentifier(iter[0])) + { + result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + } + + if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + + return result; +} + +// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) +static bool autocompleteIfElseExpression( + const AstNode* node, + const std::vector& ancestry, + const Position& position, + AutocompleteEntryMap& outResult +) +{ + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; + if (!parent) + return false; + + if (node->is()) + { + // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else + // expression. + return true; + } + + AstExprIfElse* ifElseExpr = parent->as(); + if (!ifElseExpr || ifElseExpr->condition->location.containsClosed(position)) + { + return false; + } + else if (!ifElseExpr->hasThen) + { + outResult["then"] = {AutocompleteEntryKind::Keyword}; + return true; + } + else if (ifElseExpr->trueExpr->location.containsClosed(position)) + { + return false; + } + else if (!ifElseExpr->hasElse) + { + outResult["else"] = {AutocompleteEntryKind::Keyword}; + outResult["elseif"] = {AutocompleteEntryKind::Keyword}; + return true; + } + else + { + return false; + } +} + +static AutocompleteContext autocompleteExpression( + const Module& module, + NotNull builtinTypes, + TypeArena* typeArena, + const std::vector& ancestry, + const ScopePtr& scopeAtPosition, + Position position, + AutocompleteEntryMap& result +) +{ + LUAU_ASSERT(!ancestry.empty()); + + AstNode* node = ancestry.rbegin()[0]; + + if (node->is()) + { + if (auto it = module.astTypes.find(node->asExpr())) + autocompleteProps(module, typeArena, builtinTypes, *it, PropIndexType::Point, ancestry, result); + } + else if (autocompleteIfElseExpression(node, ancestry, position, result)) + return AutocompleteContext::Keyword; + else if (node->is()) + return AutocompleteContext::Unknown; + else + { + // This is inefficient. :( + ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + + while (scope) + { + for (const auto& [name, binding] : scope->bindings) + { + if (!isBindingLegalAtCurrentPosition(name, binding, position)) + continue; + + if (isBeingDefined(ancestry, name)) + continue; + + std::string n = toString(name); + if (!result.count(n)) + { + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId); + + result[n] = { + AutocompleteEntryKind::Binding, + binding.typeId, + binding.deprecated, + false, + typeCorrect, + std::nullopt, + std::nullopt, + binding.documentationSymbol, + {}, + getParenRecommendation(binding.typeId, ancestry, typeCorrect) + }; + } + } + + scope = scope->parent; + } + + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->nilType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->falseType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForFalse}; + result["nil"] = {AutocompleteEntryKind::Keyword, builtinTypes->nilType, false, false, correctForNil}; + result["not"] = {AutocompleteEntryKind::Keyword}; + result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + + if (auto ty = findExpectedTypeAt(module, node, position)) + autocompleteStringSingleton(*ty, true, node, position, result); + } + + return AutocompleteContext::Expression; +} + +static AutocompleteResult autocompleteExpression( + const Module& module, + NotNull builtinTypes, + TypeArena* typeArena, + const std::vector& ancestry, + const ScopePtr& scopeAtPosition, + Position position +) +{ + AutocompleteEntryMap result; + AutocompleteContext context = autocompleteExpression(module, builtinTypes, typeArena, ancestry, scopeAtPosition, position, result); + return {result, ancestry, context}; +} + +static std::optional getMethodContainingClass(const ModulePtr& module, AstExpr* funcExpr) +{ + AstExpr* parentExpr = nullptr; + if (auto indexName = funcExpr->as()) + { + parentExpr = indexName->expr; + } + else if (auto indexExpr = funcExpr->as()) + { + parentExpr = indexExpr->expr; + } + else + { + return std::nullopt; + } + + auto parentIt = module->astTypes.find(parentExpr); + if (!parentIt) + { + return std::nullopt; + } + + Luau::TypeId parentType = Luau::follow(*parentIt); + + if (auto parentClass = Luau::get(parentType)) + { + return parentClass; + } + + if (auto parentUnion = Luau::get(parentType)) + { + return returnFirstNonnullOptionOfType(parentUnion); + } + + return std::nullopt; +} + +static bool stringPartOfInterpString(const AstNode* node, Position position) +{ + const AstExprInterpString* interpString = node->as(); + if (!interpString) + { + return false; + } + + for (const AstExpr* expression : interpString->expressions) + { + if (expression->location.containsClosed(position)) + { + return false; + } + } + + return true; +} + +static bool isSimpleInterpolatedString(const AstNode* node) +{ + const AstExprInterpString* interpString = node->as(); + return interpString != nullptr && interpString->expressions.size == 0; +} + +static std::optional getStringContents(const AstNode* node) +{ + if (const AstExprConstantString* string = node->as()) + { + return std::string(string->value.data, string->value.size); + } + else if (const AstExprInterpString* interpString = node->as(); interpString && interpString->expressions.size == 0) + { + LUAU_ASSERT(interpString->strings.size == 1); + return std::string(interpString->strings.data->data, interpString->strings.data->size); + } + else + { + return std::nullopt; + } +} + +static std::optional convertRequireSuggestionsToAutocompleteEntryMap(std::optional suggestions) +{ + if (!suggestions) + return std::nullopt; + + AutocompleteEntryMap result; + for (const RequireSuggestion& suggestion : *suggestions) + { + AutocompleteEntry entry = {AutocompleteEntryKind::RequirePath}; + entry.insertText = std::move(suggestion.fullPath); + result[std::move(suggestion.label)] = std::move(entry); + } + return result; +} + +static std::optional autocompleteStringParams( + const ModulePtr& module, + const std::vector& nodes, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +) +{ + if (nodes.size() < 2) + { + return std::nullopt; + } + + if (!nodes.back()->is() && !isSimpleInterpolatedString(nodes.back()) && !nodes.back()->is()) + { + return std::nullopt; + } + + if (!nodes.back()->is()) + { + if (nodes.back()->location.end == position || nodes.back()->location.begin == position) + { + return std::nullopt; + } + } + + AstExprCall* candidate = nodes.at(nodes.size() - 2)->as(); + if (!candidate) + { + return std::nullopt; + } + + // HACK: All current instances of 'magic string' params are the first parameter of their functions, + // so we encode that here rather than putting a useless member on the FunctionType struct. + if (candidate->args.size > 1 && !candidate->args.data[0]->location.contains(position)) + { + return std::nullopt; + } + + auto it = module->astTypes.find(candidate->func); + if (!it) + { + return std::nullopt; + } + + std::optional candidateString = getStringContents(nodes.back()); + + auto performCallback = [&](const FunctionType* funcType) -> std::optional + { + for (const std::string& tag : funcType->tags) + { + if (FFlag::AutocompleteRequirePathSuggestions2) + { + if (tag == kRequireTagName && fileResolver) + { + return convertRequireSuggestionsToAutocompleteEntryMap(fileResolver->getRequireSuggestions(module->name, candidateString)); + } + } + if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) + { + return ret; + } + } + return std::nullopt; + }; + + auto followedId = Luau::follow(*it); + if (auto functionType = Luau::get(followedId)) + { + return performCallback(functionType); + } + + if (auto intersect = Luau::get(followedId)) + { + for (TypeId part : intersect->parts) + { + if (auto candidateFunctionType = Luau::get(part)) + { + if (std::optional ret = performCallback(candidateFunctionType)) + { + return ret; + } + } + } + } + + return std::nullopt; +} + +static AutocompleteResult autocompleteWhileLoopKeywords(std::vector ancestry) +{ + AutocompleteEntryMap ret; + ret["do"] = {AutocompleteEntryKind::Keyword}; + ret["and"] = {AutocompleteEntryKind::Keyword}; + ret["or"] = {AutocompleteEntryKind::Keyword}; + return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword}; +} + +static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) +{ + std::string result = "function("; + + auto [args, tail] = Luau::flatten(funcTy.argTypes); + + bool first = true; + // Skip the implicit 'self' argument if call is indexed with ':' + for (size_t argIdx = 0; argIdx < args.size(); ++argIdx) + { + if (!first) + result += ", "; + else + first = false; + + std::string name; + if (argIdx < funcTy.argNames.size() && funcTy.argNames[argIdx]) + name = funcTy.argNames[argIdx]->name; + else + name = "a" + std::to_string(argIdx); + + if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) + result += name + ": " + *type; + else + result += name; + } + + if (tail && (Luau::isVariadic(*tail) || Luau::get(Luau::follow(*tail)))) + { + if (!first) + result += ", "; + + std::optional varArgType; + if (const VariadicTypePack* pack = get(follow(*tail))) + { + if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) + varArgType = std::move(res); + } + + if (varArgType) + result += "...: " + *varArgType; + else + result += "..."; + } + + result += ")"; + + auto [rets, retTail] = Luau::flatten(funcTy.retTypes); + if (const size_t totalRetSize = rets.size() + (retTail ? 1 : 0); totalRetSize > 0) + { + if (std::optional returnTypes = tryToStringDetailed(scope, funcTy.retTypes, true)) + { + result += ": "; + bool wrap = totalRetSize != 1; + if (wrap) + result += "("; + result += *returnTypes; + if (wrap) + result += ")"; + } + } + result += " end"; + return result; +} + +static std::optional makeAnonymousAutofilled( + const ModulePtr& module, + const ScopePtr& scopeAtPosition, + Position position, + const AstNode* node, + const std::vector& ancestry +) +{ + const AstExprCall* call = node->as(); + if (!call && ancestry.size() > 1) + call = ancestry[ancestry.size() - 2]->as(); + + if (!call) + return std::nullopt; + + if (!call->location.containsClosed(position) || call->func->location.containsClosed(position)) + return std::nullopt; + + TypeId* typeIter = module->astTypes.find(call->func); + if (!typeIter) + return std::nullopt; + + const FunctionType* outerFunction = get(follow(*typeIter)); + if (!outerFunction) + return std::nullopt; + + size_t argument = 0; + for (size_t i = 0; i < call->args.size; ++i) + { + if (call->args.data[i]->location.containsClosed(position)) + { + argument = i; + break; + } + } + + if (call->self) + argument++; + + std::optional argType; + auto [args, tail] = flatten(outerFunction->argTypes); + if (argument < args.size()) + argType = args[argument]; + + if (!argType) + return std::nullopt; + + TypeId followed = follow(*argType); + const FunctionType* type = get(followed); + if (!type) + { + if (const UnionType* unionType = get(followed)) + { + if (std::optional nonnullFunction = returnFirstNonnullOptionOfType(unionType)) + type = *nonnullFunction; + } + } + + if (!type) + return std::nullopt; + + const ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(*module, position); + if (!scope) + return std::nullopt; + + AutocompleteEntry entry; + entry.kind = AutocompleteEntryKind::GeneratedFunction; + entry.typeCorrect = TypeCorrectKind::Correct; + entry.type = argType; + entry.insertText = makeAnonymous(scope, *type); + return std::make_optional(std::move(entry)); +} + +AutocompleteResult autocomplete_( + const ModulePtr& module, + NotNull builtinTypes, + TypeArena* typeArena, + std::vector& ancestry, + Scope* globalScope, + const ScopePtr& scopeAtPosition, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +) +{ + AstNode* node = ancestry.back(); + + AstExprConstantNil dummy{Location{}}; + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; + + // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node + if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) + { + ancestry.pop_back(); + + node = ancestry.back(); + parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; + } + + if (auto indexName = node->as()) + { + auto it = module->astTypes.find(indexName->expr); + if (!it) + return {}; + + TypeId ty = follow(*it); + PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; + + return {autocompleteProps(*module, typeArena, builtinTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; + } + else if (auto typeReference = node->as()) + { + if (typeReference->prefix) + return {autocompleteModuleTypes(*module, scopeAtPosition, position, typeReference->prefix->value), ancestry, AutocompleteContext::Type}; + else + return {autocompleteTypeNames(*module, scopeAtPosition, position, ancestry), ancestry, AutocompleteContext::Type}; + } + else if (node->is()) + { + return {autocompleteTypeNames(*module, scopeAtPosition, position, ancestry), ancestry, AutocompleteContext::Type}; + } + else if (AstStatLocal* statLocal = node->as()) + { + if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) + return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Unknown}; + else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + else + return {}; + } + + else if (AstStatFor* statFor = extractStat(ancestry)) + { + if (!statFor->hasDo || position < statFor->doLocation.begin) + { + if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || + (statFor->step && statFor->step->location.containsClosed(position))) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + + if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + return {}; + } + + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + } + + else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) + { + if (!statForIn->hasIn || position <= statForIn->inLocation.begin) + { + AstLocal* lastName = statForIn->vars.data[statForIn->vars.size - 1]; + if (lastName->name == kParseNameError || lastName->location.containsClosed(position)) + { + // Here we are either working with a missing binding (as would be the case in a bare "for" keyword) or + // the cursor is still touching a binding name. The user is still typing a new name, so we should not offer + // any suggestions. + return {}; + } + + return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + } + + if (!statForIn->hasDo || position <= statForIn->doLocation.begin) + { + LUAU_ASSERT(statForIn->values.size > 0); + AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; + + if (lastExpr->location.containsClosed(position)) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + + if (position > lastExpr->location.end) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + + return {}; // Not sure what this means + } + } + else if (AstStatForIn* statForIn = extractStat(ancestry)) + { + // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. + // ex "for f in f do" + if (!statForIn->hasDo) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + } + + else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) + { + if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) + { + return autocompleteWhileLoopKeywords(ancestry); + } + + if (!statWhile->hasDo || position < statWhile->doLocation.begin) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + + if (statWhile->hasDo && position > statWhile->doLocation.end) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + } + + else if (AstStatWhile* statWhile = extractStat(ancestry); + (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition && + !statWhile->condition->location.containsClosed(position))) + { + return autocompleteWhileLoopKeywords(ancestry); + } + else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) + { + return { + {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, + ancestry, + AutocompleteContext::Keyword + }; + } + else if (AstStatIf* statIf = parent->as(); statIf && node->is()) + { + if (statIf->condition->is()) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + } + else if (AstStatIf* statIf = extractStat(ancestry); statIf && + (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && + (statIf->condition && !statIf->condition->location.containsClosed(position))) + { + AutocompleteEntryMap ret; + ret["then"] = {AutocompleteEntryKind::Keyword}; + ret["and"] = {AutocompleteEntryKind::Keyword}; + ret["or"] = {AutocompleteEntryKind::Keyword}; + return {std::move(ret), ancestry, AutocompleteContext::Keyword}; + } + else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + else if (AstExprTable* exprTable = parent->as(); + exprTable && (node->is() || node->is() || node->is())) + { + for (const auto& [kind, key, value] : exprTable->items) + { + // If item doesn't have a key, maybe the value is actually the key + if (key ? key == node : node->is() && value == node) + { + if (auto it = module->astExpectedTypes.find(exprTable)) + { + auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); + + if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*nodeIt, !node->is(), node, position, result); + + if (!key) + { + // If there is "no key," it may be that the user + // intends for the current token to be the key, but + // has yet to type the `=` sign. + // + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + { + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); + } + } + + // Remove keys that are already completed + for (const auto& item : exprTable->items) + { + if (!item.key) + continue; + + if (auto stringKey = item.key->as()) + result.erase(std::string(stringKey->value.data, stringKey->value.size)); + } + + // If we know for sure that a key is being written, do not offer general expression suggestions + if (!key) + autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position, result); + + return {result, ancestry, AutocompleteContext::Property}; + } + + break; + } + } + } + else if (AstExprTable* exprTable = node->as()) + { + AutocompleteEntryMap result; + + if (auto it = module->astExpectedTypes.find(exprTable)) + { + result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); + + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + { + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); + } + + // Remove keys that are already completed + for (const auto& item : exprTable->items) + { + if (!item.key) + continue; + + if (auto stringKey = item.key->as()) + result.erase(std::string(stringKey->value.data, stringKey->value.size)); + } + } + + // Also offer general expression suggestions + autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position, result); + + return {result, ancestry, AutocompleteContext::Property}; + } + else if (isIdentifier(node) && (parent->is() || parent->is())) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + + if (std::optional ret = autocompleteStringParams(module, ancestry, position, fileResolver, callback)) + { + return {*ret, ancestry, AutocompleteContext::String}; + } + else if (node->is() || isSimpleInterpolatedString(node)) + { + AutocompleteEntryMap result; + + if (auto it = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*it, false, node, position, result); + + if (ancestry.size() >= 2) + { + if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) + { + if (auto it = module->astTypes.find(idxExpr->expr)) + autocompleteProps(*module, typeArena, builtinTypes, follow(*it), PropIndexType::Point, ancestry, result); + } + else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) + { + if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) + { + if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) + autocompleteStringSingleton(*it, false, node, position, result); + } + } + } + + return {result, ancestry, AutocompleteContext::String}; + } + else if (stringPartOfInterpString(node, position)) + { + // We're not a simple interpolated string, we're something like `a{"b"}@1`, and we + // can't know what to format to + AutocompleteEntryMap map; + return {map, ancestry, AutocompleteContext::String}; + } + + if (node->is()) + return {}; + + if (node->asExpr()) + { + AutocompleteResult ret = autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + if (std::optional generated = makeAnonymousAutofilled(module, scopeAtPosition, position, node, ancestry)) + ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); + return ret; + } + else if (node->asStat()) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + + return {}; +} + + +} // namespace Luau diff --git a/Analysis/src/AutocompleteCore.h b/Analysis/src/AutocompleteCore.h new file mode 100644 index 00000000..d4264da2 --- /dev/null +++ b/Analysis/src/AutocompleteCore.h @@ -0,0 +1,27 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AutocompleteTypes.h" + +namespace Luau +{ +struct Module; +struct FileResolver; + +using ModulePtr = std::shared_ptr; +using ModuleName = std::string; + + +AutocompleteResult autocomplete_( + const ModulePtr& module, + NotNull builtinTypes, + TypeArena* typeArena, + std::vector& ancestry, + Scope* globalScope, + const ScopePtr& scopeAtPosition, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +); + +} // namespace Luau diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 30fc2696..3dacae04 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -33,7 +33,7 @@ LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAGVARIABLE(LuauTypestateBuiltins2) LUAU_FASTFLAGVARIABLE(LuauStringFormatArityFix) -LUAU_FASTFLAG(AutocompleteRequirePathSuggestions) +LUAU_FASTFLAG(AutocompleteRequirePathSuggestions2); namespace Luau { @@ -426,7 +426,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC attachDcrMagicFunction(ttv->props["freeze"].type(), dcrMagicFunctionFreeze); } - if (FFlag::AutocompleteRequirePathSuggestions) + if (FFlag::AutocompleteRequirePathSuggestions2) { TypeId requireTy = getGlobalBinding(globals, "require"); attachTag(requireTy, kRequireTagName); diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index a62879fa..a0b5fcf4 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -3,6 +3,8 @@ #include "Luau/Constraint.h" #include "Luau/VisitType.h" +LUAU_FASTFLAGVARIABLE(LuauDontRefCountTypesInTypeFunctions) + namespace Luau { @@ -46,6 +48,21 @@ struct ReferenceCountInitializer : TypeOnceVisitor // ClassTypes never contain free types. return false; } + + bool visit(TypeId, const TypeFunctionInstanceType&) override + { + // We do not consider reference counted types that are inside a type + // function to be part of the reachable reference counted types. + // Otherwise, code can be constructed in just the right way such + // that two type functions both claim to mutate a free type, which + // prevents either type function from trying to generalize it, so + // we potentially get stuck. + // + // The default behavior here is `true` for "visit the child types" + // of this type, hence: + return !FFlag::LuauDontRefCountTypesInTypeFunctions; + } + }; bool isReferenceCountedType(const TypeId typ) diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index d05623a8..ee602999 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -10,6 +10,7 @@ #include "Luau/Def.h" #include "Luau/DenseHash.h" #include "Luau/ModuleResolver.h" +#include "Luau/NotNull.h" #include "Luau/RecursionCounter.h" #include "Luau/Refinement.h" #include "Luau/Scope.h" @@ -30,11 +31,13 @@ LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTFLAG(DebugLuauLogSolverToJson) LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAG(LuauTypestateBuiltins2) LUAU_FASTFLAGVARIABLE(LuauNewSolverVisitErrorExprLvalues) LUAU_FASTFLAGVARIABLE(LuauNewSolverPrePopulateClasses) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunExportedAndLocal) LUAU_FASTFLAGVARIABLE(LuauNewSolverPopulateTableLocations) namespace Luau @@ -172,6 +175,7 @@ bool hasFreeType(TypeId ty) ConstraintGenerator::ConstraintGenerator( ModulePtr module, NotNull normalizer, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull moduleResolver, NotNull builtinTypes, @@ -188,6 +192,7 @@ ConstraintGenerator::ConstraintGenerator( , rootScope(nullptr) , dfg(dfg) , normalizer(normalizer) + , simplifier(simplifier) , typeFunctionRuntime(typeFunctionRuntime) , moduleResolver(moduleResolver) , ice(ice) @@ -257,7 +262,7 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) d = follow(d); if (d == ty) continue; - domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + domainTy = simplifyUnion(scope, Location{}, domainTy, d); } LUAU_ASSERT(get(ty)); @@ -267,7 +272,15 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block) { + // We prepopulate global data in the resumeScope to avoid writing data into the old modules scopes + prepopulateGlobalScopeForFragmentTypecheck(globalScope, resumeScope, block); + // Pre + // We need to pop the interior types, + interiorTypes.emplace_back(); visitBlockWithoutChildScope(resumeScope, block); + // Post + interiorTypes.pop_back(); + fillInInferredBindings(resumeScope, block); if (logger) @@ -282,7 +295,7 @@ void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStat d = follow(d); if (d == ty) continue; - domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + domainTy = simplifyUnion(resumeScope, resumeScope->location, domainTy, d); } LUAU_ASSERT(get(ty)); @@ -711,7 +724,7 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc continue; } - if (scope->parent != globalScope) + if (!FFlag::LuauUserTypeFunExportedAndLocal && scope->parent != globalScope) { reportError(function->location, GenericError{"Local user-defined functions are not supported yet"}); continue; @@ -740,17 +753,26 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc if (std::optional error = typeFunctionRuntime->registerFunction(function)) reportError(function->location, GenericError{*error}); - TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ - NotNull{&builtinTypeFunctions().userFunc}, - std::move(typeParams), - {}, - function->name, - }); + UserDefinedFunctionData udtfData; + + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + udtfData.owner = module; + udtfData.definition = function; + } + + TypeId typeFunctionTy = arena->addType( + TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, udtfData} + ); TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; // Set type bindings and definition locations for this user-defined type function - scope->privateTypeBindings[function->name.value] = std::move(typeFunction); + if (FFlag::LuauUserTypeFunExportedAndLocal && function->exported) + scope->exportedTypeBindings[function->name.value] = std::move(typeFunction); + else + scope->privateTypeBindings[function->name.value] = std::move(typeFunction); + aliasDefinitionLocations[function->name.value] = function->location; } else if (auto classDeclaration = stat->as()) @@ -780,6 +802,55 @@ void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* bloc classDefinitionLocations[classDeclaration->name.value] = classDeclaration->location; } } + + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Additional pass for user-defined type functions to fill in their environments completely + for (AstStat* stat : block->body) + { + if (auto function = stat->as()) + { + // Find the type function we have already created + TypeFunctionInstanceType* mainTypeFun = nullptr; + + if (auto it = scope->privateTypeBindings.find(function->name.value); it != scope->privateTypeBindings.end()) + mainTypeFun = getMutable(it->second.type); + + if (!mainTypeFun) + { + if (auto it = scope->exportedTypeBindings.find(function->name.value); it != scope->exportedTypeBindings.end()) + mainTypeFun = getMutable(it->second.type); + } + + // Fill it with all visible type functions + if (mainTypeFun) + { + UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData; + + for (Scope* curr = scope.get(); curr; curr = curr->parent.get()) + { + for (auto& [name, tf] : curr->privateTypeBindings) + { + if (userFuncData.environment.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment[name] = ty->userFuncData.definition; + } + + for (auto& [name, tf] : curr->exportedTypeBindings) + { + if (userFuncData.environment.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment[name] = ty->userFuncData.definition; + } + } + } + } + } + } } ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) @@ -900,12 +971,8 @@ ControlFlow ConstraintGenerator::visitBlockWithoutChildScope_DEPRECATED(const Sc if (std::optional error = typeFunctionRuntime->registerFunction(function)) reportError(function->location, GenericError{*error}); - TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ - NotNull{&builtinTypeFunctions().userFunc}, - std::move(typeParams), - {}, - function->name, - }); + TypeId typeFunctionTy = + arena->addType(TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, {}}); TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; @@ -2807,7 +2874,7 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local case ErrorSuppression::DoNotSuppress: break; case ErrorSuppression::Suppress: - ty = simplifyUnion(builtinTypes, arena, *ty, builtinTypes->errorType).result; + ty = simplifyUnion(scope, local->location, *ty, builtinTypes->errorType); break; case ErrorSuppression::NormalizationFailed: reportError(local->local->annotation->location, NormalizationTooComplex{}); @@ -3673,6 +3740,32 @@ TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location locati return resultType; } +struct FragmentTypeCheckGlobalPrepopulator : AstVisitor +{ + const NotNull globalScope; + const NotNull currentScope; + const NotNull dfg; + + FragmentTypeCheckGlobalPrepopulator(NotNull globalScope, NotNull currentScope, NotNull dfg) + : globalScope(globalScope) + , currentScope(currentScope) + , dfg(dfg) + { + } + + bool visit(AstExprGlobal* global) override + { + if (auto ty = globalScope->lookup(global->name)) + { + DefId def = dfg->getDef(global); + // We only want to write into the current scope the type of the global + currentScope->lvalueTypes[def] = *ty; + } + + return true; + } +}; + struct GlobalPrepopulator : AstVisitor { const NotNull globalScope; @@ -3719,6 +3812,14 @@ struct GlobalPrepopulator : AstVisitor } }; +void ConstraintGenerator::prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program) +{ + FragmentTypeCheckGlobalPrepopulator gp{NotNull{globalScope.get()}, NotNull{resumeScope.get()}, dfg}; + if (prepareModuleScope) + prepareModuleScope(module->name, resumeScope); + program->visit(&gp); +} + void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) { GlobalPrepopulator gp{NotNull{globalScope.get()}, arena, dfg}; @@ -3870,6 +3971,24 @@ TypeId ConstraintGenerator::createTypeFunctionInstance( return result; } +TypeId ConstraintGenerator::simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(UnionType{{left, right}}); + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId tyFun : res->newTypeFunctions) + addConstraint(scope, location, ReduceConstraint{tyFun}); + + return res->result; + } + else + return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result; +} + std::vector> borrowConstraints(const std::vector& constraints) { std::vector> result; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 398f0aa5..2b7a7232 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -33,6 +33,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings) LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) LUAU_FASTFLAGVARIABLE(LuauRemoveNotAnyHack) +LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification) LUAU_FASTFLAG(LuauNewSolverPopulateTableLocations) namespace Luau @@ -320,6 +321,7 @@ struct InstantiationQueuer : TypeOnceVisitor ConstraintSolver::ConstraintSolver( NotNull normalizer, + NotNull simplifier, NotNull typeFunctionRuntime, NotNull rootScope, std::vector> constraints, @@ -333,6 +335,7 @@ ConstraintSolver::ConstraintSolver( : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) + , simplifier(simplifier) , typeFunctionRuntime(typeFunctionRuntime) , constraints(std::move(constraints)) , rootScope(rootScope) @@ -1802,7 +1805,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNullprops[c.propName] = rhsType; // Food for thought: Could we block if simplification encounters a blocked type? - lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFreeUpperBound, newUpperBound).result; + lhsFree->upperBound = simplifyIntersection(constraint->scope, constraint->location, lhsFreeUpperBound, newUpperBound); bind(constraint, c.propType, rhsType); return true; @@ -2016,7 +2019,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullscope, constraint->location, std::move(parts)); unify(constraint, rhsType, res); } @@ -2596,9 +2599,9 @@ std::pair, std::optional> ConstraintSolver::lookupTa // if we're in an lvalue context, we need the _common_ type here. if (context == ValueContext::LValue) - return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)}; - return {{}, simplifyUnion(builtinTypes, arena, one, two).result}; + return {{}, simplifyUnion(constraint->scope, constraint->location, one, two)}; } // if we're in an lvalue context, we need the _common_ type here. else if (context == ValueContext::LValue) @@ -2630,7 +2633,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa { TypeId one = *begin(options); TypeId two = *(++begin(options)); - return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)}; } else return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; @@ -3019,6 +3022,63 @@ bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) return false; } +TypeId ConstraintSolver::simplifyIntersection(NotNull scope, Location location, TypeId left, TypeId right) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(IntersectionType{{left, right}}); + + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId ty : res->newTypeFunctions) + pushConstraint(scope, location, ReduceConstraint{ty}); + + return res->result; + } + else + return ::Luau::simplifyIntersection(builtinTypes, arena, left, right).result; +} + +TypeId ConstraintSolver::simplifyIntersection(NotNull scope, Location location, std::set parts) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(IntersectionType{std::vector(parts.begin(), parts.end())}); + + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId ty : res->newTypeFunctions) + pushConstraint(scope, location, ReduceConstraint{ty}); + + return res->result; + } + else + return ::Luau::simplifyIntersection(builtinTypes, arena, std::move(parts)).result; +} + +TypeId ConstraintSolver::simplifyUnion(NotNull scope, Location location, TypeId left, TypeId right) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(UnionType{{left, right}}); + + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId ty : res->newTypeFunctions) + pushConstraint(scope, location, ReduceConstraint{ty}); + + return res->result; + } + else + return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result; +} + TypeId ConstraintSolver::errorRecoveryType() const { return builtinTypes->errorRecoveryType(); diff --git a/Analysis/src/EqSatSimplification.cpp b/Analysis/src/EqSatSimplification.cpp new file mode 100644 index 00000000..41e87de2 --- /dev/null +++ b/Analysis/src/EqSatSimplification.cpp @@ -0,0 +1,2449 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/EqSatSimplification.h" +#include "Luau/EqSatSimplificationImpl.h" + +#include "Luau/EGraph.h" +#include "Luau/Id.h" +#include "Luau/Language.h" + +#include "Luau/StringUtils.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFunction.h" +#include "Luau/VisitType.h" + +#include +#include +#include +#include +#include +#include +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauLogSimplification) +LUAU_FASTFLAGVARIABLE(DebugLuauExtraEqSatSanityChecks) + +namespace Luau::EqSatSimplification +{ +using Id = Luau::EqSat::Id; + +using EGraph = Luau::EqSat::EGraph; +using Luau::EqSat::Slice; + +TTable::TTable(Id basis) +{ + storage.push_back(basis); +} + +// I suspect that this is going to become a performance hotspot. It would be +// nice to avoid allocating propTypes_ +TTable::TTable(Id basis, std::vector propNames_, std::vector propTypes_) + : propNames(std::move(propNames_)) +{ + storage.reserve(propTypes_.size() + 1); + storage.push_back(basis); + storage.insert(storage.end(), propTypes_.begin(), propTypes_.end()); + + LUAU_ASSERT(storage.size() == 1 + propTypes_.size()); +} + +Id TTable::getBasis() const +{ + LUAU_ASSERT(!storage.empty()); + return storage[0]; +} + +Slice TTable::propTypes() const +{ + LUAU_ASSERT(propNames.size() + 1 == storage.size()); + + return Slice{storage.data() + 1, propNames.size()}; +} + +Slice TTable::mutableOperands() +{ + return Slice{storage.data(), storage.size()}; +} + +Slice TTable::operands() const +{ + return Slice{storage.data(), storage.size()}; +} + +bool TTable::operator==(const TTable& rhs) const +{ + return storage == rhs.storage && propNames == rhs.propNames; +} + +size_t TTable::Hash::operator()(const TTable& value) const +{ + size_t hash = 0; + + // We're using pointers here, which does mean platform divergence. I think + // it's okay? (famous last words, I know) + for (StringId s : value.propNames) + EqSat::hashCombine(hash, EqSat::languageHash(s)); + + EqSat::hashCombine(hash, EqSat::languageHash(value.storage)); + + return hash; +} + +uint32_t StringCache::add(std::string_view s) +{ + size_t hash = std::hash()(s); + if (uint32_t* it = strings.find(hash)) + return *it; + + char* storage = static_cast(allocator.allocate(s.size())); + memcpy(storage, s.data(), s.size()); + + uint32_t result = uint32_t(views.size()); + views.emplace_back(storage, s.size()); + strings[hash] = result; + return result; +} + +std::string_view StringCache::asStringView(StringId id) const +{ + LUAU_ASSERT(id < views.size()); + return views[id]; +} + +std::string StringCache::asString(StringId id) const +{ + return std::string{asStringView(id)}; +} + +template +Simplify::Data Simplify::make(const EGraph&, const T&) const +{ + return true; +} + +void Simplify::join(Data& left, const Data& right) const +{ + left = left || right; +} + +using EClass = Luau::EqSat::EClass; + +// A terminal type is a type that does not contain any other types. +// Examples: any, unknown, number, string, boolean, nil, table, class, thread, function +// +// All class types are also terminal. +static bool isTerminal(const EType& node) +{ + return node.get() || node.get() || node.get() || node.get() || node.get() || + node.get() || node.get() || node.get() || node.get() || node.get() || + node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || + node.get() || node.get(); +} + +static bool isTerminal(const EGraph& egraph, Id eclass) +{ + const auto& nodes = egraph[eclass].nodes; + return std::any_of( + nodes.begin(), + nodes.end(), + [](auto& a) + { + return isTerminal(a); + } + ); +} + +Id mkUnion(EGraph& egraph, std::vector parts) +{ + if (parts.size() == 0) + return egraph.add(TNever{}); + else if (parts.size() == 1) + return parts[0]; + else + return egraph.add(Union{std::move(parts)}); +} + +Id mkIntersection(EGraph& egraph, std::vector parts) +{ + if (parts.size() == 0) + return egraph.add(TUnknown{}); + else if (parts.size() == 1) + return parts[0]; + else + return egraph.add(Intersection{std::move(parts)}); +} + +struct ListRemover +{ + std::unordered_map>& mappings2; + TypeId ty; + + ~ListRemover() + { + mappings2.erase(ty); + } +}; + +/* + * Crucial subtlety: It is very extremely important that enodes and eclasses are + * immutable. Mutating an enode would mean that it is no longer equivalent to + * other nodes in the same eclass. + * + * At the same time, many TypeIds are NOT immutable! + * + * The thing that makes this navigable is that it is okay if the same TypeId is + * imported as a different Id at different times as type inference runs. For + * example, if we at one point import a BlockedType as a TOpaque, and later + * import that same TypeId as some other enode type, this is all completely + * okay. + * + * The main thing we have to be very cautious about, I think, is unsealed + * tables. Unsealed table types have properties imperatively inserted into them + * as type inference runs. If we were to encode that TypeId as part of an + * enode, we could run into a situation where the egraph makes incorrect + * assumptions about the table. + * + * The solution is pretty simple: Never use the contents of a mutable TypeId in + * any reduction rule. TOpaque is always okay because we never actually poke + * around inside the TypeId to do anything. + */ +Id toId( + EGraph& egraph, + NotNull builtinTypes, + std::unordered_map& mappingIdToClass, + std::unordered_map>& typeToMappingId, // (TypeId: (MappingId, count)) + std::unordered_set& boundNodes, + StringCache& strings, + TypeId ty +) +{ + ty = follow(ty); + + // First, handle types which do not contain other types. They obviously + // cannot participate in cycles, so we don't have to check for that. + + if (auto freeTy = get(ty)) + return egraph.add(TOpaque{ty}); + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (auto prim = get(ty)) + { + switch (prim->type) + { + case Luau::PrimitiveType::NilType: + return egraph.add(TNil{}); + case Luau::PrimitiveType::Boolean: + return egraph.add(TBoolean{}); + case Luau::PrimitiveType::Number: + return egraph.add(TNumber{}); + case Luau::PrimitiveType::String: + return egraph.add(TString{}); + case Luau::PrimitiveType::Thread: + return egraph.add(TThread{}); + case Luau::PrimitiveType::Function: + return egraph.add(TTopFunction{}); + case Luau::PrimitiveType::Table: + return egraph.add(TTopTable{}); + case Luau::PrimitiveType::Buffer: + return egraph.add(TBuffer{}); + default: + LUAU_ASSERT(!"Unimplemented"); + return egraph.add(Invalid{}); + } + } + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + return egraph.add(SBoolean{bs->value}); + else if (auto ss = get(s)) + return egraph.add(SString{strings.add(ss->value)}); + else + LUAU_ASSERT(!"Unexpected"); + } + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (get(ty)) + return egraph.add(TFunction{ty}); + else if (ty == builtinTypes->classType) + return egraph.add(TTopClass{}); + else if (get(ty)) + return egraph.add(TClass{ty}); + else if (get(ty)) + return egraph.add(TAny{}); + else if (get(ty)) + return egraph.add(TError{}); + else if (get(ty)) + return egraph.add(TUnknown{}); + else if (get(ty)) + return egraph.add(TNever{}); + + // Now handle composite types. + + if (auto it = typeToMappingId.find(ty); it != typeToMappingId.end()) + { + auto& [mappingId, count] = it->second; + ++count; + Id res = egraph.add(TBound{mappingId}); + boundNodes.insert(res); + return res; + } + + typeToMappingId.emplace(ty, std::pair{mappingIdToClass.size(), 0}); + ListRemover lr{typeToMappingId, ty}; + + auto cache = [&](Id res) + { + const auto& [mappingId, count] = typeToMappingId.at(ty); + if (count > 0) + mappingIdToClass.emplace(mappingId, res); + return res; + }; + + if (auto tt = get(ty)) + return egraph.add(TImportedTable{ty}); + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (auto ut = get(ty)) + { + std::vector parts; + for (TypeId part : ut) + parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); + + return cache(mkUnion(egraph, std::move(parts))); + } + else if (auto it = get(ty)) + { + std::vector parts; + for (TypeId part : it) + parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); + + LUAU_ASSERT(parts.size() > 1); + + return cache(mkIntersection(egraph, std::move(parts))); + } + else if (auto negation = get(ty)) + { + Id part = toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, negation->ty); + return cache(egraph.add(Negation{std::array{part}})); + } + else if (auto tfun = get(ty)) + { + LUAU_ASSERT(tfun->packArguments.empty()); + + std::vector parts; + for (TypeId part : tfun->typeArguments) + parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); + + return cache(egraph.add(TTypeFun{tfun->function.get(), std::move(parts)})); + } + else if (get(ty)) + return egraph.add(TNoRefine{}); + else + { + LUAU_ASSERT(!"Unhandled Type"); + return cache(egraph.add(Invalid{})); + } +} + +Id toId(EGraph& egraph, NotNull builtinTypes, std::unordered_map& mappingIdToClass, StringCache& strings, TypeId ty) +{ + std::unordered_map> typeToMappingId; + std::unordered_set boundNodes; + Id id = toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, ty); + + for (Id id : boundNodes) + { + for (const auto [tb, _index] : Query(&egraph, id)) + { + Id bindee = mappingIdToClass.at(tb->value()); + egraph.merge(id, bindee); + } + } + + egraph.rebuild(); + + return egraph.find(id); +} + +// We apply a penalty to cyclic types to guide the system away from them where +// possible. +static const int CYCLE_PENALTY = 5000; + +// Composite types have cost equal to the sum of the costs of their parts plus a +// constant factor. +static const int SET_TYPE_PENALTY = 1; +static const int TABLE_TYPE_PENALTY = 2; +static const int NEGATION_PENALTY = 2; +static const int TFUN_PENALTY = 2; + +// FIXME. We don't have an accurate way to score a TImportedTable table against +// a TTable. +static const int IMPORTED_TABLE_PENALTY = 50; + +// TBound shouldn't ever be selected as the best node of a class unless we are +// debugging eqsat itself and need to stringify eclasses. We thus penalize it +// so heavily that we'll use any other alternative. +static const int BOUND_PENALTY = 999999999; + +// TODO iteration count limit +// TODO also: accept an argument which is the maximum cost to consider before +// abandoning the count. +// TODO: the egraph should be the first parameter. +static size_t computeCost(std::unordered_map& bestNodes, const EGraph& egraph, std::unordered_map& costs, Id id) +{ + if (auto it = costs.find(id); it != costs.end()) + return it->second; + + const std::vector& nodes = egraph[id].nodes; + + size_t minCost = std::numeric_limits::max(); + size_t bestNode = std::numeric_limits::max(); + + const auto updateCost = [&](size_t cost, size_t node) + { + if (cost < minCost) + { + minCost = cost; + bestNode = node; + } + }; + + // First, quickly scan for a terminal type. If we can find one, it is obviously the best. + for (size_t index = 0; index < nodes.size(); ++index) + { + if (isTerminal(nodes[index])) + { + minCost = 1; + bestNode = index; + + costs[id] = 1; + const auto [iter, isFresh] = bestNodes.insert({id, index}); + + // If we are forcing the cost function to select a specific node, + // then we still need to traverse into that node, even if this + // particular node is the obvious choice under normal circumstances. + if (isFresh || iter->second == index) + return 1; + } + } + + // If we recur into this type before this call frame completes, it is + // because this type participates in a cycle. + costs[id] = CYCLE_PENALTY; + + auto computeChildren = [&](Slice parts, size_t maxCost) -> std::optional + { + size_t cost = 0; + for (Id part : parts) + { + cost += computeCost(bestNodes, egraph, costs, part); + + // Abandon this node if it is too costly + if (cost > maxCost) + return std::nullopt; + } + return cost; + }; + + size_t startIndex = 0; + size_t endIndex = nodes.size(); + + // FFlag::DebugLuauLogSimplification will sometimes stringify an Id and pass + // in a prepopulated bestNodes map. If that mapping already has an index + // for this Id, don't look at the other nodes of this class. + if (auto it = bestNodes.find(id); it != bestNodes.end()) + { + LUAU_ASSERT(it->second < nodes.size()); + + startIndex = it->second; + endIndex = startIndex + 1; + } + + for (size_t index = startIndex; index < endIndex; ++index) + { + const auto& node = nodes[index]; + + if (node.get()) + updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound. + else if (node.get()) + { + minCost = 1; + bestNode = index; + } + else if (auto tbl = node.get()) + { + // TODO: We could make the penalty a parameter to computeChildren. + std::optional maybeCost = computeChildren(tbl->operands(), minCost); + if (maybeCost) + updateCost(TABLE_TYPE_PENALTY + *maybeCost, index); + } + else if (node.get()) + { + minCost = IMPORTED_TABLE_PENALTY; + bestNode = index; + } + else if (auto u = node.get()) + { + std::optional maybeCost = computeChildren(u->operands(), minCost); + if (maybeCost) + updateCost(SET_TYPE_PENALTY + *maybeCost, index); + } + else if (auto i = node.get()) + { + std::optional maybeCost = computeChildren(i->operands(), minCost); + if (maybeCost) + updateCost(SET_TYPE_PENALTY + *maybeCost, index); + } + else if (auto negation = node.get()) + { + std::optional maybeCost = computeChildren(negation->operands(), minCost); + if (maybeCost) + updateCost(NEGATION_PENALTY + *maybeCost, index); + } + else if (auto tfun = node.get()) + { + std::optional maybeCost = computeChildren(tfun->operands(), minCost); + if (maybeCost) + updateCost(TFUN_PENALTY + *maybeCost, index); + } + } + + LUAU_ASSERT(bestNode < nodes.size()); + + costs[id] = minCost; + bestNodes.insert({id, bestNode}); + return minCost; +} + +static std::unordered_map computeBestResult(const EGraph& egraph, Id id, const std::unordered_map& forceNodes) +{ + std::unordered_map costs; + std::unordered_map bestNodes = forceNodes; + computeCost(bestNodes, egraph, costs, id); + return bestNodes; +} + +static std::unordered_map computeBestResult(const EGraph& egraph, Id id) +{ + std::unordered_map costs; + std::unordered_map bestNodes; + computeCost(bestNodes, egraph, costs, id); + return bestNodes; +} + +TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& bestNodes, + std::unordered_map& seen, + std::vector& newTypeFunctions, + Id rootId +); + +TypeId flattenTableNode( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& bestNodes, + std::unordered_map& seen, + std::vector& newTypeFunctions, + Id rootId +) +{ + std::vector stack; + std::unordered_set seenIds; + + Id id = rootId; + const TImportedTable* importedTable = nullptr; + while (true) + { + size_t index = bestNodes.at(id); + const auto& eclass = egraph[id]; + + const auto [_iter, isFresh] = seenIds.insert(id); + if (!isFresh) + { + // If a TTable is its own basis, it must be the case that some other + // node on this eclass is a TImportedTable. Let's use that. + + for (size_t i = 0; i < eclass.nodes.size(); ++i) + { + if (eclass.nodes[i].get()) + { + index = i; + break; + } + } + + // If we couldn't find one, we don't know what to do. Use ErrorType. + LUAU_ASSERT(0); + return builtinTypes->errorType; + } + + const auto& node = eclass.nodes[index]; + if (const TTable* ttable = node.get()) + { + stack.push_back(ttable); + id = ttable->getBasis(); + continue; + } + else if (const TImportedTable* ti = node.get()) + { + importedTable = ti; + break; + } + else + LUAU_ASSERT(0); + } + + TableType resultTable; + if (importedTable) + { + const TableType* t = Luau::get(importedTable->value()); + LUAU_ASSERT(t); + resultTable = *t; // Intentional shallow clone here + } + + while (!stack.empty()) + { + const TTable* t = stack.back(); + stack.pop_back(); + + for (size_t i = 0; i < t->propNames.size(); ++i) + { + StringId propName = t->propNames[i]; + const Id propType = t->propTypes()[i]; + + resultTable.props[strings.asString(propName)] = Property{fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, propType)}; + } + } + + return arena->addType(std::move(resultTable)); +} + +TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& bestNodes, + std::unordered_map& seen, + std::vector& newTypeFunctions, + Id rootId +) +{ + if (auto it = seen.find(rootId); it != seen.end()) + return it->second; + + size_t index = bestNodes.at(rootId); + LUAU_ASSERT(index <= egraph[rootId].nodes.size()); + + const EType& node = egraph[rootId].nodes[index]; + + if (node.get()) + return builtinTypes->nilType; + else if (node.get()) + return builtinTypes->booleanType; + else if (node.get()) + return builtinTypes->numberType; + else if (node.get()) + return builtinTypes->stringType; + else if (node.get()) + return builtinTypes->threadType; + else if (node.get()) + return builtinTypes->functionType; + else if (node.get()) + return builtinTypes->tableType; + else if (node.get()) + return builtinTypes->classType; + else if (node.get()) + return builtinTypes->bufferType; + else if (auto opaque = node.get()) + return opaque->value(); + else if (auto b = node.get()) + return b->value() ? builtinTypes->trueType : builtinTypes->falseType; + else if (auto s = node.get()) + return arena->addType(SingletonType{StringSingleton{strings.asString(s->value())}}); + else if (auto fun = node.get()) + return fun->value(); + else if (auto tbl = node.get()) + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + TypeId flattened = flattenTableNode(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); + + asMutable(res)->ty.emplace(flattened); + return flattened; + } + else if (auto tbl = node.get()) + return tbl->value(); + else if (auto cls = node.get()) + return cls->value(); + else if (node.get()) + return builtinTypes->anyType; + else if (node.get()) + return builtinTypes->errorType; + else if (node.get()) + return builtinTypes->unknownType; + else if (node.get()) + return builtinTypes->neverType; + else if (auto u = node.get()) + { + Slice parts = u->operands(); + + if (parts.empty()) + return builtinTypes->neverType; + else if (parts.size() == 1) + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); + else + { + TypeId res = arena->addType(BlockedType{}); + + seen[rootId] = res; + + std::vector partTypes; + partTypes.reserve(parts.size()); + + for (Id part : parts) + partTypes.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); + + asMutable(res)->ty.emplace(std::move(partTypes)); + + return res; + } + } + else if (auto i = node.get()) + { + Slice parts = i->operands(); + + if (parts.empty()) + return builtinTypes->neverType; + else if (parts.size() == 1) + { + LUAU_ASSERT(parts[0] != rootId); + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); + } + else + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + std::vector partTypes; + partTypes.reserve(parts.size()); + + for (Id part : parts) + partTypes.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); + + asMutable(res)->ty.emplace(std::move(partTypes)); + + return res; + } + } + else if (auto negation = node.get()) + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + TypeId ty = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, negation->operands()[0]); + + asMutable(res)->ty.emplace(ty); + + return res; + } + else if (auto tfun = node.get()) + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + std::vector args; + for (Id part : tfun->operands()) + args.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); + + asMutable(res)->ty.emplace(*tfun->value(), std::move(args)); + + newTypeFunctions.push_back(res); + + return res; + } + else if (node.get()) + return builtinTypes->errorType; + else if (node.get()) + return builtinTypes->noRefineType; + else + { + LUAU_ASSERT(!"Unimplemented"); + return nullptr; + } +} + +static TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& forceNodes, + std::vector& newTypeFunctions, + Id rootId +) +{ + const std::unordered_map bestNodes = computeBestResult(egraph, rootId, forceNodes); + std::unordered_map seen; + + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); +} + +static TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + std::vector& newTypeFunctions, + Id rootId +) +{ + const std::unordered_map bestNodes = computeBestResult(egraph, rootId); + std::unordered_map seen; + + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); +} + +Subst::Subst(Id eclass, Id newClass, std::string desc) + : eclass(std::move(eclass)) + , newClass(std::move(newClass)) + , desc(std::move(desc)) +{ +} + +std::string mkDesc( + EGraph& egraph, + const StringCache& strings, + NotNull arena, + NotNull builtinTypes, + Id from, + Id to, + const std::unordered_map& forceNodes, + const std::string& rule +) +{ + if (!FFlag::DebugLuauLogSimplification) + return ""; + + std::vector newTypeFunctions; + + TypeId fromTy = fromId(egraph, strings, builtinTypes, arena, forceNodes, newTypeFunctions, from); + TypeId toTy = fromId(egraph, strings, builtinTypes, arena, forceNodes, newTypeFunctions, to); + + ToStringOptions opts; + opts.useQuestionMarks = false; + + const int RULE_PADDING = 35; + const std::string rulePadding(std::max(0, RULE_PADDING - rule.size()), ' '); + const std::string fromIdStr = ""; // "(" + std::to_string(uint32_t(from)) + ") "; + const std::string toIdStr = ""; // "(" + std::to_string(uint32_t(to)) + ") "; + + return rule + ":" + rulePadding + fromIdStr + toString(fromTy, opts) + " <=> " + toIdStr + toString(toTy, opts); +} + +std::string mkDesc(EGraph& egraph, const StringCache& strings, NotNull arena, NotNull builtinTypes, Id from, Id to, const std::string& rule) +{ + if (!FFlag::DebugLuauLogSimplification) + return ""; + + return mkDesc(egraph, strings, arena, builtinTypes, from, to, {}, rule); +} + +static std::string getNodeName(const StringCache& strings, const EType& node) +{ + if (node.get()) + return "nil"; + else if (node.get()) + return "boolean"; + else if (node.get()) + return "number"; + else if (node.get()) + return "string"; + else if (node.get()) + return "thread"; + else if (node.get()) + return "function"; + else if (node.get()) + return "table"; + else if (node.get()) + return "class"; + else if (node.get()) + return "buffer"; + else if (node.get()) + return "opaque"; + else if (auto b = node.get()) + return b->value() ? "true" : "false"; + else if (auto s = node.get()) + return "\"" + strings.asString(s->value()) + "\""; + else if (node.get()) + return "\xe2\x88\xaa"; + else if (node.get()) + return "\xe2\x88\xa9"; + else if (auto cls = node.get()) + { + const ClassType* ct = get(cls->value()); + LUAU_ASSERT(ct); + return ct->name; + } + else if (node.get()) + return "any"; + else if (node.get()) + return "error"; + else if (node.get()) + return "unknown"; + else if (node.get()) + return "never"; + else if (auto tfun = node.get()) + return "tfun " + tfun->value()->name; + else if (node.get()) + return "~"; + else if (node.get()) + return "invalid?"; + else if (node.get()) + return "bound"; + + return "???"; +} + +std::string toDot(const StringCache& strings, const EGraph& egraph) +{ + std::stringstream ss; + ss << "digraph G {" << '\n'; + ss << " graph [fontsize=10 fontname=\"Verdana\" compound=true];" << '\n'; + ss << " node [shape=record fontsize=10 fontname=\"Verdana\"];" << '\n'; + + std::set populated; + + for (const auto& [id, eclass] : egraph.getAllClasses()) + { + for (const auto& node : eclass.nodes) + { + if (!node.operands().empty()) + populated.insert(id); + for (Id op : node.operands()) + populated.insert(op); + } + } + + for (const auto& [id, eclass] : egraph.getAllClasses()) + { + if (!populated.count(id)) + continue; + + const std::string className = "cluster_" + std::to_string(uint32_t(id)); + ss << " subgraph " << className << " {" << '\n'; + ss << " node [style=\"rounded,filled\"];" << '\n'; + ss << " label = \"" << uint32_t(id) << "\";" << '\n'; + ss << " color = blue;" << '\n'; + + for (size_t index = 0; index < eclass.nodes.size(); ++index) + { + const auto& node = eclass.nodes[index]; + + const std::string label = getNodeName(strings, node); + const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index); + + ss << " " << nodeName << " [label=\"" << label << "\"];" << '\n'; + } + + ss << " }" << '\n'; + } + + for (const auto& [id, eclass] : egraph.getAllClasses()) + { + for (size_t index = 0; index < eclass.nodes.size(); ++index) + { + const auto& node = eclass.nodes[index]; + + const std::string label = getNodeName(strings, node); + const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index); + + for (Id op : node.operands()) + { + op = egraph.find(op); + const std::string destNodeName = "n" + std::to_string(uint32_t(op)) + "_0"; + ss << " " << nodeName << " -> " << destNodeName << " [lhead=cluster_" << uint32_t(op) << "];" << '\n'; + } + } + } + + ss << "}" << '\n'; + + return ss.str(); +} + +template +static Tag const* isTag(const EType& node) +{ + return node.get(); +} + +/// Important: Only use this to test for leaf node types like TUnknown and +/// TNumber. Things that we know cannot be simplified any further and are safe +/// to short-circuit on. +/// +/// It does a linear scan and exits early, so if a particular eclass has +/// multiple "interesting" representations, this function can surprise you. +template +static Tag const* isTag(const EGraph& egraph, Id id) +{ + for (const auto& node : egraph[id].nodes) + { + if (auto n = isTag(node)) + return n; + } + return nullptr; +} + +struct RewriteRule +{ + explicit RewriteRule(EGraph* egraph) + : egraph(egraph) + { + } + + virtual void read(std::vector& substs, Id eclass, const EType* enode) = 0; + +protected: + const EqSat::EClass& get(Id id) + { + return (*egraph)[id]; + } + + Id find(Id id) + { + return egraph->find(id); + } + + Id add(EType enode) + { + return egraph->add(std::move(enode)); + } + + template + const Tag* isTag(Id id) + { + for (const auto& node : (*egraph)[id].nodes) + { + if (auto n = node.get()) + return n; + } + return nullptr; + } + + template + bool isTag(const EType& enode) + { + return enode.get(); + } + +public: + EGraph* egraph; +}; + +enum SubclassRelationship +{ + LeftSuper, + RightSuper, + Unrelated +}; + +static SubclassRelationship relateClasses(const TClass* leftClass, const TClass* rightClass) +{ + const ClassType* leftClassType = Luau::get(leftClass->value()); + const ClassType* rightClassType = Luau::get(rightClass->value()); + + if (isSubclass(leftClassType, rightClassType)) + return RightSuper; + else if (isSubclass(rightClassType, leftClassType)) + return LeftSuper; + else + return Unrelated; +} + +// Entirely analogous to NormalizedType except that it operates on eclasses instead of TypeIds. +struct CanonicalizedType +{ + std::optional nilPart; + std::optional truePart; + std::optional falsePart; + std::optional numberPart; + std::optional stringPart; + std::vector stringSingletons; + std::optional threadPart; + std::optional functionPart; + std::optional tablePart; + std::vector classParts; + std::optional bufferPart; + std::optional errorPart; + + // Functions that have been union'd into the type + std::unordered_set functionParts; + + // Anything that isn't canonical: Intersections, unions, free types, and so on. + std::unordered_set otherParts; + + bool isUnknown() const + { + return nilPart && truePart && falsePart && numberPart && stringPart && threadPart && functionPart && tablePart && bufferPart; + } +}; + +void unionUnknown(EGraph& egraph, CanonicalizedType& ct) +{ + ct.nilPart = egraph.add(TNil{}); + ct.truePart = egraph.add(SBoolean{true}); + ct.falsePart = egraph.add(SBoolean{false}); + ct.numberPart = egraph.add(TNumber{}); + ct.stringPart = egraph.add(TString{}); + ct.threadPart = egraph.add(TThread{}); + ct.functionPart = egraph.add(TTopFunction{}); + ct.tablePart = egraph.add(TTopTable{}); + ct.bufferPart = egraph.add(TBuffer{}); + + ct.functionParts.clear(); + ct.otherParts.clear(); +} + +void unionAny(EGraph& egraph, CanonicalizedType& ct) +{ + unionUnknown(egraph, ct); + ct.errorPart = egraph.add(TError{}); +} + +void unionClasses(EGraph& egraph, std::vector& hereParts, Id there) +{ + if (1 == hereParts.size() && isTag(egraph, hereParts[0])) + return; + + const auto thereClass = isTag(egraph, there); + if (!thereClass) + return; + + for (size_t index = 0; index < hereParts.size(); ++index) + { + const Id herePart = hereParts[index]; + + if (auto partClass = isTag(egraph, herePart)) + { + switch (relateClasses(partClass, thereClass)) + { + case LeftSuper: + return; + case RightSuper: + hereParts[index] = there; + std::sort(hereParts.begin(), hereParts.end()); + return; + case Unrelated: + continue; + } + } + } + + hereParts.push_back(there); + std::sort(hereParts.begin(), hereParts.end()); +} + +void unionWithType(EGraph& egraph, CanonicalizedType& ct, Id part) +{ + if (isTag(egraph, part)) + ct.nilPart = part; + else if (isTag(egraph, part)) + ct.truePart = ct.falsePart = part; + else if (auto b = isTag(egraph, part)) + { + if (b->value()) + ct.truePart = part; + else + ct.falsePart = part; + } + else if (isTag(egraph, part)) + ct.numberPart = part; + else if (isTag(egraph, part)) + ct.stringPart = part; + else if (isTag(egraph, part)) + ct.stringSingletons.push_back(part); + else if (isTag(egraph, part)) + ct.threadPart = part; + else if (isTag(egraph, part)) + { + ct.functionPart = part; + ct.functionParts.clear(); + } + else if (isTag(egraph, part)) + ct.tablePart = part; + else if (isTag(egraph, part)) + ct.classParts = {part}; + else if (isTag(egraph, part)) + ct.bufferPart = part; + else if (isTag(egraph, part)) + { + if (!ct.functionPart) + ct.functionParts.insert(part); + } + else if (auto tclass = isTag(egraph, part)) + unionClasses(egraph, ct.classParts, part); + else if (isTag(egraph, part)) + { + unionAny(egraph, ct); + return; + } + else if (isTag(egraph, part)) + ct.errorPart = part; + else if (isTag(egraph, part)) + unionUnknown(egraph, ct); + else if (isTag(egraph, part)) + { + // Nothing + } + else + ct.otherParts.insert(part); +} + +// Find an enode under the given eclass which is simple enough that it could be +// subtracted from a CanonicalizedType easily. +// +// A union is "simple enough" if it is acyclic and is only comprised of terminal +// types and unions that are themselves subtractable +const EType* findSubtractableClass(const EGraph& egraph, std::unordered_set& seen, Id id) +{ + if (seen.count(id)) + return nullptr; + + const EType* bestUnion = nullptr; + std::optional unionSize; + + for (const auto& node : egraph[id].nodes) + { + if (isTerminal(node)) + return &node; + + if (const auto u = node.get()) + { + seen.insert(id); + + for (Id part : u->operands()) + { + if (!findSubtractableClass(egraph, seen, part)) + return nullptr; + } + + // If multiple unions in this class are all simple enough, prefer + // the shortest one. + if (!unionSize || u->operands().size() < unionSize) + { + unionSize = u->operands().size(); + bestUnion = &node; + } + } + } + + return bestUnion; +} + +const EType* findSubtractableClass(const EGraph& egraph, Id id) +{ + std::unordered_set seen; + + return findSubtractableClass(egraph, seen, id); +} + +// Subtract the type 'part' from 'ct' +// Returns true if the subtraction succeeded. This function will fail if 'part` is too complicated. +bool subtract(EGraph& egraph, CanonicalizedType& ct, Id part) +{ + const EType* etype = findSubtractableClass(egraph, part); + if (!etype) + return false; + + if (etype->get()) + ct.nilPart.reset(); + else if (etype->get()) + { + ct.truePart.reset(); + ct.falsePart.reset(); + } + else if (auto b = etype->get()) + { + if (b->value()) + ct.truePart.reset(); + else + ct.falsePart.reset(); + } + else if (etype->get()) + ct.numberPart.reset(); + else if (etype->get()) + ct.stringPart.reset(); + else if (etype->get()) + return false; + else if (etype->get()) + ct.threadPart.reset(); + else if (etype->get()) + ct.functionPart.reset(); + else if (etype->get()) + ct.tablePart.reset(); + else if (etype->get()) + ct.classParts.clear(); + else if (auto tclass = etype->get()) + { + auto it = std::find(ct.classParts.begin(), ct.classParts.end(), part); + if (it != ct.classParts.end()) + ct.classParts.erase(it); + else + return false; + } + else if (etype->get()) + ct.bufferPart.reset(); + else if (etype->get()) + ct = {}; + else if (etype->get()) + ct.errorPart.reset(); + else if (etype->get()) + { + std::optional errorPart = ct.errorPart; + ct = {}; + ct.errorPart = errorPart; + } + else if (etype->get()) + { + // Nothing + } + else if (auto u = etype->get()) + { + // TODO cycles + // TODO this is super promlematic because 'part' represents a whole group of equivalent enodes. + for (Id unionPart : u->operands()) + { + // TODO: This recursive call will require that we re-traverse this + // eclass to find the subtractible enode. It would be nice to do the + // work just once and reuse it. + bool ok = subtract(egraph, ct, unionPart); + if (!ok) + return false; + } + } + else if (etype->get()) + return false; + else + return false; + + return true; +} + +Id fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) +{ + if (ct.isUnknown()) + { + if (ct.errorPart) + return egraph.add(TAny{}); + else + return egraph.add(TUnknown{}); + } + + std::vector parts; + + if (ct.nilPart) + parts.push_back(*ct.nilPart); + + if (ct.truePart && ct.falsePart) + parts.push_back(egraph.add(TBoolean{})); + else if (ct.truePart) + parts.push_back(*ct.truePart); + else if (ct.falsePart) + parts.push_back(*ct.falsePart); + + if (ct.numberPart) + parts.push_back(*ct.numberPart); + + if (ct.stringPart) + parts.push_back(*ct.stringPart); + else if (!ct.stringSingletons.empty()) + parts.insert(parts.end(), ct.stringSingletons.begin(), ct.stringSingletons.end()); + + if (ct.threadPart) + parts.push_back(*ct.threadPart); + if (ct.functionPart) + parts.push_back(*ct.functionPart); + if (ct.tablePart) + parts.push_back(*ct.tablePart); + parts.insert(parts.end(), ct.classParts.begin(), ct.classParts.end()); + if (ct.bufferPart) + parts.push_back(*ct.bufferPart); + if (ct.errorPart) + parts.push_back(*ct.errorPart); + + parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end()); + parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end()); + + return mkUnion(egraph, std::move(parts)); +} + +void addChildren(const EGraph& egraph, const EType* enode, VecDeque& worklist) +{ + for (Id id : enode->operands()) + worklist.push_back(id); +} + +static bool occurs(EGraph& egraph, Id outerId, Slice operands) +{ + for (const Id i : operands) + { + if (egraph.find(i) == outerId) + return true; + } + return false; +} + +Simplifier::Simplifier(NotNull arena, NotNull builtinTypes) + : arena(arena) + , builtinTypes(builtinTypes) + , egraph(Simplify{}) +{ +} + +const EqSat::EClass& Simplifier::get(Id id) const +{ + return egraph[id]; +} + +Id Simplifier::find(Id id) const +{ + return egraph.find(id); +} + +Id Simplifier::add(EType enode) +{ + return egraph.add(std::move(enode)); +} + +template +const Tag* Simplifier::isTag(Id id) const +{ + for (const auto& node : get(id).nodes) + { + if (const Tag* ty = node.get()) + return ty; + } + + return nullptr; +} + +template +const Tag* Simplifier::isTag(const EType& enode) const +{ + return enode.get(); +} + +void Simplifier::subst(Id from, Id to) +{ + substs.emplace_back(from, to, " - "); +} + +void Simplifier::subst(Id from, Id to, const std::string& ruleName) +{ + std::string desc; + if (FFlag::DebugLuauLogSimplification) + desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, std::move(ruleName)); + substs.emplace_back(from, to, desc); +} + +void Simplifier::subst(Id from, Id to, const std::string& ruleName, const std::unordered_map& forceNodes) +{ + std::string desc; + if (FFlag::DebugLuauLogSimplification) + desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName); + substs.emplace_back(from, to, desc); +} + +void Simplifier::unionClasses(std::vector& hereParts, Id there) +{ + if (1 == hereParts.size() && isTag(hereParts[0])) + return; + + const auto thereClass = isTag(there); + if (!thereClass) + return; + + for (size_t index = 0; index < hereParts.size(); ++index) + { + const Id herePart = hereParts[index]; + + if (auto partClass = isTag(herePart)) + { + switch (relateClasses(partClass, thereClass)) + { + case LeftSuper: + return; + case RightSuper: + hereParts[index] = there; + std::sort(hereParts.begin(), hereParts.end()); + return; + case Unrelated: + continue; + } + } + } + + hereParts.push_back(there); + std::sort(hereParts.begin(), hereParts.end()); +} + +void Simplifier::simplifyUnion(Id id) +{ + id = find(id); + + for (const auto [u, unionIndex] : Query(&egraph, id)) + { + std::vector newParts; + std::unordered_set seen; + + CanonicalizedType canonicalized; + + if (occurs(egraph, id, u->operands())) + continue; + + for (Id part : u->operands()) + unionWithType(egraph, canonicalized, find(part)); + + Id resultId = fromCanonicalized(egraph, canonicalized); + + subst(id, resultId, "simplifyUnion", {{id, unionIndex}}); + } +} + +// If one of the nodes matches the given Tag, succeed and return the id and node for the other half. +// If neither matches, return nullopt. +template +static std::optional> matchOne(Id hereId, const EType* hereNode, Id thereId, const EType* thereNode) +{ + if (hereNode->get()) + return std::pair{thereId, thereNode}; + else if (thereNode->get()) + return std::pair{hereId, hereNode}; + else + return std::nullopt; +} + +// If the two nodes can be intersected into a "simple" type, return that, else return nullopt. +std::optional intersectOne(EGraph& egraph, Id hereId, const EType* hereNode, Id thereId, const EType* thereNode) +{ + hereId = egraph.find(hereId); + thereId = egraph.find(thereId); + + if (hereId == thereId) + return *hereNode; + + if (hereNode->get() || thereNode->get()) + return TNever{}; + + if (hereNode->get() || hereNode->get() || hereNode->get() || thereNode->get() || + thereNode->get() || thereNode->get() || hereNode->get() || thereNode->get()) + return std::nullopt; + + if (hereNode->get()) + return *thereNode; + if (thereNode->get()) + return *hereNode; + + if (hereNode->get()) + return *thereNode; + if (thereNode->get()) + return *hereNode; + + if (hereNode->get() || thereNode->get()) + return std::nullopt; + + if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) + { + const auto [otherId, otherNode] = *res; + + if (otherNode->get() || otherNode->get()) + return *otherNode; + else + return TNever{}; + } + if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) + { + const auto [otherId, otherNode] = *res; + + if (otherNode->get() || otherNode->get()) + return *otherNode; + } + if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) + { + const auto [otherId, otherNode] = *res; + + if (otherNode->get()) + return std::nullopt; // TODO + else + return TNever{}; + } + if (auto hereClass = hereNode->get()) + { + if (auto thereClass = thereNode->get()) + { + switch (relateClasses(hereClass, thereClass)) + { + case LeftSuper: + return *thereNode; + case RightSuper: + return *hereNode; + case Unrelated: + return TNever{}; + } + } + else + return TNever{}; + } + if (auto hereBool = hereNode->get()) + { + if (auto thereBool = thereNode->get()) + { + if (hereBool->value() == thereBool->value()) + return *hereNode; + else + return TNever{}; + } + else if (thereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (auto thereBool = thereNode->get()) + { + if (auto hereBool = hereNode->get()) + { + if (thereBool->value() == hereBool->value()) + return *thereNode; + else + return TNever{}; + } + else if (hereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (hereNode->get()) + { + if (thereNode->get()) + return TBoolean{}; + else if (thereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (thereNode->get()) + { + if (hereNode->get()) + return TBoolean{}; + else if (hereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (hereNode->get()) + { + if (thereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (thereNode->get()) + { + if (hereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (hereNode->get()) + { + if (thereNode->get() || thereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (thereNode->get()) + { + if (hereNode->get() || hereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (hereNode->get() && thereNode->get()) + return std::nullopt; + if (hereNode->get() && isTerminal(*thereNode)) + return TNever{}; + if (thereNode->get() && isTerminal(*hereNode)) + return TNever{}; + if (isTerminal(*hereNode) && isTerminal(*thereNode)) + { + // We already know that 'here' and 'there' are different classes. + return TNever{}; + } + + return std::nullopt; +} + +void Simplifier::uninhabitedIntersection(Id id) +{ + for (const auto [intersection, index] : Query(&egraph, id)) + { + Slice parts = intersection->operands(); + + if (parts.empty()) + { + Id never = egraph.add(TNever{}); + subst(id, never, "uninhabitedIntersection"); + return; + } + else if (1 == parts.size()) + { + subst(id, parts[0], "uninhabitedIntersection"); + return; + } + + Id accumulator = egraph.add(TUnknown{}); + EType accumulatorNode = TUnknown{}; + + std::vector unsimplified; + + if (occurs(egraph, id, parts)) + continue; + + for (Id partId : parts) + { + if (isTag(partId)) + return; + + bool found = false; + + const auto& partNodes = egraph[partId].nodes; + for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex) + { + const EType& N = partNodes[partIndex]; + if (std::optional intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N)) + { + if (isTag(*intersection)) + { + subst(id, egraph.add(TNever{}), "uninhabitedIntersection", {{id, index}, {partId, partIndex}}); + return; + } + + accumulator = egraph.add(*intersection); + accumulatorNode = *intersection; + found = true; + break; + } + } + + if (!found) + unsimplified.push_back(partId); + } + + if ((unsimplified.empty() || !isTag(accumulator)) && find(accumulator) != id) + unsimplified.push_back(accumulator); + + const Id result = mkIntersection(egraph, std::move(unsimplified)); + + subst(id, result, "uninhabitedIntersection", {{id, index}}); + } +} + +void Simplifier::intersectWithNegatedClass(Id id) +{ + for (const auto pair : Query(&egraph, id)) + { + const Intersection* intersection = pair.first; + const size_t intersectionIndex = pair.second; + + auto trySubst = [&](size_t i, size_t j) + { + Id iId = intersection->operands()[i]; + Id jId = intersection->operands()[j]; + + for (const auto [negation, negationIndex] : Query(&egraph, jId)) + { + const Id negated = negation->operands()[0]; + + if (iId == negated) + { + subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {jId, negationIndex}}); + return; + } + + for (const auto [negatedClass, negatedClassIndex] : Query(&egraph, negated)) + { + const auto& iNodes = egraph[iId].nodes; + for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex) + { + const EType& iNode = iNodes[iIndex]; + if (isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || + isTag(iNode) || + // isTag(iNode) || // I'm not sure about this one. + isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode)) + { + // eg string & ~SomeClass + subst(id, iId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); + return; + } + + if (const TClass* class_ = iNode.get()) + { + switch (relateClasses(class_, negatedClass)) + { + case LeftSuper: + // eg Instance & ~Part + // This cannot be meaningfully reduced. + continue; + case RightSuper: + subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); + return; + case Unrelated: + // Part & ~Folder == Part + { + std::vector newParts; + newParts.reserve(intersection->operands().size() - 1); + for (Id part : intersection->operands()) + { + if (part != jId) + newParts.push_back(part); + } + + Id substId = egraph.add(Intersection{newParts.begin(), newParts.end()}); + subst(id, substId, "intersectClassWithNegatedClass", {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}}); + } + } + } + } + } + } + }; + + if (2 != intersection->operands().size()) + continue; + + trySubst(0, 1); + trySubst(1, 0); + } +} + +void Simplifier::intersectWithNoRefine(Id id) +{ + for (const auto pair : Query(&egraph, id)) + { + const Intersection* intersection = pair.first; + const size_t intersectionIndex = pair.second; + + const Slice intersectionOperands = intersection->operands(); + + for (size_t index = 0; index < intersectionOperands.size(); ++index) + { + const auto replace = [&]() + { + std::vector newOperands{intersectionOperands.begin(), intersectionOperands.end()}; + newOperands.erase(newOperands.begin() + index); + + Id substId = egraph.add(Intersection{std::move(newOperands)}); + + subst(id, substId, "intersectWithNoRefine", {{id, intersectionIndex}}); + }; + + if (isTag(intersectionOperands[index])) + replace(); + else + { + for (const auto [negation, negationIndex] : Query(&egraph, intersectionOperands[index])) + { + if (isTag(negation->operands()[0])) + { + replace(); + break; + } + } + } + } + } +} + +/* + * Replace x where x = A & (B | x) with A + * + * Important subtlety: The egraph is routinely going to create cyclic unions and + * intersections. We can't arbitrarily remove things from a union just because + * it can be referred to in a cyclic way. We must only do this for things that + * can only be expressed in a cyclic way. + * + * As an example, we will bind the following type to true: + * + * (true | buffer | class | function | number | string | table | thread) & + * boolean + * + * The egraph represented by this type will indeed be cyclic as the 'true' class + * includes both 'true' itself and the above type, but removing true from the + * union will result is an incorrect judgment! + * + * The solution (for now) is only to consider a type to be cyclic if it was + * cyclic on its original import. + * + * FIXME: I still don't think this is quite right, but I don't know how to + * articulate what the actual rule ought to be. + */ +void Simplifier::cyclicIntersectionOfUnion(Id id) +{ + // FIXME: This has pretty terrible runtime complexity. + + for (const auto [i, intersectionIndex] : Query(&egraph, id)) + { + Slice intersectionParts = i->operands(); + for (size_t intersectionOperandIndex = 0; intersectionOperandIndex < intersectionParts.size(); ++intersectionOperandIndex) + { + const Id intersectionPart = find(intersectionParts[intersectionOperandIndex]); + + for (const auto [bound, _boundIndex] : Query(&egraph, intersectionPart)) + { + const Id pointee = find(mappingIdToClass.at(bound->value())); + + for (const auto [u, unionIndex] : Query(&egraph, pointee)) + { + const Slice& unionOperands = u->operands(); + for (size_t unionOperandIndex = 0; unionOperandIndex < unionOperands.size(); ++unionOperandIndex) + { + Id unionOperand = find(unionOperands[unionOperandIndex]); + if (unionOperand == id) + { + std::vector newIntersectionParts(intersectionParts.begin(), intersectionParts.end()); + newIntersectionParts.erase(newIntersectionParts.begin() + intersectionOperandIndex); + + subst( + id, + mkIntersection(egraph, std::move(newIntersectionParts)), + "cyclicIntersectionOfUnion", + {{id, intersectionIndex}, {pointee, unionIndex}} + ); + } + } + } + } + } + } +} + +void Simplifier::cyclicUnionOfIntersection(Id id) +{ + // FIXME: This has pretty terrible runtime complexity. + + for (const auto [union_, unionIndex] : Query(&egraph, id)) + { + Slice unionOperands = union_->operands(); + for (size_t unionOperandIndex = 0; unionOperandIndex < unionOperands.size(); ++unionOperandIndex) + { + const Id unionPart = find(unionOperands[unionOperandIndex]); + + for (const auto [bound, _boundIndex] : Query(&egraph, unionPart)) + { + const Id pointee = find(mappingIdToClass.at(bound->value())); + + for (const auto [intersection, intersectionIndex] : Query(&egraph, pointee)) + { + Slice intersectionOperands = intersection->operands(); + for (size_t intersectionOperandIndex = 0; intersectionOperandIndex < intersectionOperands.size(); ++intersectionOperandIndex) + { + const Id intersectionPart = find(intersectionOperands[intersectionOperandIndex]); + if (intersectionPart == id) + { + std::vector newIntersectionParts(intersectionOperands.begin(), intersectionOperands.end()); + newIntersectionParts.erase(newIntersectionParts.begin() + intersectionOperandIndex); + + if (!newIntersectionParts.empty()) + { + Id newIntersection = mkIntersection(egraph, std::move(newIntersectionParts)); + + std::vector newIntersectionParts(unionOperands.begin(), unionOperands.end()); + newIntersectionParts.erase(newIntersectionParts.begin() + unionOperandIndex); + newIntersectionParts.push_back(newIntersection); + + subst( + id, + mkUnion(egraph, std::move(newIntersectionParts)), + "cyclicUnionOfIntersection", + {{id, unionIndex}, {pointee, intersectionIndex}} + ); + } + } + } + } + } + } + } +} + +void Simplifier::expandNegation(Id id) +{ + for (const auto [negation, index] : Query{&egraph, id}) + { + if (isTag(negation->operands()[0])) + return; + + CanonicalizedType canonicalized; + unionUnknown(egraph, canonicalized); + + const bool ok = subtract(egraph, canonicalized, negation->operands()[0]); + if (!ok) + continue; + + subst(id, fromCanonicalized(egraph, canonicalized), "expandNegation", {{id, index}}); + } +} + +/** + * Let A be a class-node having the form B & C1 & ... & Cn + * And B be a class-node having the form (D | E) + * + * Create a class containing the node (C1 & ... & Cn & D) | (C1 & ... & Cn & E) + * + * This function does nothing and returns nullopt if A and B are cyclic. + */ +static std::optional distributeIntersectionOfUnion( + EGraph& egraph, + Id outerClass, + const Intersection* outerIntersection, + Id innerClass, + const Union* innerUnion +) +{ + Slice outerOperands = outerIntersection->operands(); + + std::vector newOperands; + newOperands.reserve(innerUnion->operands().size()); + for (Id innerOperand : innerUnion->operands()) + { + if (isTag(egraph, innerOperand)) + continue; + + if (innerOperand == outerClass) + { + // Skip cyclic intersections of unions. There's a separate + // rule to get rid of those. + return std::nullopt; + } + + std::vector intersectionParts; + intersectionParts.reserve(outerOperands.size()); + intersectionParts.push_back(innerOperand); + + for (const Id op : outerOperands) + { + if (isTag(egraph, op)) + { + break; + } + if (op != innerClass) + intersectionParts.push_back(op); + } + + newOperands.push_back(mkIntersection(egraph, intersectionParts)); + } + + return mkUnion(egraph, std::move(newOperands)); +} + +// A & (B | C) -> (A & B) | (A & C) +// +// A & B & (C | D) -> A & (B & (C | D)) +// -> A & ((B & C) | (B & D)) +// -> (A & B & C) | (A & B & D) +void Simplifier::intersectionOfUnion(Id id) +{ + id = find(id); + + for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) + { + // For each operand O + // For each node N + // If N is a union U + // Create a new union comprised of every operand except O intersected with every operand of U + const Slice operands = intersection->operands(); + + if (operands.size() < 2) + return; + + if (occurs(egraph, id, operands)) + continue; + + for (Id operand : operands) + { + operand = find(operand); + if (operand == id) + break; + // Optimization: Decline to distribute any unions on an eclass that + // also contains a terminal node. + if (isTerminal(egraph, operand)) + continue; + + for (const auto [operandUnion, unionIndex] : Query(&egraph, operand)) + { + if (occurs(egraph, id, operandUnion->operands())) + continue; + + std::optional distributed = distributeIntersectionOfUnion(egraph, id, intersection, operand, operandUnion); + + if (distributed) + subst(id, *distributed, "intersectionOfUnion", {{id, intersectionIndex}, {operand, unionIndex}}); + } + } + } +} + +// {"a": b} & {"a": c, ...} => {"a": b & c, ...} +void Simplifier::intersectTableProperty(Id id) +{ + for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) + { + const Slice intersectionParts = intersection->operands(); + for (size_t i = 0; i < intersection->operands().size(); ++i) + { + const Id iId = intersection->operands()[i]; + + for (size_t j = 0; j < intersection->operands().size(); ++j) + { + if (i == j) + continue; + + const Id jId = intersection->operands()[j]; + + if (iId == jId) + continue; + + for (const auto [table1, table1Index] : Query(&egraph, iId)) + { + const TableType* table1Ty = Luau::get(table1->value()); + LUAU_ASSERT(table1Ty); + + if (table1Ty->props.size() != 1) + continue; + + for (const auto [table2, table2Index] : Query(&egraph, jId)) + { + const TableType* table2Ty = Luau::get(table2->value()); + LUAU_ASSERT(table2Ty); + + auto it = table2Ty->props.find(table1Ty->props.begin()->first); + if (it != table2Ty->props.end()) + { + std::vector newIntersectionParts; + newIntersectionParts.reserve(intersectionParts.size() - 1); + + for (size_t index = 0; index < intersectionParts.size(); ++index) + { + if (index != i && index != j) + newIntersectionParts.push_back(intersectionParts[index]); + } + + Id newTableProp = egraph.add(Intersection{ + toId(egraph, builtinTypes, mappingIdToClass, stringCache, it->second.type()), + toId(egraph, builtinTypes, mappingIdToClass, stringCache, table1Ty->props.begin()->second.type()) + }); + + newIntersectionParts.push_back(egraph.add(TTable{jId, {stringCache.add(it->first)}, {newTableProp}})); + + subst( + id, + egraph.add(Intersection{std::move(newIntersectionParts)}), + "intersectTableProperty", + {{id, intersectionIndex}, {iId, table1Index}, {jId, table2Index}} + ); + } + } + } + } + } + } +} + +// { prop: never } == never +void Simplifier::uninhabitedTable(Id id) +{ + for (const auto [table, tableIndex] : Query(&egraph, id)) + { + const TableType* tt = Luau::get(table->value()); + LUAU_ASSERT(tt); + + for (const auto& [propName, prop] : tt->props) + { + if (prop.readTy && Luau::get(follow(*prop.readTy))) + { + subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); + return; + } + + if (prop.writeTy && Luau::get(follow(*prop.writeTy))) + { + subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); + return; + } + } + } + + for (const auto [table, tableIndex] : Query(&egraph, id)) + { + for (Id propType : table->propTypes()) + { + if (isTag(propType)) + { + subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); + return; + } + } + } +} + +void Simplifier::unneededTableModification(Id id) +{ + for (const auto [tbl, tblIndex] : Query(&egraph, id)) + { + const Id basis = tbl->getBasis(); + for (const auto [importedTbl, importedTblIndex] : Query(&egraph, basis)) + { + const TableType* tt = Luau::get(importedTbl->value()); + LUAU_ASSERT(tt); + + bool skip = false; + + for (size_t i = 0; i < tbl->propNames.size(); ++i) + { + StringId propName = tbl->propNames[i]; + const Id propType = tbl->propTypes()[i]; + + Id importedProp = toId(egraph, builtinTypes, mappingIdToClass, stringCache, tt->props.at(stringCache.asString(propName)).type()); + + if (find(importedProp) != find(propType)) + { + skip = true; + break; + } + } + + if (!skip) + subst(id, basis, "unneededTableModification", {{id, tblIndex}, {basis, importedTblIndex}}); + } + } +} + +void Simplifier::builtinTypeFunctions(Id id) +{ + for (const auto [tfun, index] : Query(&egraph, id)) + { + const Slice& args = tfun->operands(); + + if (args.size() != 2) + continue; + + const std::string& name = tfun->value()->name; + if (name == "add" || name == "sub" || name == "mul" || name == "div" || name == "idiv" || name == "pow" || name == "mod") + { + if (isTag(args[0]) && isTag(args[1])) + { + subst(id, add(TNumber{}), "builtinTypeFunctions", {{id, index}}); + } + } + } +} + +// Replace union<>, intersect<>, and refine<> with unions or intersections. +// These type functions exist primarily to cause simplification to defer until +// particular points in execution, so it is safe to get rid of them here. +// +// It's not clear that these type functions should exist at all. +void Simplifier::iffyTypeFunctions(Id id) +{ + for (const auto [tfun, index] : Query(&egraph, id)) + { + const Slice& args = tfun->operands(); + + const std::string& name = tfun->value()->name; + + if (name == "union") + subst(id, add(Union{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); + else if (name == "intersect" || name == "refine") + subst(id, add(Intersection{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); + } +} + +static void deleteSimplifier(Simplifier* s) +{ + delete s; +} + +SimplifierPtr newSimplifier(NotNull arena, NotNull builtinTypes) +{ + return SimplifierPtr{new Simplifier(arena, builtinTypes), &deleteSimplifier}; +} + +} // namespace Luau::EqSatSimplification + +namespace Luau +{ + +std::optional eqSatSimplify(NotNull simplifier, TypeId ty) +{ + using namespace Luau::EqSatSimplification; + + std::unordered_map newMappings; + Id rootId = toId(simplifier->egraph, simplifier->builtinTypes, newMappings, simplifier->stringCache, ty); + simplifier->mappingIdToClass.insert(newMappings.begin(), newMappings.end()); + + Simplifier::RewriteRuleFn rules[] = { + &Simplifier::simplifyUnion, + &Simplifier::uninhabitedIntersection, + &Simplifier::intersectWithNegatedClass, + &Simplifier::intersectWithNoRefine, + &Simplifier::cyclicIntersectionOfUnion, + &Simplifier::cyclicUnionOfIntersection, + &Simplifier::expandNegation, + &Simplifier::intersectionOfUnion, + &Simplifier::intersectTableProperty, + &Simplifier::uninhabitedTable, + &Simplifier::unneededTableModification, + &Simplifier::builtinTypeFunctions, + &Simplifier::iffyTypeFunctions, + }; + + std::unordered_set seen; + VecDeque worklist; + + bool progressed = true; + + int count = 0; + const int MAX_COUNT = 1000; + + if (FFlag::DebugLuauLogSimplification) + std::ofstream("begin.dot") << toDot(simplifier->stringCache, simplifier->egraph); + + auto& egraph = simplifier->egraph; + const auto& builtinTypes = simplifier->builtinTypes; + auto& arena = simplifier->arena; + + if (FFlag::DebugLuauLogSimplification) + printf(">> simplify %s\n", toString(ty).c_str()); + + while (progressed && count < MAX_COUNT) + { + progressed = false; + worklist.clear(); + seen.clear(); + + rootId = egraph.find(rootId); + + worklist.push_back(rootId); + + if (FFlag::DebugLuauLogSimplification) + { + std::vector newTypeFunctions; + const TypeId t = fromId(egraph, simplifier->stringCache, builtinTypes, arena, newTypeFunctions, rootId); + + std::cout << "Begin (" << uint32_t(egraph.find(rootId)) << ")\t" << toString(t) << '\n'; + } + + while (!worklist.empty() && count < MAX_COUNT) + { + Id id = egraph.find(worklist.front()); + worklist.pop_front(); + + const bool isFresh = seen.insert(id).second; + if (!isFresh) + continue; + + simplifier->substs.clear(); + + // Optimization: If this class alraedy has a terminal node, don't + // try to run any rules on it. + bool shouldAbort = false; + + for (const EType& enode : egraph[id].nodes) + { + if (isTerminal(enode)) + { + shouldAbort = true; + break; + } + } + + if (shouldAbort) + continue; + + for (const EType& enode : egraph[id].nodes) + addChildren(egraph, &enode, worklist); + + for (Simplifier::RewriteRuleFn rule : rules) + (simplifier.get()->*rule)(id); + + if (simplifier->substs.empty()) + continue; + + for (const Subst& subst : simplifier->substs) + { + if (subst.newClass == subst.eclass) + continue; + + if (FFlag::DebugLuauExtraEqSatSanityChecks) + { + const Id never = egraph.find(egraph.add(TNever{})); + const Id str = egraph.find(egraph.add(TString{})); + const Id unk = egraph.find(egraph.add(TUnknown{})); + LUAU_ASSERT(never != str); + LUAU_ASSERT(never != unk); + } + + const bool isFresh = egraph.merge(subst.newClass, subst.eclass); + + ++count; + + if (FFlag::DebugLuauLogSimplification) + { + if (isFresh) + std::cout << "count=" << std::setw(3) << count << "\t" << subst.desc << '\n'; + + std::string filename = format("step%03d.dot", count); + std::ofstream(filename) << toDot(simplifier->stringCache, egraph); + } + + if (FFlag::DebugLuauExtraEqSatSanityChecks) + { + const Id never = egraph.find(egraph.add(TNever{})); + const Id str = egraph.find(egraph.add(TString{})); + const Id unk = egraph.find(egraph.add(TUnknown{})); + const Id trueId = egraph.find(egraph.add(SBoolean{true})); + + LUAU_ASSERT(never != str); + LUAU_ASSERT(never != unk); + LUAU_ASSERT(never != trueId); + } + + progressed |= isFresh; + } + + egraph.rebuild(); + } + } + + EqSatSimplificationResult result; + result.result = fromId(egraph, simplifier->stringCache, builtinTypes, arena, result.newTypeFunctions, rootId); + + if (FFlag::DebugLuauLogSimplification) + printf("<< simplify %s\n", toString(result.result).c_str()); + + return result; +} + +} // namespace Luau diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp index d4f3ebd9..3395f125 100644 --- a/Analysis/src/FragmentAutocomplete.cpp +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -4,6 +4,7 @@ #include "Luau/Ast.h" #include "Luau/AstQuery.h" #include "Luau/Common.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Parser.h" #include "Luau/ParseOptions.h" #include "Luau/Module.h" @@ -18,11 +19,14 @@ #include "Luau/ParseOptions.h" #include "Luau/Module.h" +#include "AutocompleteCore.h" + LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauStoreDFGOnModule2); +LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) namespace { @@ -41,7 +45,6 @@ void copyModuleMap(Luau::DenseHashMap& result, const Luau::DenseHashMap getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos) +/** + * Get document offsets is a function that takes a source text document as well as a start position and end position(line, column) in that + * document and attempts to get the concrete text between those points. It returns a pair of: + * - start offset that represents an index in the source `char*` corresponding to startPos + * - length, that represents how many more bytes to read to get to endPos. + * Example - your document is "foo bar baz" and getDocumentOffsets is passed (1, 4) - (1, 8). This function returns the pair {3, 7}, + * which corresponds to the string " bar " + */ +std::pair getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos) { - unsigned int lineCount = 0; - unsigned int colCount = 0; + size_t lineCount = 0; + size_t colCount = 0; - unsigned int docOffset = 0; - unsigned int startOffset = 0; - unsigned int endOffset = 0; + size_t docOffset = 0; + size_t startOffset = 0; + size_t endOffset = 0; bool foundStart = false; bool foundEnd = false; for (char c : src) @@ -115,6 +126,13 @@ std::pair getDocumentOffsets(const std::string_view& foundEnd = true; } + // We put a cursor position that extends beyond the extents of the current line + if (foundStart && !foundEnd && (lineCount > endPos.line)) + { + foundEnd = true; + endOffset = docOffset - 1; + } + if (c == '\n') { lineCount++; @@ -125,20 +143,24 @@ std::pair getDocumentOffsets(const std::string_view& docOffset++; } + if (foundStart && !foundEnd) + endOffset = src.length(); - unsigned int min = std::min(startOffset, endOffset); - unsigned int len = std::max(startOffset, endOffset) - min; + size_t min = std::min(startOffset, endOffset); + size_t len = std::max(startOffset, endOffset) - min; return {min, len}; } -ScopePtr findClosestScope(const ModulePtr& module, const Position& cursorPos) +ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStatement) { LUAU_ASSERT(module->hasModuleScope()); ScopePtr closest = module->getModuleScope(); + + // find the scope the nearest statement belonged to. for (auto [loc, sc] : module->scopes) { - if (loc.begin <= cursorPos && closest->location.begin <= loc.begin) + if (loc.encloses(nearestStatement->location) && closest->location.begin <= loc.begin) closest = sc; } @@ -152,13 +174,27 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie opts.allowDeclarationSyntax = false; opts.captureComments = false; opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack)}; - AstStat* enclosingStatement = result.nearestStatement; + AstStat* nearestStatement = result.nearestStatement; - const Position& endPos = cursorPos; - // If the statement starts on a previous line, grab the statement beginning - // otherwise, grab the statement end to whatever is being typed right now - const Position& startPos = - enclosingStatement->location.begin.line == cursorPos.line ? enclosingStatement->location.begin : enclosingStatement->location.end; + const Location& rootSpan = srcModule.root->location; + // Did we append vs did we insert inline + bool appended = cursorPos >= rootSpan.end; + // statement spans multiple lines + bool multiline = nearestStatement->location.begin.line != nearestStatement->location.end.line; + + const Position endPos = cursorPos; + + // We start by re-parsing everything (we'll refine this as we go) + Position startPos = srcModule.root->location.begin; + + // If we added to the end of the sourceModule, use the end of the nearest location + if (appended && multiline) + startPos = nearestStatement->location.end; + // Statement spans one line && cursorPos is on a different line + else if (!multiline && cursorPos.line != nearestStatement->location.end.line) + startPos = nearestStatement->location.end; + else + startPos = nearestStatement->location.begin; auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos); @@ -173,10 +209,11 @@ FragmentParseResult parseFragment(const SourceModule& srcModule, std::string_vie std::vector fabricatedAncestry = std::move(result.ancestry); std::vector fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, p.root->location.end); fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end()); - if (enclosingStatement == nullptr) - enclosingStatement = p.root; + if (nearestStatement == nullptr) + nearestStatement = p.root; fragmentResult.root = std::move(p.root); fragmentResult.ancestry = std::move(fabricatedAncestry); + fragmentResult.nearestStatement = nearestStatement; return fragmentResult; } @@ -205,7 +242,7 @@ ModulePtr copyModule(const ModulePtr& result, std::unique_ptr alloc) return incrementalModule; } -FragmentTypeCheckResult typeCheckFragmentHelper( +FragmentTypeCheckResult typecheckFragment_( Frontend& frontend, AstStatBlock* root, const ModulePtr& stale, @@ -245,15 +282,18 @@ FragmentTypeCheckResult typeCheckFragmentHelper( /// Create a DataFlowGraph just for the surrounding context auto updatedDfg = DataFlowGraphBuilder::updateGraph(*stale->dataFlowGraph.get(), stale->dfgScopes, root, cursorPos, iceHandler); + SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes); + /// Contraint Generator ConstraintGenerator cg{ incrementalModule, NotNull{&normalizer}, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, NotNull{&frontend.moduleResolver}, frontend.builtinTypes, iceHandler, - frontend.globals.globalScope, + stale->getModuleScope(), nullptr, nullptr, NotNull{&updatedDfg}, @@ -262,7 +302,7 @@ FragmentTypeCheckResult typeCheckFragmentHelper( cg.rootScope = stale->getModuleScope().get(); // Any additions to the scope must occur in a fresh scope auto freshChildOfNearestScope = std::make_shared(closestScope); - incrementalModule->scopes.push_back({root->location, freshChildOfNearestScope}); + incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); // closest Scope -> children = { ...., freshChildOfNearestScope} // We need to trim nearestChild from the scope hierarcy @@ -274,9 +314,11 @@ FragmentTypeCheckResult typeCheckFragmentHelper( LUAU_ASSERT(back == freshChildOfNearestScope.get()); closestScope->children.pop_back(); + /// Initialize the constraint solver and run it ConstraintSolver cs{ NotNull{&normalizer}, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), @@ -307,7 +349,7 @@ FragmentTypeCheckResult typeCheckFragmentHelper( freeze(incrementalModule->internalTypes); freeze(incrementalModule->interfaceTypes); - return {std::move(incrementalModule), freshChildOfNearestScope.get()}; + return {std::move(incrementalModule), std::move(freshChildOfNearestScope)}; } @@ -327,27 +369,51 @@ FragmentTypeCheckResult typecheckFragment( } ModulePtr module = frontend.moduleResolver.getModule(moduleName); - const ScopePtr& closestScope = findClosestScope(module, cursorPos); - - - FragmentParseResult r = parseFragment(*sourceModule, src, cursorPos); + FragmentParseResult parseResult = parseFragment(*sourceModule, src, cursorPos); FrontendOptions frontendOptions = opts.value_or(frontend.options); - return typeCheckFragmentHelper(frontend, r.root, module, closestScope, cursorPos, std::move(r.alloc), frontendOptions); + const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement); + FragmentTypeCheckResult result = + typecheckFragment_(frontend, parseResult.root, module, closestScope, cursorPos, std::move(parseResult.alloc), frontendOptions); + result.ancestry = std::move(parseResult.ancestry); + return result; } -AutocompleteResult fragmentAutocomplete( + +FragmentAutocompleteResult fragmentAutocomplete( Frontend& frontend, std::string_view src, const ModuleName& moduleName, - Position& cursorPosition, - const FrontendOptions& opts, + Position cursorPosition, + std::optional opts, StringCompletionCallback callback ) { LUAU_ASSERT(FFlag::LuauSolverV2); LUAU_ASSERT(FFlag::LuauAllowFragmentParsing); LUAU_ASSERT(FFlag::LuauStoreDFGOnModule2); - return {}; + LUAU_ASSERT(FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete); + + const SourceModule* sourceModule = frontend.getSourceModule(moduleName); + if (!sourceModule) + { + LUAU_ASSERT(!"Expected Source Module for fragment typecheck"); + return {}; + } + + auto tcResult = typecheckFragment(frontend, moduleName, cursorPosition, opts, src); + TypeArena arenaForFragmentAutocomplete; + auto result = Luau::autocomplete_( + tcResult.incrementalModule, + frontend.builtinTypes, + &arenaForFragmentAutocomplete, + tcResult.ancestry, + frontend.globals.globalScope.get(), + tcResult.freshScope, + cursorPosition, + frontend.fileResolver, + callback + ); + return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForFragmentAutocomplete), std::move(result)}; } } // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index e94b4a29..261e3781 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -10,6 +10,7 @@ #include "Luau/ConstraintSolver.h" #include "Luau/DataFlowGraph.h" #include "Luau/DcrLogger.h" +#include "Luau/EqSatSimplification.h" #include "Luau/FileResolver.h" #include "Luau/NonStrictTypeChecker.h" #include "Luau/Parser.h" @@ -46,7 +47,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode) LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionNoEvaluation) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) -LUAU_FASTFLAGVARIABLE(LuauMoreThoroughCycleDetection) LUAU_FASTFLAG(StudioReportLuauAny2) LUAU_FASTFLAGVARIABLE(LuauStoreDFGOnModule2) @@ -287,8 +287,7 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector getRequireCycles( const FileResolver* resolver, const std::unordered_map>& sourceNodes, - const SourceNode* start, - bool stopAtFirst = false + const SourceNode* start ) { std::vector result; @@ -358,9 +357,6 @@ std::vector getRequireCycles( { result.push_back({depLocation, std::move(cycle)}); - if (stopAtFirst) - return result; - // note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start // so it's safe to *only* clear seen vector when we find a cycle // if we don't do it, we will not have correct reporting for some cycles @@ -884,18 +880,11 @@ void Frontend::addBuildQueueItems( data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete); data.recordJsonLog = FFlag::DebugLuauLogSolverToJson; - const Mode mode = sourceModule->mode.value_or(data.config.mode); - // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // all correct programs must be acyclic so this code triggers rarely if (cycleDetected) - { - if (FFlag::LuauMoreThoroughCycleDetection) - data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), false); - else - data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck); - } + data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get()); data.options = frontendOptions; @@ -1334,6 +1323,7 @@ ModulePtr check( unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; + SimplifierPtr simplifier = newSimplifier(NotNull{&result->internalTypes}, builtinTypes); TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}}; if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) @@ -1342,6 +1332,7 @@ ModulePtr check( ConstraintGenerator cg{ result, NotNull{&normalizer}, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, moduleResolver, builtinTypes, @@ -1358,6 +1349,7 @@ ModulePtr check( ConstraintSolver cs{ NotNull{&normalizer}, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index dd5a2f85..1618b78f 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -132,7 +132,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a return dest.addType(NegationType{a.ty}); else if constexpr (std::is_same_v) { - TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName}; + TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncData}; return dest.addType(std::move(clone)); } else diff --git a/Analysis/src/Symbol.cpp b/Analysis/src/Symbol.cpp index 5e5b9d8c..a5117608 100644 --- a/Analysis/src/Symbol.cpp +++ b/Analysis/src/Symbol.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(LuauSymbolEquality) namespace Luau { @@ -14,7 +15,7 @@ bool Symbol::operator==(const Symbol& rhs) const return local == rhs.local; else if (global.value) return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. - else if (FFlag::LuauSolverV2) + else if (FFlag::LuauSolverV2 || FFlag::LuauSymbolEquality) return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. else return false; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 0bb7344a..60ed3027 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -870,6 +870,8 @@ struct TypeStringifier return; } + LUAU_ASSERT(uv.options.size() > 1); + bool optional = false; bool hasNonNilDisjunct = false; @@ -878,7 +880,7 @@ struct TypeStringifier { el = follow(el); - if (isNil(el)) + if (state.opts.useQuestionMarks && isNil(el)) { optional = true; continue; diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index 0193f4f1..d0ad82ec 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -51,6 +51,7 @@ LUAU_FASTFLAG(LuauUserDefinedTypeFunctionNoEvaluation) LUAU_FASTFLAG(LuauUserTypeFunFixRegister) LUAU_FASTFLAG(LuauRemoveNotAnyHack) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionResetState) +LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal) LUAU_DYNAMIC_FASTINT(LuauTypeSolverRelease) @@ -610,10 +611,29 @@ TypeFunctionReductionResult userDefinedTypeFunction( NotNull ctx ) { - if (!ctx->userFuncName) + auto typeFunction = getMutable(instance); + + if (FFlag::LuauUserTypeFunExportedAndLocal) { - ctx->ice->ice("all user-defined type functions must have an associated function definition"); - return {std::nullopt, true, {}, {}}; + if (typeFunction->userFuncData.owner.expired()) + { + ctx->ice->ice("user-defined type function module has expired"); + return {std::nullopt, true, {}, {}}; + } + + if (!typeFunction->userFuncName || !typeFunction->userFuncData.definition) + { + ctx->ice->ice("all user-defined type functions must have an associated function definition"); + return {std::nullopt, true, {}, {}}; + } + } + else + { + if (!ctx->userFuncName) + { + ctx->ice->ice("all user-defined type functions must have an associated function definition"); + return {std::nullopt, true, {}, {}}; + } } if (FFlag::LuauUserDefinedTypeFunctionNoEvaluation) @@ -632,7 +652,22 @@ TypeFunctionReductionResult userDefinedTypeFunction( return {std::nullopt, false, {ty}, {}}; } - AstName name = *ctx->userFuncName; + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Ensure that whole type function environment is registered + for (auto& [name, definition] : typeFunction->userFuncData.environment) + { + if (std::optional error = ctx->typeFunctionRuntime->registerFunction(definition)) + { + // Failure to register at this point means that original definition had to error out and should not have been present in the + // environment + ctx->ice->ice("user-defined type function reference cannot be registered"); + return {std::nullopt, true, {}, {}}; + } + } + } + + AstName name = FFlag::LuauUserTypeFunExportedAndLocal ? typeFunction->userFuncData.definition->name : *ctx->userFuncName; lua_State* global = ctx->typeFunctionRuntime->state.get(); @@ -643,8 +678,44 @@ TypeFunctionReductionResult userDefinedTypeFunction( lua_State* L = lua_newthread(global); LuauTempThreadPopper popper(global); - lua_getglobal(global, name.value); - lua_xmove(global, L, 1); + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Fetch the function we want to evaluate + lua_pushlightuserdata(L, typeFunction->userFuncData.definition); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + { + ctx->ice->ice("user-defined type function reference cannot be found in the registry"); + return {std::nullopt, true, {}, {}}; + } + + // Build up the environment + lua_getfenv(L, -1); + lua_setreadonly(L, -1, false); + + for (auto& [name, definition] : typeFunction->userFuncData.environment) + { + lua_pushlightuserdata(L, definition); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + { + ctx->ice->ice("user-defined type function reference cannot be found in the registry"); + return {std::nullopt, true, {}, {}}; + } + + lua_setfield(L, -2, name.c_str()); + } + + lua_setreadonly(L, -1, true); + lua_pop(L, 1); + } + else + { + lua_getglobal(global, name.value); + lua_xmove(global, L, 1); + } if (FFlag::LuauUserDefinedTypeFunctionResetState) resetTypeFunctionState(L); @@ -693,7 +764,7 @@ TypeFunctionReductionResult userDefinedTypeFunction( TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get()); - // At least 1 error occured while deserializing + // At least 1 error occurred while deserializing if (runtimeBuilder->errors.size() > 0) return {std::nullopt, true, {}, {}, runtimeBuilder->errors.front()}; @@ -935,6 +1006,23 @@ std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunc prepareState(); + lua_State* global = state.get(); + + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Fetch to check if function is already registered + lua_pushlightuserdata(global, function); + lua_gettable(global, LUA_REGISTRYINDEX); + + if (!lua_isnil(global, -1)) + { + lua_pop(global, 1); + return std::nullopt; + } + + lua_pop(global, 1); + } + AstName name = function->name; // Construct ParseResult containing the type function @@ -961,7 +1049,6 @@ std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunc std::string bytecode = builder.getBytecode(); - lua_State* global = state.get(); // Separate sandboxed thread for individual execution and private globals lua_State* L = lua_newthread(global); @@ -989,9 +1076,19 @@ std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunc return format("Could not find '%s' type function in the global scope", name.value); } - // Store resulting function in the global environment - lua_xmove(L, global, 1); - lua_setglobal(global, name.value); + if (FFlag::LuauUserTypeFunExportedAndLocal) + { + // Store resulting function in the registry + lua_pushlightuserdata(global, function); + lua_xmove(L, global, 1); + lua_settable(global, LUA_REGISTRYINDEX); + } + else + { + // Store resulting function in the global environment + lua_xmove(L, global, 1); + lua_setglobal(global, name.value); + } return std::nullopt; } diff --git a/Ast/include/Luau/Allocator.h b/Ast/include/Luau/Allocator.h new file mode 100644 index 00000000..7fd951ae --- /dev/null +++ b/Ast/include/Luau/Allocator.h @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Location.h" +#include "Luau/DenseHash.h" +#include "Luau/Common.h" + +#include + +namespace Luau +{ + +class Allocator +{ +public: + Allocator(); + Allocator(Allocator&&); + + Allocator& operator=(Allocator&&) = delete; + + ~Allocator(); + + void* allocate(size_t size); + + template + T* alloc(Args&&... args) + { + static_assert(std::is_trivially_destructible::value, "Objects allocated with this allocator will never have their destructors run!"); + + T* t = static_cast(allocate(sizeof(T))); + new (t) T(std::forward(args)...); + return t; + } + +private: + struct Page + { + Page* next; + + char data[8192]; + }; + + Page* root; + size_t offset; +}; + +} diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 7845cca2..736f24a2 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -316,16 +316,18 @@ public: enum QuoteStyle { - Quoted, + QuotedSimple, + QuotedRaw, Unquoted }; - AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle = Quoted); + AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle); void visit(AstVisitor* visitor) override; + bool isQuoted() const; AstArray value; - QuoteStyle quoteStyle = Quoted; + QuoteStyle quoteStyle; }; class AstExprLocal : public AstExpr @@ -876,13 +878,14 @@ class AstStatTypeFunction : public AstStat public: LUAU_RTTI(AstStatTypeFunction); - AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body); + AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body, bool exported); void visit(AstVisitor* visitor) override; AstName name; Location nameLocation; AstExprFunction* body; + bool exported; }; class AstStatDeclareGlobal : public AstStat diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index f6ac28ad..6c8f21c1 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Allocator.h" #include "Luau/Ast.h" #include "Luau/Location.h" #include "Luau/DenseHash.h" @@ -11,40 +12,6 @@ namespace Luau { -class Allocator -{ -public: - Allocator(); - Allocator(Allocator&&); - - Allocator& operator=(Allocator&&) = delete; - - ~Allocator(); - - void* allocate(size_t size); - - template - T* alloc(Args&&... args) - { - static_assert(std::is_trivially_destructible::value, "Objects allocated with this allocator will never have their destructors run!"); - - T* t = static_cast(allocate(sizeof(T))); - new (t) T(std::forward(args)...); - return t; - } - -private: - struct Page - { - Page* next; - - char data[8192]; - }; - - Page* root; - size_t offset; -}; - struct Lexeme { enum Type diff --git a/Ast/src/Allocator.cpp b/Ast/src/Allocator.cpp new file mode 100644 index 00000000..f8a99db4 --- /dev/null +++ b/Ast/src/Allocator.cpp @@ -0,0 +1,66 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Allocator.h" + +namespace Luau +{ + +Allocator::Allocator() + : root(static_cast(operator new(sizeof(Page)))) + , offset(0) +{ + root->next = nullptr; +} + +Allocator::Allocator(Allocator&& rhs) + : root(rhs.root) + , offset(rhs.offset) +{ + rhs.root = nullptr; + rhs.offset = 0; +} + +Allocator::~Allocator() +{ + Page* page = root; + + while (page) + { + Page* next = page->next; + + operator delete(page); + + page = next; + } +} + +void* Allocator::allocate(size_t size) +{ + constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double); + + if (root) + { + uintptr_t data = reinterpret_cast(root->data); + uintptr_t result = (data + offset + align - 1) & ~(align - 1); + if (result + size <= data + sizeof(root->data)) + { + offset = result - data + size; + return reinterpret_cast(result); + } + } + + // allocate new page + size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data); + void* pageData = operator new(offsetof(Page, data) + pageSize); + + Page* page = static_cast(pageData); + + page->next = root; + + root = page; + offset = size; + + return page->data; +} + +} diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index a72aca86..a06fcb09 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -92,6 +92,11 @@ void AstExprConstantString::visit(AstVisitor* visitor) visitor->visit(this); } +bool AstExprConstantString::isQuoted() const +{ + return quoteStyle == QuoteStyle::QuotedSimple || quoteStyle == QuoteStyle::QuotedRaw; +} + AstExprLocal::AstExprLocal(const Location& location, AstLocal* local, bool upvalue) : AstExpr(ClassIndex(), location) , local(local) @@ -760,11 +765,18 @@ void AstStatTypeAlias::visit(AstVisitor* visitor) } } -AstStatTypeFunction::AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body) +AstStatTypeFunction::AstStatTypeFunction( + const Location& location, + const AstName& name, + const Location& nameLocation, + AstExprFunction* body, + bool exported +) : AstStat(ClassIndex(), location) , name(name) , nameLocation(nameLocation) , body(body) + , exported(exported) { } diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 54540215..4fb9c936 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Lexer.h" +#include "Luau/Allocator.h" #include "Luau/Common.h" #include "Luau/Confusables.h" #include "Luau/StringUtils.h" @@ -10,64 +11,6 @@ namespace Luau { -Allocator::Allocator() - : root(static_cast(operator new(sizeof(Page)))) - , offset(0) -{ - root->next = nullptr; -} - -Allocator::Allocator(Allocator&& rhs) - : root(rhs.root) - , offset(rhs.offset) -{ - rhs.root = nullptr; - rhs.offset = 0; -} - -Allocator::~Allocator() -{ - Page* page = root; - - while (page) - { - Page* next = page->next; - - operator delete(page); - - page = next; - } -} - -void* Allocator::allocate(size_t size) -{ - constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double); - - if (root) - { - uintptr_t data = reinterpret_cast(root->data); - uintptr_t result = (data + offset + align - 1) & ~(align - 1); - if (result + size <= data + sizeof(root->data)) - { - offset = result - data + size; - return reinterpret_cast(result); - } - } - - // allocate new page - size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data); - void* pageData = operator new(offsetof(Page, data) + pageSize); - - Page* page = static_cast(pageData); - - page->next = root; - - root = page; - offset = size; - - return page->data; -} - Lexeme::Lexeme(const Location& location, Type type) : type(type) , location(location) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 1ca028f2..02d17c1d 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -21,6 +21,7 @@ LUAU_FASTFLAGVARIABLE(LuauSolverV2) LUAU_FASTFLAGVARIABLE(LuauNativeAttribute) LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr) LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctionsSyntax2) +LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunParseExport) LUAU_FASTFLAGVARIABLE(LuauAllowFragmentParsing) LUAU_FASTFLAGVARIABLE(LuauPortableStringZeroCheck) @@ -943,8 +944,11 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported) Lexeme matchFn = lexer.current(); nextLexeme(); - if (exported) - report(start, "Type function cannot be exported"); + if (!FFlag::LuauUserDefinedTypeFunParseExport) + { + if (exported) + report(start, "Type function cannot be exported"); + } // parse the name of the type function std::optional fnName = parseNameOpt("type function name"); @@ -962,7 +966,7 @@ AstStat* Parser::parseTypeFunction(const Location& start, bool exported) matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; - return allocator.alloc(Location(start, body->location), fnName->name, fnName->location, body); + return allocator.alloc(Location(start, body->location), fnName->name, fnName->location, body, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -3012,8 +3016,23 @@ std::optional> Parser::parseCharArray() AstExpr* Parser::parseString() { Location location = lexer.current().location; + + AstExprConstantString::QuoteStyle style; + switch (lexer.current().type) + { + case Lexeme::QuotedString: + case Lexeme::InterpStringSimple: + style = AstExprConstantString::QuotedSimple; + break; + case Lexeme::RawString: + style = AstExprConstantString::QuotedRaw; + break; + default: + LUAU_ASSERT(false && "Invalid string type"); + } + if (std::optional> value = parseCharArray()) - return allocator.alloc(location, *value); + return allocator.alloc(location, *value, style); else return reportExprError(location, {}, "String literal contains malformed escape sequence"); } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index be1f23f0..80ede2d0 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Config.h" #include "Luau/ModuleResolver.h" #include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" @@ -224,7 +225,14 @@ struct CliConfigResolver : Luau::ConfigResolver if (std::optional contents = readFile(configPath)) { - std::optional error = Luau::parseConfig(*contents, result); + Luau::ConfigOptions::AliasOptions aliasOpts; + aliasOpts.configLocation = configPath; + aliasOpts.overwriteAliases = true; + + Luau::ConfigOptions opts; + opts.aliasOptions = std::move(aliasOpts); + + std::optional error = Luau::parseConfig(*contents, result, opts); if (error) configErrors.push_back({configPath, *error}); } diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index e9f40a09..4906d55a 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -181,6 +181,16 @@ std::string resolvePath(std::string_view path, std::string_view baseFilePath) return resolvedPath; } +bool hasFileExtension(std::string_view name, const std::vector& extensions) +{ + for (const std::string& extension : extensions) + { + if (name.size() >= extension.size() && name.substr(name.size() - extension.size()) == extension) + return true; + } + return false; +} + std::optional readFile(const std::string& name) { #ifdef _WIN32 diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index dce94ace..f723c765 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -15,6 +15,8 @@ std::string resolvePath(std::string_view relativePath, std::string_view baseFile std::optional readFile(const std::string& name); std::optional readStdin(); +bool hasFileExtension(std::string_view name, const std::vector& extensions); + bool isAbsolutePath(std::string_view path); bool isFile(const std::string& path); bool isDirectory(const std::string& path); diff --git a/CLI/Require.cpp b/CLI/Require.cpp index 9a00597a..2c45d0ac 100644 --- a/CLI/Require.cpp +++ b/CLI/Require.cpp @@ -3,6 +3,7 @@ #include "FileUtils.h" #include "Luau/Common.h" +#include "Luau/Config.h" #include #include @@ -83,6 +84,9 @@ RequireResolver::ModuleStatus RequireResolver::findModuleImpl() absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix } + if (hasFileExtension(absolutePath, {".luau", ".lua"}) && isFile(absolutePath)) + luaL_argerrorL(L, 1, "error requiring module: consider removing the file extension"); + return ModuleStatus::NotFound; } @@ -235,14 +239,15 @@ std::optional RequireResolver::getAlias(std::string alias) return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; } ); - while (!config.aliases.count(alias) && !isConfigFullyResolved) + while (!config.aliases.contains(alias) && !isConfigFullyResolved) { parseNextConfig(); } - if (!config.aliases.count(alias) && isConfigFullyResolved) + if (!config.aliases.contains(alias) && isConfigFullyResolved) return std::nullopt; // could not find alias - return resolvePath(config.aliases[alias], joinPaths(lastSearchedDir, Luau::kConfigName)); + const Luau::Config::AliasInfo& aliasInfo = config.aliases[alias]; + return resolvePath(aliasInfo.value, aliasInfo.configLocation); } void RequireResolver::parseNextConfig() @@ -275,9 +280,16 @@ void RequireResolver::parseConfigInDirectory(const std::string& directory) { std::string configPath = joinPaths(directory, Luau::kConfigName); + Luau::ConfigOptions::AliasOptions aliasOpts; + aliasOpts.configLocation = configPath; + aliasOpts.overwriteAliases = false; + + Luau::ConfigOptions opts; + opts.aliasOptions = std::move(aliasOpts); + if (std::optional contents = readFile(configPath)) { - std::optional error = Luau::parseConfig(*contents, config); + std::optional error = Luau::parseConfig(*contents, config, opts); if (error) luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str()); } diff --git a/Common/include/Luau/Variant.h b/Common/include/Luau/Variant.h index 88722257..14eb8c4e 100644 --- a/Common/include/Luau/Variant.h +++ b/Common/include/Luau/Variant.h @@ -19,7 +19,7 @@ class Variant static_assert(std::disjunction_v...> == false, "variant does not allow references as an alternative type"); static_assert(std::disjunction_v...> == false, "variant does not allow arrays as an alternative type"); -private: +public: template static constexpr int getTypeId() { @@ -35,6 +35,7 @@ private: return -1; } +private: template struct First { diff --git a/Config/include/Luau/Config.h b/Config/include/Luau/Config.h index 3866547b..d6016229 100644 --- a/Config/include/Luau/Config.h +++ b/Config/include/Luau/Config.h @@ -1,12 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include "Luau/LinterConfig.h" #include "Luau/ParseOptions.h" +#include #include #include -#include +#include #include namespace Luau @@ -19,6 +21,10 @@ constexpr const char* kConfigName = ".luaurc"; struct Config { Config(); + Config(const Config& other) noexcept; + Config& operator=(const Config& other) noexcept; + Config(Config&& other) noexcept = default; + Config& operator=(Config&& other) noexcept = default; Mode mode = Mode::Nonstrict; @@ -32,7 +38,19 @@ struct Config std::vector globals; - std::unordered_map aliases; + struct AliasInfo + { + std::string value; + std::string_view configLocation; + }; + + DenseHashMap aliases{""}; + + void setAlias(std::string alias, const std::string& value, const std::string configLocation); + +private: + // Prevents making unnecessary copies of the same config location string. + DenseHashMap> configLocationCache{""}; }; struct ConfigResolver @@ -60,6 +78,18 @@ std::optional parseLintRuleString( bool isValidAlias(const std::string& alias); -std::optional parseConfig(const std::string& contents, Config& config, bool compat = false); +struct ConfigOptions +{ + bool compat = false; + + struct AliasOptions + { + std::string configLocation; + bool overwriteAliases; + }; + std::optional aliasOptions = std::nullopt; +}; + +std::optional parseConfig(const std::string& contents, Config& config, const ConfigOptions& options = ConfigOptions{}); } // namespace Luau diff --git a/Config/src/Config.cpp b/Config/src/Config.cpp index cf7d4b22..3760fd9e 100644 --- a/Config/src/Config.cpp +++ b/Config/src/Config.cpp @@ -4,7 +4,8 @@ #include "Luau/Lexer.h" #include "Luau/StringUtils.h" #include -#include +#include +#include namespace Luau { @@ -16,6 +17,50 @@ Config::Config() enabledLint.setDefaults(); } +Config::Config(const Config& other) noexcept + : mode(other.mode) + , parseOptions(other.parseOptions) + , enabledLint(other.enabledLint) + , fatalLint(other.fatalLint) + , lintErrors(other.lintErrors) + , typeErrors(other.typeErrors) + , globals(other.globals) +{ + for (const auto& [alias, aliasInfo] : other.aliases) + { + std::string configLocation = std::string(aliasInfo.configLocation); + + if (!configLocationCache.contains(configLocation)) + configLocationCache[configLocation] = std::make_unique(configLocation); + + AliasInfo newAliasInfo; + newAliasInfo.value = aliasInfo.value; + newAliasInfo.configLocation = *configLocationCache[configLocation]; + aliases[alias] = std::move(newAliasInfo); + } +} + +Config& Config::operator=(const Config& other) noexcept +{ + if (this != &other) + { + Config copy(other); + std::swap(*this, copy); + } + return *this; +} + +void Config::setAlias(std::string alias, const std::string& value, const std::string configLocation) +{ + AliasInfo& info = aliases[alias]; + info.value = value; + + if (!configLocationCache.contains(configLocation)) + configLocationCache[configLocation] = std::make_unique(configLocation); + + info.configLocation = *configLocationCache[configLocation]; +} + static Error parseBoolean(bool& result, const std::string& value) { if (value == "true") @@ -136,7 +181,12 @@ bool isValidAlias(const std::string& alias) return true; } -Error parseAlias(std::unordered_map& aliases, std::string aliasKey, const std::string& aliasValue) +static Error parseAlias( + Config& config, + std::string aliasKey, + const std::string& aliasValue, + const std::optional& aliasOptions +) { if (!isValidAlias(aliasKey)) return Error{"Invalid alias " + aliasKey}; @@ -150,8 +200,12 @@ Error parseAlias(std::unordered_map& aliases, std::str return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; } ); - if (!aliases.count(aliasKey)) - aliases[std::move(aliasKey)] = aliasValue; + + if (!aliasOptions) + return Error("Cannot parse aliases without alias options"); + + if (aliasOptions->overwriteAliases || !config.aliases.contains(aliasKey)) + config.setAlias(std::move(aliasKey), aliasValue, aliasOptions->configLocation); return std::nullopt; } @@ -285,16 +339,16 @@ static Error parseJson(const std::string& contents, Action action) return {}; } -Error parseConfig(const std::string& contents, Config& config, bool compat) +Error parseConfig(const std::string& contents, Config& config, const ConfigOptions& options) { return parseJson( contents, [&](const std::vector& keys, const std::string& value) -> Error { if (keys.size() == 1 && keys[0] == "languageMode") - return parseModeString(config.mode, value, compat); + return parseModeString(config.mode, value, options.compat); else if (keys.size() == 2 && keys[0] == "lint") - return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, compat); + return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, options.compat); else if (keys.size() == 1 && keys[0] == "lintErrors") return parseBoolean(config.lintErrors, value); else if (keys.size() == 1 && keys[0] == "typeErrors") @@ -305,9 +359,9 @@ Error parseConfig(const std::string& contents, Config& config, bool compat) return std::nullopt; } else if (keys.size() == 2 && keys[0] == "aliases") - return parseAlias(config.aliases, keys[1], value); - else if (compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode") - return parseModeString(config.mode, value, compat); + return parseAlias(config, keys[1], value, options.aliasOptions); + else if (options.compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode") + return parseModeString(config.mode, value, options.compat); else { std::vector keysv(keys.begin(), keys.end()); diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index 480aa07d..924da974 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -23,6 +23,13 @@ struct Analysis final using D = typename N::Data; + Analysis() = default; + + Analysis(N a) + : analysis(std::move(a)) + { + } + template static D fnMake(const N& analysis, const EGraph& egraph, const L& enode) { @@ -59,6 +66,15 @@ struct EClass final template struct EGraph final { + using EClassT = EClass; + + EGraph() = default; + + explicit EGraph(N analysis) + : analysis(std::move(analysis)) + { + } + Id find(Id id) const { return unionfind.find(id); @@ -85,33 +101,59 @@ struct EGraph final return id; } - void merge(Id id1, Id id2) + // Returns true if the two IDs were not previously merged. + bool merge(Id id1, Id id2) { id1 = find(id1); id2 = find(id2); if (id1 == id2) - return; + return false; - unionfind.merge(id1, id2); + const Id mergedId = unionfind.merge(id1, id2); - EClass& eclass1 = get(id1); - EClass eclass2 = std::move(get(id2)); + // Ensure that id1 is the Id that we keep, and id2 is the id that we drop. + if (mergedId == id2) + std::swap(id1, id2); + + EClassT& eclass1 = get(id1); + EClassT eclass2 = std::move(get(id2)); classes.erase(id2); - worklist.reserve(worklist.size() + eclass2.parents.size()); - for (auto [enode, id] : eclass2.parents) - worklist.push_back({std::move(enode), id}); + eclass1.nodes.insert(eclass1.nodes.end(), eclass2.nodes.begin(), eclass2.nodes.end()); + eclass1.parents.insert(eclass1.parents.end(), eclass2.parents.begin(), eclass2.parents.end()); + + std::sort( + eclass1.nodes.begin(), + eclass1.nodes.end(), + [](const L& left, const L& right) + { + return left.index() < right.index(); + } + ); + + worklist.reserve(worklist.size() + eclass1.parents.size()); + for (const auto& [eclass, id] : eclass1.parents) + worklist.push_back(id); analysis.join(eclass1.data, eclass2.data); + + return true; } void rebuild() { + std::unordered_set seen; + while (!worklist.empty()) { - auto [enode, id] = worklist.back(); + Id id = worklist.back(); worklist.pop_back(); - repair(get(find(id))); + + const bool isFresh = seen.insert(id).second; + if (!isFresh) + continue; + + repair(find(id)); } } @@ -120,16 +162,21 @@ struct EGraph final return classes.size(); } - EClass& operator[](Id id) + EClassT& operator[](Id id) { return get(find(id)); } - const EClass& operator[](Id id) const + const EClassT& operator[](Id id) const { return const_cast(this)->get(find(id)); } + const std::unordered_map& getAllClasses() const + { + return classes; + } + private: Analysis analysis; @@ -139,19 +186,19 @@ private: /// The e-class map 𝑀 maps e-class ids to e-classes. All equivalent e-class ids map to the same /// e-class, i.e., 𝑎 ≡id 𝑏 iff 𝑀[𝑎] is the same set as 𝑀[𝑏]. An e-class id 𝑎 is said to refer to the /// e-class 𝑀[find(𝑎)]. - std::unordered_map> classes; + std::unordered_map classes; /// The hashcons 𝐻 is a map from e-nodes to e-class ids. std::unordered_map hashcons; - std::vector> worklist; + std::vector worklist; private: void canonicalize(L& enode) { // An e-node 𝑛 is canonical iff 𝑛 = canonicalize(𝑛), where // canonicalize(𝑓(𝑎1, 𝑎2, ...)) = 𝑓(find(𝑎1), find(𝑎2), ...). - for (Id& id : enode.operands()) + for (Id& id : enode.mutableOperands()) id = find(id); } @@ -171,7 +218,7 @@ private: classes.insert_or_assign( id, - EClass{ + EClassT{ id, {enode}, analysis.make(*this, enode), @@ -182,7 +229,7 @@ private: for (Id operand : enode.operands()) get(operand).parents.push_back({enode, id}); - worklist.emplace_back(enode, id); + worklist.emplace_back(id); hashcons.insert_or_assign(enode, id); return id; @@ -190,12 +237,13 @@ private: // Looks up for an eclass from a given non-canonicalized `id`. // For a canonicalized eclass, use `get(find(id))` or `egraph[id]`. - EClass& get(Id id) + EClassT& get(Id id) { + LUAU_ASSERT(classes.count(id)); return classes.at(id); } - void repair(EClass& eclass) + void repair(Id id) { // In the egg paper, the `repair` function makes use of two loops over the `eclass.parents` // by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id. @@ -204,26 +252,54 @@ private: // Here, we unify the two loops. I think it's equivalent? // After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent. - std::unordered_map map; - for (auto& [enode, id] : eclass.parents) + std::unordered_map newParents; + + // The eclass can be deallocated if it is merged into another eclass, so + // we take what we need from it and avoid retaining a pointer. + std::vector> parents = get(id).parents; + for (auto& pair : parents) { + L& enode = pair.first; + Id id = pair.second; + // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. hashcons.erase(enode); canonicalize(enode); hashcons.insert_or_assign(enode, find(id)); - if (auto it = map.find(enode); it != map.end()) + if (auto it = newParents.find(enode); it != newParents.end()) merge(id, it->second); - map.insert_or_assign(enode, find(id)); + newParents.insert_or_assign(enode, find(id)); } - eclass.parents.clear(); - for (auto it = map.begin(); it != map.end();) + // We reacquire the pointer because the prior loop potentially merges + // the eclass into another, which might move it around in memory. + EClassT* eclass = &get(find(id)); + + eclass->parents.clear(); + + for (const auto& [node, id] : newParents) + eclass->parents.emplace_back(std::move(node), std::move(id)); + + std::unordered_set newNodes; + for (L node : eclass->nodes) { - auto node = map.extract(it++); - eclass.parents.emplace_back(std::move(node.key()), node.mapped()); + canonicalize(node); + newNodes.insert(std::move(node)); } + + eclass->nodes.assign(newNodes.begin(), newNodes.end()); + + // FIXME: Extract into sortByTag() + std::sort( + eclass->nodes.begin(), + eclass->nodes.end(), + [](const L& left, const L& right) + { + return left.index() < right.index(); + } + ); } }; diff --git a/EqSat/include/Luau/Id.h b/EqSat/include/Luau/Id.h index c56a6ab6..7069f23c 100644 --- a/EqSat/include/Luau/Id.h +++ b/EqSat/include/Luau/Id.h @@ -2,6 +2,7 @@ #pragma once #include +#include #include namespace Luau::EqSat @@ -9,15 +10,17 @@ namespace Luau::EqSat struct Id final { - explicit Id(size_t id); + explicit Id(uint32_t id); - explicit operator size_t() const; + explicit operator uint32_t() const; bool operator==(Id rhs) const; bool operator!=(Id rhs) const; + bool operator<(Id rhs) const; + private: - size_t id; + uint32_t id; }; } // namespace Luau::EqSat diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h index 8855d851..56fc7202 100644 --- a/EqSat/include/Luau/Language.h +++ b/EqSat/include/Luau/Language.h @@ -6,9 +6,19 @@ #include "Luau/Slice.h" #include "Luau/Variant.h" +#include #include #include +#include #include +#include + +#define LUAU_EQSAT_UNIT(name) \ + struct name : ::Luau::EqSat::Unit \ + { \ + static constexpr const char* tag = #name; \ + using Unit::Unit; \ + } #define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom \ @@ -31,21 +41,57 @@ using NodeVector::NodeVector; \ } -#define LUAU_EQSAT_FIELD(name) \ - struct name : public ::Luau::EqSat::Field \ - { \ - } - -#define LUAU_EQSAT_NODE_FIELDS(name, ...) \ - struct name : public ::Luau::EqSat::NodeFields \ +#define LUAU_EQSAT_NODE_SET(name) \ + struct name : public ::Luau::EqSat::NodeSet> \ { \ static constexpr const char* tag = #name; \ - using NodeFields::NodeFields; \ + using NodeSet::NodeSet; \ + } + +#define LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(name, t) \ + struct name : public ::Luau::EqSat::NodeAtomAndVector> \ + { \ + static constexpr const char* tag = #name; \ + using NodeAtomAndVector::NodeAtomAndVector; \ } namespace Luau::EqSat { +template +struct Unit +{ + Slice mutableOperands() + { + return {}; + } + + Slice operands() const + { + return {}; + } + + bool operator==(const Unit& rhs) const + { + return true; + } + + bool operator!=(const Unit& rhs) const + { + return false; + } + + struct Hash + { + size_t operator()(const Unit& value) const + { + // chosen by fair dice roll. + // guaranteed to be random. + return 4; + } + }; +}; + template struct Atom { @@ -60,7 +106,7 @@ struct Atom } public: - Slice operands() + Slice mutableOperands() { return {}; } @@ -92,6 +138,62 @@ private: T _value; }; +template +struct NodeAtomAndVector +{ + template + NodeAtomAndVector(const X& value, Args&&... args) + : _value(value) + , vector{std::forward(args)...} + { + } + + Id operator[](size_t i) const + { + return vector[i]; + } + +public: + const X& value() const + { + return _value; + } + + Slice mutableOperands() + { + return Slice{vector.data(), vector.size()}; + } + + Slice operands() const + { + return Slice{vector.data(), vector.size()}; + } + + bool operator==(const NodeAtomAndVector& rhs) const + { + return _value == rhs._value && vector == rhs.vector; + } + + bool operator!=(const NodeAtomAndVector& rhs) const + { + return !(*this == rhs); + } + + struct Hash + { + size_t operator()(const NodeAtomAndVector& value) const + { + size_t result = languageHash(value._value); + hashCombine(result, languageHash(value.vector)); + return result; + } + }; + +private: + X _value; + T vector; +}; + template struct NodeVector { @@ -107,7 +209,7 @@ struct NodeVector } public: - Slice operands() + Slice mutableOperands() { return Slice{vector.data(), vector.size()}; } @@ -139,90 +241,61 @@ private: T vector; }; -/// Empty base class just for static_asserts. -struct FieldBase +template +struct NodeSet { - FieldBase() = delete; - - FieldBase(FieldBase&&) = delete; - FieldBase& operator=(FieldBase&&) = delete; - - FieldBase(const FieldBase&) = delete; - FieldBase& operator=(const FieldBase&) = delete; -}; - -template -struct Field : FieldBase -{ -}; - -template -struct NodeFields -{ - static_assert(std::conjunction...>::value); - - template - static constexpr int getIndex() + template + NodeSet(Args&&... args) + : vector{std::forward(args)...} { - constexpr int N = sizeof...(Fields); - constexpr bool is[N] = {std::is_same_v, Fields>...}; + std::sort(begin(vector), end(vector)); + auto it = std::unique(begin(vector), end(vector)); + vector.erase(it, end(vector)); + } - for (int i = 0; i < N; ++i) - if (is[i]) - return i; - - return -1; + Id operator[](size_t i) const + { + return vector[i]; } public: - template - NodeFields(Args&&... args) - : array{std::forward(args)...} + Slice mutableOperands() { - } - - Slice operands() - { - return Slice{array}; + return Slice{vector.data(), vector.size()}; } Slice operands() const { - return Slice{array.data(), array.size()}; + return Slice{vector.data(), vector.size()}; } - template - Id field() const + bool operator==(const NodeSet& rhs) const { - static_assert(std::disjunction_v, Fields>...>); - return array[getIndex()]; + return vector == rhs.vector; } - bool operator==(const NodeFields& rhs) const - { - return array == rhs.array; - } - - bool operator!=(const NodeFields& rhs) const + bool operator!=(const NodeSet& rhs) const { return !(*this == rhs); } struct Hash { - size_t operator()(const NodeFields& value) const + size_t operator()(const NodeSet& value) const { - return languageHash(value.array); + return languageHash(value.vector); } }; -private: - std::array array; +protected: + T vector; }; template struct Language final { + using VariantTy = Luau::Variant; + template using WithinDomain = std::disjunction, Ts>...>; @@ -237,14 +310,14 @@ struct Language final return v.index(); } - /// You should never call this function with the intention of mutating the `Id`. - /// Reading is ok, but you should also never assume that these `Id`s are stable. - Slice operands() noexcept + /// This should only be used in canonicalization! + /// Always prefer operands() + Slice mutableOperands() noexcept { return visit( [](auto&& v) -> Slice { - return v.operands(); + return v.mutableOperands(); }, v ); @@ -306,7 +379,7 @@ public: }; private: - Variant v; + VariantTy v; }; } // namespace Luau::EqSat diff --git a/EqSat/include/Luau/LanguageHash.h b/EqSat/include/Luau/LanguageHash.h index 506f352b..cfc33b83 100644 --- a/EqSat/include/Luau/LanguageHash.h +++ b/EqSat/include/Luau/LanguageHash.h @@ -3,6 +3,7 @@ #include #include +#include #include namespace Luau::EqSat diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index 559ee119..22a61628 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -14,7 +14,9 @@ struct UnionFind final Id makeSet(); Id find(Id id) const; Id find(Id id); - void merge(Id a, Id b); + + // Merge aSet with bSet and return the canonicalized Id into the merged set. + Id merge(Id aSet, Id bSet); private: std::vector parents; diff --git a/EqSat/src/Id.cpp b/EqSat/src/Id.cpp index 960249ba..eae6a974 100644 --- a/EqSat/src/Id.cpp +++ b/EqSat/src/Id.cpp @@ -1,15 +1,16 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Id.h" +#include "Luau/Common.h" namespace Luau::EqSat { -Id::Id(size_t id) +Id::Id(uint32_t id) : id(id) { } -Id::operator size_t() const +Id::operator uint32_t() const { return id; } @@ -24,9 +25,14 @@ bool Id::operator!=(Id rhs) const return id != rhs.id; } +bool Id::operator<(Id rhs) const +{ + return id < rhs.id; +} + } // namespace Luau::EqSat size_t std::hash::operator()(Luau::EqSat::Id id) const { - return std::hash()(size_t(id)); + return std::hash()(uint32_t(id)); } diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 619c3f47..6a952999 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -3,12 +3,16 @@ #include "Luau/Common.h" +#include + namespace Luau::EqSat { Id UnionFind::makeSet() { - Id id{parents.size()}; + LUAU_ASSERT(parents.size() < std::numeric_limits::max()); + + Id id{uint32_t(parents.size())}; parents.push_back(id); ranks.push_back(0); @@ -25,42 +29,44 @@ Id UnionFind::find(Id id) Id set = canonicalize(id); // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. - while (id != parents[size_t(id)]) + while (id != parents[uint32_t(id)]) { // Note: we don't update the ranks here since a rank // represents the upper bound on the maximum depth of a tree - Id parent = parents[size_t(id)]; - parents[size_t(id)] = set; + Id parent = parents[uint32_t(id)]; + parents[uint32_t(id)] = set; id = parent; } return set; } -void UnionFind::merge(Id a, Id b) +Id UnionFind::merge(Id a, Id b) { Id aSet = find(a); Id bSet = find(b); if (aSet == bSet) - return; + return aSet; // Ensure that the rank of set A is greater than the rank of set B - if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) + if (ranks[uint32_t(aSet)] > ranks[uint32_t(bSet)]) std::swap(aSet, bSet); - parents[size_t(bSet)] = aSet; + parents[uint32_t(bSet)] = aSet; - if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) - ranks[size_t(aSet)]++; + if (ranks[uint32_t(aSet)] == ranks[uint32_t(bSet)]) + ranks[uint32_t(aSet)]++; + + return aSet; } Id UnionFind::canonicalize(Id id) const { - LUAU_ASSERT(size_t(id) < parents.size()); + LUAU_ASSERT(uint32_t(id) < parents.size()); // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. - while (id != parents[size_t(id)]) - id = parents[size_t(id)]; + while (id != parents[uint32_t(id)]) + id = parents[uint32_t(id)]; return id; } diff --git a/Sources.cmake b/Sources.cmake index 4b99e867..1299b119 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -14,6 +14,7 @@ endif() # Luau.Ast Sources target_sources(Luau.Ast PRIVATE + Ast/include/Luau/Allocator.h Ast/include/Luau/Ast.h Ast/include/Luau/Confusables.h Ast/include/Luau/Lexer.h @@ -24,6 +25,7 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/StringUtils.h Ast/include/Luau/TimeTrace.h + Ast/src/Allocator.cpp Ast/src/Ast.cpp Ast/src/Confusables.cpp Ast/src/Lexer.cpp @@ -168,6 +170,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstQuery.h Analysis/include/Luau/Autocomplete.h + Analysis/include/Luau/AutocompleteTypes.h Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Cancellation.h Analysis/include/Luau/Clone.h @@ -181,6 +184,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Differ.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h + Analysis/include/Luau/EqSatSimplification.h Analysis/include/Luau/FileResolver.h Analysis/include/Luau/FragmentAutocomplete.h Analysis/include/Luau/Frontend.h @@ -245,6 +249,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/AstJsonEncoder.cpp Analysis/src/AstQuery.cpp Analysis/src/Autocomplete.cpp + Analysis/src/AutocompleteCore.cpp Analysis/src/BuiltinDefinitions.cpp Analysis/src/Clone.cpp Analysis/src/Constraint.cpp @@ -256,6 +261,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Differ.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp + Analysis/src/EqSatSimplification.cpp Analysis/src/FragmentAutocomplete.cpp Analysis/src/Frontend.cpp Analysis/src/Generalization.cpp @@ -417,7 +423,7 @@ endif() if(TARGET Luau.UnitTest) # Luau.UnitTest Sources target_sources(Luau.UnitTest PRIVATE - tests/AnyTypeSummary.test.cpp + tests/AnyTypeSummary.test.cpp tests/AssemblyBuilderA64.test.cpp tests/AssemblyBuilderX64.test.cpp tests/AstJsonEncoder.test.cpp @@ -444,6 +450,7 @@ if(TARGET Luau.UnitTest) tests/EqSat.language.test.cpp tests/EqSat.propositional.test.cpp tests/EqSat.slice.test.cpp + tests/EqSatSimplification.test.cpp tests/Error.test.cpp tests/Fixture.cpp tests/Fixture.h diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index d382a924..052d8c82 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -39,7 +39,7 @@ const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Ri "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; -const char* luau_ident = "$Luau: Copyright (C) 2019-2023 Roblox Corporation $\n" +const char* luau_ident = "$Luau: Copyright (C) 2019-2024 Roblox Corporation $\n" "$URL: luau.org $\n"; #define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base)) diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index e170e9bc..e6e67020 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -67,7 +67,7 @@ TEST_CASE("encode_constants") charString.data = const_cast("a\x1d\0\\\"b"); charString.size = 6; - AstExprConstantString needsEscaping{Location(), charString}; + AstExprConstantString needsEscaping{Location(), charString, AstExprConstantString::QuotedSimple}; CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil)); CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b)); @@ -83,7 +83,7 @@ TEST_CASE("basic_escaping") { std::string s = "hello \"world\""; AstArray theString{s.data(), s.size()}; - AstExprConstantString str{Location(), theString}; + AstExprConstantString str{Location(), theString, AstExprConstantString::QuotedSimple}; std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})"; CHECK_EQ(expected, toJson(&str)); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index de4049a9..0424e3df 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -151,40 +151,6 @@ struct ACBuiltinsFixture : ACFixtureImpl { }; -#define LUAU_CHECK_HAS_KEY(map, key) \ - do \ - { \ - auto&& _m = (map); \ - auto&& _k = (key); \ - const size_t count = _m.count(_k); \ - CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \ - if (!count) \ - { \ - MESSAGE("Keys: (count " << _m.size() << ")"); \ - for (const auto& [k, v] : _m) \ - { \ - MESSAGE("\tkey: " << k); \ - } \ - } \ - } while (false) - -#define LUAU_CHECK_HAS_NO_KEY(map, key) \ - do \ - { \ - auto&& _m = (map); \ - auto&& _k = (key); \ - const size_t count = _m.count(_k); \ - CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \ - if (count) \ - { \ - MESSAGE("Keys: (count " << _m.size() << ")"); \ - for (const auto& [k, v] : _m) \ - { \ - MESSAGE("\tkey: " << k); \ - } \ - } \ - } while (false) - TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") diff --git a/tests/Config.test.cpp b/tests/Config.test.cpp index 70d6d6d7..690c4c37 100644 --- a/tests/Config.test.cpp +++ b/tests/Config.test.cpp @@ -58,7 +58,11 @@ TEST_CASE("report_a_syntax_error") TEST_CASE("noinfer_is_still_allowed") { Config config; - auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, true); + + ConfigOptions opts; + opts.compat = true; + + auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, opts); REQUIRE(!err); CHECK_EQ(int(Luau::Mode::NoCheck), int(config.mode)); @@ -147,6 +151,10 @@ TEST_CASE("extra_globals") TEST_CASE("lint_rules_compat") { Config config; + + ConfigOptions opts; + opts.compat = true; + auto err = parseConfig( R"( {"lint": { @@ -156,7 +164,7 @@ TEST_CASE("lint_rules_compat") }} )", config, - true + opts ); REQUIRE(!err); diff --git a/tests/ConstraintGeneratorFixture.cpp b/tests/ConstraintGeneratorFixture.cpp index 1b84d4c9..ef91fdf7 100644 --- a/tests/ConstraintGeneratorFixture.cpp +++ b/tests/ConstraintGeneratorFixture.cpp @@ -10,6 +10,7 @@ namespace Luau ConstraintGeneratorFixture::ConstraintGeneratorFixture() : Fixture() , mainModule(new Module) + , simplifier(newSimplifier(NotNull{&arena}, builtinTypes)) , forceTheFlag{FFlag::LuauSolverV2, true} { mainModule->name = "MainModule"; @@ -25,6 +26,7 @@ void ConstraintGeneratorFixture::generateConstraints(const std::string& code) cg = std::make_unique( mainModule, NotNull{&normalizer}, + NotNull{simplifier.get()}, NotNull{&typeFunctionRuntime}, NotNull(&moduleResolver), builtinTypes, @@ -44,8 +46,19 @@ void ConstraintGeneratorFixture::solve(const std::string& code) { generateConstraints(code); ConstraintSolver cs{ - NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, NotNull{dfg.get()}, {} + NotNull{&normalizer}, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, + NotNull{rootScope}, + constraints, + "MainModule", + NotNull(&moduleResolver), + {}, + &logger, + NotNull{dfg.get()}, + {} }; + cs.run(); } diff --git a/tests/ConstraintGeneratorFixture.h b/tests/ConstraintGeneratorFixture.h index 782747c7..800bf873 100644 --- a/tests/ConstraintGeneratorFixture.h +++ b/tests/ConstraintGeneratorFixture.h @@ -4,8 +4,9 @@ #include "Luau/ConstraintGenerator.h" #include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" -#include "Luau/TypeArena.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Module.h" +#include "Luau/TypeArena.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -20,6 +21,7 @@ struct ConstraintGeneratorFixture : Fixture DcrLogger logger; UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + SimplifierPtr simplifier; TypeCheckLimits limits; TypeFunctionRuntime typeFunctionRuntime{NotNull{&ice}, NotNull{&limits}}; diff --git a/tests/EqSat.language.test.cpp b/tests/EqSat.language.test.cpp index 282d4ad2..fd1bde57 100644 --- a/tests/EqSat.language.test.cpp +++ b/tests/EqSat.language.test.cpp @@ -11,9 +11,7 @@ LUAU_EQSAT_ATOM(I32, int); LUAU_EQSAT_ATOM(Bool, bool); LUAU_EQSAT_ATOM(Str, std::string); -LUAU_EQSAT_FIELD(Left); -LUAU_EQSAT_FIELD(Right); -LUAU_EQSAT_NODE_FIELDS(Add, Left, Right); +LUAU_EQSAT_NODE_ARRAY(Add, 2); using namespace Luau; @@ -117,8 +115,8 @@ TEST_CASE("node_field") Add add{left, right}; - EqSat::Id left2 = add.field(); - EqSat::Id right2 = add.field(); + EqSat::Id left2 = add.operands()[0]; + EqSat::Id right2 = add.operands()[1]; CHECK(left == left2); CHECK(left != right2); @@ -135,10 +133,10 @@ TEST_CASE("language_operands") const Add* add = v2.get(); REQUIRE(add); - EqSat::Slice actual = v2.operands(); + EqSat::Slice actual = v2.operands(); CHECK(actual.size() == 2); - CHECK(actual[0] == add->field()); - CHECK(actual[1] == add->field()); + CHECK(actual[0] == add->operands()[0]); + CHECK(actual[1] == add->operands()[1]); } TEST_SUITE_END(); diff --git a/tests/EqSatSimplification.test.cpp b/tests/EqSatSimplification.test.cpp new file mode 100644 index 00000000..aaaec456 --- /dev/null +++ b/tests/EqSatSimplification.test.cpp @@ -0,0 +1,728 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "Luau/EqSatSimplification.h" + +using namespace Luau; + +struct ESFixture : Fixture +{ + ScopedFastFlag newSolverOnly{FFlag::LuauSolverV2, true}; + + TypeArena arena_; + const NotNull arena{&arena_}; + + SimplifierPtr simplifier; + + TypeId parentClass; + TypeId childClass; + TypeId anotherChild; + TypeId unrelatedClass; + + TypeId genericT = arena_.addType(GenericType{"T"}); + TypeId genericU = arena_.addType(GenericType{"U"}); + + TypeId numberToString = arena_.addType(FunctionType{ + arena_.addTypePack({builtinTypes->numberType}), + arena_.addTypePack({builtinTypes->stringType}) + }); + + TypeId stringToNumber = arena_.addType(FunctionType{ + arena_.addTypePack({builtinTypes->stringType}), + arena_.addTypePack({builtinTypes->numberType}) + }); + + ESFixture() + : simplifier(newSimplifier(arena, builtinTypes)) + { + createSomeClasses(&frontend); + + ScopePtr moduleScope = frontend.globals.globalScope; + + parentClass = moduleScope->linearSearchForBinding("Parent")->typeId; + childClass = moduleScope->linearSearchForBinding("Child")->typeId; + anotherChild = moduleScope->linearSearchForBinding("AnotherChild")->typeId; + unrelatedClass = moduleScope->linearSearchForBinding("Unrelated")->typeId; + } + + std::optional simplifyStr(TypeId ty) + { + auto res = eqSatSimplify(NotNull{simplifier.get()}, ty); + LUAU_ASSERT(res); + return toString(res->result); + } + + TypeId tbl(TableType::Props props) + { + return arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, TableState::Sealed}); + } +}; + +TEST_SUITE_BEGIN("EqSatSimplification"); + +TEST_CASE_FIXTURE(ESFixture, "primitive") +{ + CHECK("number" == simplifyStr(builtinTypes->numberType)); +} + +TEST_CASE_FIXTURE(ESFixture, "number | number") +{ + TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->numberType}}); + + CHECK("number" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "number | string") +{ + CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = number | t1") +{ + TypeId ty = arena->freshType(nullptr); + asMutable(ty)->ty.emplace(std::vector{builtinTypes->numberType, ty}); + + CHECK("number" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "number | string | number") +{ + TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType, builtinTypes->numberType}}); + + CHECK("number | string" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "string | (number | string) | number") +{ + TypeId u1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}); + TypeId u2 = arena->addType(UnionType{{builtinTypes->stringType, u1, builtinTypes->numberType}}); + + CHECK("number | string" == simplifyStr(u2)); +} + +TEST_CASE_FIXTURE(ESFixture, "string | any") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->anyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "any | string") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "any | never") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string | unknown") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->unknownType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "unknown | string") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "unknown | never") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string | never") +{ + CHECK("string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string | never | number") +{ + CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType, builtinTypes->numberType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & string") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & number") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->numberType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & unknown") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->unknownType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "never & string") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->neverType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & (unknown | never)") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->stringType, + arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}}) + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "true | false") +{ + CHECK("boolean" == simplifyStr(arena->addType(UnionType{{builtinTypes->trueType, builtinTypes->falseType}}))); +} + +/* + * Intuitively, if we have a type like + * + * x where x = A & B & (C | D | x) + * + * We know that x is certainly not larger than A & B. + * We also know that the union (C | D | x) can be rewritten `(C | D | (A & B & (C | D | x))) + * This tells us that the union part is not smaller than A & B. + * We can therefore discard the union entirely and simplify this type to A & B + */ +TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (number | t1)") +{ + TypeId intersectionTy = arena->addType(BlockedType{}); + TypeId unionTy = arena->addType(UnionType{{builtinTypes->numberType, intersectionTy}}); + + asMutable(intersectionTy)->ty.emplace(std::vector{builtinTypes->stringType, unionTy}); + + CHECK("string" == simplifyStr(intersectionTy)); +} + +TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (unknown | t1)") +{ + TypeId intersectionTy = arena->addType(BlockedType{}); + TypeId unionTy = arena->addType(UnionType{{builtinTypes->unknownType, intersectionTy}}); + + asMutable(intersectionTy)->ty.emplace(std::vector{builtinTypes->stringType, unionTy}); + + CHECK("string" == simplifyStr(intersectionTy)); +} + +TEST_CASE_FIXTURE(ESFixture, "error | unknown") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->errorType, builtinTypes->unknownType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "\"hello\" | string") +{ + CHECK("string" == simplifyStr(arena->addType(UnionType{{ + arena->addType(SingletonType{StringSingleton{"hello"}}), builtinTypes->stringType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "\"hello\" | \"world\" | \"hello\"") +{ + CHECK("\"hello\" | \"world\"" == simplifyStr(arena->addType(UnionType{{ + arena->addType(SingletonType{StringSingleton{"hello"}}), + arena->addType(SingletonType{StringSingleton{"world"}}), + arena->addType(SingletonType{StringSingleton{"hello"}}), + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "nil | boolean | number | string | thread | function | table | class | buffer") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{ + builtinTypes->nilType, + builtinTypes->booleanType, + builtinTypes->numberType, + builtinTypes->stringType, + builtinTypes->threadType, + builtinTypes->functionType, + builtinTypes->tableType, + builtinTypes->classType, + builtinTypes->bufferType, + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Parent & number") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ + parentClass, builtinTypes->numberType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & Parent") +{ + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{ + childClass, parentClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & Unrelated") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ + childClass, unrelatedClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child | Parent") +{ + CHECK("Parent" == simplifyStr(arena->addType(UnionType{{ + childClass, parentClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "class | Child") +{ + CHECK("class" == simplifyStr(arena->addType(UnionType{{ + builtinTypes->classType, childClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Parent | class | Child") +{ + CHECK("class" == simplifyStr(arena->addType(UnionType{{ + parentClass, builtinTypes->classType, childClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Parent | Unrelated") +{ + CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{ + parentClass, unrelatedClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "never | Parent | Unrelated") +{ + CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{ + builtinTypes->neverType, parentClass, unrelatedClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "never | Parent | (number & string) | Unrelated") +{ + CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{ + builtinTypes->neverType, parentClass, + arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}), + unrelatedClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "T & U") +{ + CHECK("T & U" == simplifyStr(arena->addType(IntersectionType{{ + genericT, genericU + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "boolean & true") +{ + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->booleanType, builtinTypes->trueType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "boolean & (true | number | string | thread | function | table | class | buffer)") +{ + TypeId truthy = arena->addType(UnionType{{ + builtinTypes->trueType, + builtinTypes->numberType, + builtinTypes->stringType, + builtinTypes->threadType, + builtinTypes->functionType, + builtinTypes->tableType, + builtinTypes->classType, + builtinTypes->bufferType, + }}); + + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->booleanType, truthy + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "boolean & ~(false?)") +{ + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->booleanType, builtinTypes->truthyType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "false & ~(false?)") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->falseType, builtinTypes->truthyType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (number) -> string") +{ + CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, numberToString}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (number) -> string") +{ + CHECK("(number) -> string" == simplifyStr(arena->addType(UnionType{{numberToString, numberToString}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & function") +{ + CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->functionType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & boolean") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->booleanType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & string") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & ~function") +{ + TypeId notFunction = arena->addType(NegationType{builtinTypes->functionType}); + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, notFunction}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string | function") +{ + CHECK("function" == simplifyStr(arena->addType(UnionType{{numberToString, builtinTypes->functionType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (string) -> number") +{ + CHECK("((number) -> string) & ((string) -> number)" == simplifyStr(arena->addType(IntersectionType{{numberToString, stringToNumber}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (string) -> number") +{ + CHECK("((number) -> string) | ((string) -> number)" == simplifyStr(arena->addType(UnionType{{numberToString, stringToNumber}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "add") +{ + CHECK("number" == simplifyStr(arena->addType( + TypeFunctionInstanceType{builtinTypeFunctions().addFunc, { + builtinTypes->numberType, builtinTypes->numberType + }} + ))); +} + +TEST_CASE_FIXTURE(ESFixture, "union") +{ + CHECK("number" == simplifyStr(arena->addType( + TypeFunctionInstanceType{builtinTypeFunctions().unionFunc, { + builtinTypes->numberType, builtinTypes->numberType + }} + ))); +} + +TEST_CASE_FIXTURE(ESFixture, "never & ~string") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->neverType, + arena->addType(NegationType{builtinTypes->stringType}) + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "blocked & never") +{ + const TypeId blocked = arena->addType(BlockedType{}); + + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{blocked, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "blocked & ~number & function") +{ + const TypeId blocked = arena->addType(BlockedType{}); + const TypeId notNumber = arena->addType(NegationType{builtinTypes->numberType}); + + const TypeId ty = arena->addType(IntersectionType{{blocked, notNumber, builtinTypes->functionType}}); + + std::string expected = toString(blocked) + " & function"; + + CHECK(expected == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "(number | boolean | string | nil | table) & (false | nil)") +{ + const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->stringType, builtinTypes->nilType, builtinTypes->tableType}}); + + CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number | boolean | nil) & (false | nil)") +{ + const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->nilType}}); + + CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(boolean | nil) & (false | nil)") +{ + const TypeId t1 = arena->addType(UnionType{{builtinTypes->booleanType, builtinTypes->nilType}}); + + CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}}))); +} + +// (('a & false) | ('a & nil)) | number + +// Child & ~Parent +// ~Parent & Child +// ~Child & Parent +// Parent & ~Child +// ~Child & ~Parent +// ~Parent & ~Child + +TEST_CASE_FIXTURE(ESFixture, "free & string & number") +{ + Scope scope{builtinTypes->anyTypePack}; + const TypeId freeTy = arena->addType(FreeType{&scope}); + + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{freeTy, builtinTypes->numberType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(blocked & number) | (blocked & number)") +{ + const TypeId blocked = arena->addType(BlockedType{}); + const TypeId u = arena->addType(IntersectionType{{blocked, builtinTypes->numberType}}); + const TypeId ty = arena->addType(UnionType{{u, u}}); + + const std::string blockedStr = toString(blocked); + + CHECK(blockedStr + " & number" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "{} & unknown") +{ + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{ + tbl({}), + builtinTypes->unknownType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{} & table") +{ + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{ + tbl({}), + builtinTypes->tableType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{} & ~(false?)") +{ + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{ + tbl({}), + builtinTypes->truthyType + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: number}") +{ + const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}}); + const TypeId hasX = tbl({{"x", builtinTypes->numberType}}); + + const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}}); + auto res = eqSatSimplify(NotNull{simplifier.get()}, ty); + + CHECK("{ x: number }" == toString(res->result)); + + // Also assert that we don't allocate a fresh TableType in this case. + CHECK(follow(res->result) == hasX); +} + +TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: ~(false?)}") +{ + const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}}); + const TypeId hasX = tbl({{"x", builtinTypes->truthyType}}); + + const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}}); + auto res = eqSatSimplify(NotNull{simplifier.get()}, ty); + + CHECK("{ x: number }" == toString(res->result)); +} + +TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) }") +{ + // {x: number?}? + const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}}); + + // {x: ~(false?)} + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + + const TypeId ty = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}}); + + CHECK("{ x: number }" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "never | (({ x: number? }?) & { x: ~(false?) })") +{ + // {x: number?}? + const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}}); + + // {x: ~(false?)} + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + + // ({x: number?}?) & {x: ~(false?)} + const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}}); + + const TypeId ty = arena->addType(UnionType{{builtinTypes->neverType, intersectionTy}}); + + CHECK("{ x: number }" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "({ x: number? }?) & { x: ~(false?) } & ~(false?)") +{ + // {x: number?}? + const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}}); + + // {x: ~(false?)} + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + + // ({x: number?}?) & {x: ~(false?)} & ~(false?) + const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}}); + + CHECK("{ x: number }" == simplifyStr(intersectionTy)); +} + +#if 0 +// TODO +TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) } & ~(false?)) | number") +{ + // ({ x: number? }?) & { x: ~(false?) } & ~(false?) + const TypeId xWithOptionalNumber = tbl({{"x", builtinTypes->optionalNumberType}}); + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}}); + const TypeId ty = arena->addType(UnionType{{intersectionTy, builtinTypes->numberType}}); + + CHECK("{ x: number } | number" == simplifyStr(ty)); +} +#endif + +TEST_CASE_FIXTURE(ESFixture, "number & no-refine") +{ + CHECK("number" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->noRefineType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{ x: number } & ~boolean") +{ + const TypeId tblTy = tbl(TableType::Props{{"x", builtinTypes->numberType}}); + + const TypeId ty = arena->addType(IntersectionType{{ + tblTy, + arena->addType(NegationType{builtinTypes->booleanType}) + }}); + + CHECK("{ x: number }" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "(nil & string)?") +{ + const TypeId nilAndString = arena->addType(IntersectionType{{builtinTypes->nilType, builtinTypes->stringType}}); + const TypeId ty = arena->addType(UnionType{{nilAndString, builtinTypes->nilType}}); + + CHECK("nil" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "string & \"hi\"") +{ + const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}}); + + CHECK("\"hi\"" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, hi}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & (\"hi\" | \"bye\")") +{ + const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}}); + const TypeId bye = arena->addType(SingletonType{StringSingleton{"bye"}}); + + CHECK("\"bye\" | \"hi\"" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->stringType, + arena->addType(UnionType{{hi, bye}}) + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & ~Child") +{ + const TypeId ty = arena->addType(IntersectionType{{ + arena->addType(UnionType{{childClass, unrelatedClass}}), + arena->addType(NegationType{childClass}) + }}); + + CHECK("Unrelated" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "string & ~Child") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{ + builtinTypes->stringType, + arena->addType(NegationType{childClass}) + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & Child") +{ + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{ + arena->addType(UnionType{{childClass, unrelatedClass}}), + childClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(Child | AnotherChild) & ~Child") +{ + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{ + arena->addType(UnionType{{childClass, anotherChild}}), + childClass + }}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: never }") +{ + const TypeId ty = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->neverType}}); + + CHECK("never" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: number? } & { x: string }") +{ + const TypeId leftTable = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->optionalNumberType}}); + const TypeId rightTable = tbl({{"x", builtinTypes->stringType}}); + + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{leftTable, rightTable}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & add") +{ + const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}}); + const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{ + builtinTypeFunctions().addFunc, + {u, parentClass}, + {} + }); + + const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}}); + + CHECK("Child & add" == simplifyStr(intersection)); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & intersect") +{ + const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}}); + const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{ + builtinTypeFunctions().intersectFunc, + {u, parentClass}, + {} + }); + + const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}}); + + CHECK("Child" == simplifyStr(intersection)); +} + +// {someKey: ~any} +// +// Maybe something we could do here is to try to reduce the key, get the +// class->node mapping, and skip the extraction process if the class corresponds +// to TNever. + +// t1 where t1 = add, number> + +TEST_SUITE_END(); diff --git a/tests/Fixture.h b/tests/Fixture.h index 0db208d9..39222a25 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -293,3 +293,37 @@ using DifferFixtureWithBuiltins = DifferFixtureGeneric; } while (false) #define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result) + +#define LUAU_CHECK_HAS_KEY(map, key) \ + do \ + { \ + auto&& _m = (map); \ + auto&& _k = (key); \ + const size_t count = _m.count(_k); \ + CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \ + if (!count) \ + { \ + MESSAGE("Keys: (count " << _m.size() << ")"); \ + for (const auto& [k, v] : _m) \ + { \ + MESSAGE("\tkey: " << k); \ + } \ + } \ + } while (false) + +#define LUAU_CHECK_HAS_NO_KEY(map, key) \ + do \ + { \ + auto&& _m = (map); \ + auto&& _k = (key); \ + const size_t count = _m.count(_k); \ + CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \ + if (count) \ + { \ + MESSAGE("Keys: (count " << _m.size() << ")"); \ + for (const auto& [k, v] : _m) \ + { \ + MESSAGE("\tkey: " << k); \ + } \ + } \ + } while (false) diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp index de2e9832..81e42f87 100644 --- a/tests/FragmentAutocomplete.test.cpp +++ b/tests/FragmentAutocomplete.test.cpp @@ -4,19 +4,37 @@ #include "Fixture.h" #include "Luau/Ast.h" #include "Luau/AstQuery.h" +#include "Luau/Autocomplete.h" +#include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" #include "Luau/Frontend.h" +#include "Luau/AutocompleteTypes.h" using namespace Luau; LUAU_FASTFLAG(LuauAllowFragmentParsing); LUAU_FASTFLAG(LuauStoreDFGOnModule2); +LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) +static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) +{ + return std::nullopt; +} struct FragmentAutocompleteFixture : Fixture { - ScopedFastFlag sffs[3] = {{FFlag::LuauAllowFragmentParsing, true}, {FFlag::LuauSolverV2, true}, {FFlag::LuauStoreDFGOnModule2, true}}; + ScopedFastFlag sffs[4] = { + {FFlag::LuauAllowFragmentParsing, true}, + {FFlag::LuauSolverV2, true}, + {FFlag::LuauStoreDFGOnModule2, true}, + {FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete, true} + }; + FragmentAutocompleteFixture() + { + addGlobalBinding(frontend.globals, "table", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globals, "math", Binding{builtinTypes->anyType}); + } FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos) { ParseResult p = tryParse(source); // We don't care about parsing incomplete asts @@ -26,7 +44,6 @@ struct FragmentAutocompleteFixture : Fixture CheckResult checkBase(const std::string& document) { - ScopedFastFlag sff{FFlag::LuauSolverV2, true}; FrontendOptions opts; opts.retainFullTypeGraphs = true; return this->frontend.check("MainModule", opts); @@ -48,6 +65,16 @@ struct FragmentAutocompleteFixture : Fixture options.runLintChecks = false; return Luau::typecheckFragment(frontend, "MainModule", cursorPos, options, document); } + + FragmentAutocompleteResult autocompleteFragment(const std::string& document, Position cursorPos) + { + FrontendOptions options; + options.retainFullTypeGraphs = true; + // Don't strictly need this in the new solver + options.forAutocomplete = true; + options.runLintChecks = false; + return Luau::fragmentAutocomplete(frontend, document, "MainModule", cursorPos, options, nullCallback); + } }; TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); @@ -172,6 +199,13 @@ TEST_SUITE_END(); TEST_SUITE_BEGIN("FragmentAutocompleteParserTests"); +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer") +{ + check("local a ="); + auto fragment = parseFragment("local a =", Position(0, 10)); + CHECK_EQ("local a =", fragment.fragmentToParse); +} + TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") { auto res = check(R"( @@ -278,6 +312,33 @@ local y = 5 CHECK_EQ("y", std::string(rhs->name.value)); } +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope") +{ + + check(R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )"); + + auto fragment = parseFragment( + R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )", + Position{6, 0} + ); + + + + CHECK_EQ("function abc()\n local myInnerLocal = 1\n\n end\n", fragment.fragmentToParse); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); @@ -302,7 +363,7 @@ local z = x + y Position{3, 15} ); - auto opt = linearSearchForBinding(fragment.freshScope, "z"); + auto opt = linearSearchForBinding(fragment.freshScope.get(), "z"); REQUIRE(opt); CHECK_EQ("number", toString(*opt)); } @@ -326,9 +387,222 @@ local y = 5 Position{2, 11} ); - auto correct = linearSearchForBinding(fragment.freshScope, "z"); + auto correct = linearSearchForBinding(fragment.freshScope.get(), "z"); REQUIRE(correct); CHECK_EQ("number", toString(*correct)); } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("FragmentAutocompleteTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_simple_property_access") +{ + auto res = check( + R"( +local tbl = { abc = 1234} +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = autocompleteFragment( + R"( +local tbl = { abc = 1234} +tbl. +)", + Position{2, 5} + ); + + LUAU_ASSERT(fragment.freshScope); + + CHECK_EQ(1, fragment.acResults.entryMap.size()); + CHECK(fragment.acResults.entryMap.count("abc")); + CHECK_EQ(AutocompleteContext::Property, fragment.acResults.context); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_nested_property_access") +{ + auto res = check( + R"( +local tbl = { abc = { def = 1234, egh = false } } +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = autocompleteFragment( + R"( +local tbl = { abc = { def = 1234, egh = false } } +tbl.abc. +)", + Position{2, 8} + ); + + LUAU_ASSERT(fragment.freshScope); + + CHECK_EQ(2, fragment.acResults.entryMap.size()); + CHECK(fragment.acResults.entryMap.count("def")); + CHECK(fragment.acResults.entryMap.count("egh")); + CHECK_EQ(fragment.acResults.context, AutocompleteContext::Property); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "inline_autocomplete_picks_the_right_scope") +{ + auto res = check( + R"( +type Table = { a: number, b: number } +do + type Table = { x: string, y: string } +end +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = autocompleteFragment( + R"( +type Table = { a: number, b: number } +do + type Table = { x: string, y: string } + local a : T +end +)", + Position{4, 15} + ); + + LUAU_ASSERT(fragment.freshScope); + + REQUIRE(fragment.acResults.entryMap.count("Table")); + REQUIRE(fragment.acResults.entryMap["Table"].type); + const TableType* tv = get(follow(*fragment.acResults.entryMap["Table"].type)); + REQUIRE(tv); + CHECK(tv->props.count("x")); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nested_recursive_function") +{ + auto res = check(R"( +function foo() +end +)"); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = autocompleteFragment( + R"( +function foo() +end +)", + Position{2, 0} + ); + + CHECK(fragment.acResults.entryMap.count("foo")); + CHECK_EQ(AutocompleteContext::Statement, fragment.acResults.context); +} + + +// Start compatibility tests! + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "empty_program") +{ + check(""); + + auto frag = autocompleteFragment(" ", Position{0, 1}); + auto ac = frag.acResults; + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer") +{ + check("local a ="); + auto frag = autocompleteFragment("local a =", Position{0, 9}); + auto ac = frag.acResults; + + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "leave_numbers_alone") +{ + check("local a = 3."); + + auto frag = autocompleteFragment("local a = 3.", Position{0, 12}); + auto ac = frag.acResults; + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::Unknown); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "user_defined_globals") +{ + check("local myLocal = 4; "); + + auto frag = autocompleteFragment("local myLocal = 4; ", Position{0, 18}); + auto ac = frag.acResults; + + CHECK(ac.entryMap.count("myLocal")); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "dont_suggest_local_before_its_definition") +{ + check(R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )"); + + // autocomplete after abc but before myInnerLocal + auto fragment = autocompleteFragment( + R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end +)", + Position{3, 0} + ); + auto ac = fragment.acResults; + CHECK(ac.entryMap.count("myLocal")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal"); + + // autocomplete after my inner local + fragment = autocompleteFragment( + R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )", + Position{4, 0} + ); + ac = fragment.acResults; + CHECK(ac.entryMap.count("myLocal")); + CHECK(ac.entryMap.count("myInnerLocal")); + + fragment = autocompleteFragment( + R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )", + Position{6, 0} + ); + + ac = fragment.acResults; + CHECK(ac.entryMap.count("myLocal")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal"); +} + +TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 0ab402b5..69330057 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -18,6 +18,7 @@ LUAU_FASTINT(LuauParseErrorLimit) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionsSyntax2) +LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport) namespace { @@ -2377,10 +2378,15 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") { ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag sff2{FFlag::LuauUserDefinedTypeFunParseExport, true}; AstStat* stat = parse(R"( type function foo() - return + return types.number + end + + export type function bar() + return types.string end )"); @@ -2417,7 +2423,6 @@ TEST_CASE_FIXTURE(Fixture, "invalid_user_defined_type_functions") { ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; - matchParseError("export type function foo() end", "Type function cannot be exported"); matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'"); matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'"); } diff --git a/tests/RequireByString.test.cpp b/tests/RequireByString.test.cpp index 641323c2..f9bc3afb 100644 --- a/tests/RequireByString.test.cpp +++ b/tests/RequireByString.test.cpp @@ -424,6 +424,13 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireUnprefixedPath") assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"}); } +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithExtension") +{ + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency.luau"; + runProtectedRequire(path); + assertOutputContainsAll({"false", "error requiring module: consider removing the file extension"}); +} + TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias") { std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/alias_requirer"; diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 422315f9..dedf8824 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -964,6 +964,7 @@ TEST_CASE_FIXTURE(Fixture, "correct_stringification_user_defined_type_functions" std::vector{builtinTypes->numberType}, // Type Function Arguments {}, {AstName{"woohoo"}}, // Type Function Name + {}, }; Type tv{tftt}; diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp index 29d7e8a7..eca633a8 100644 --- a/tests/TypeFunction.user.test.cpp +++ b/tests/TypeFunction.user.test.cpp @@ -16,6 +16,8 @@ LUAU_FASTFLAG(LuauUserTypeFunFixNoReadWrite) LUAU_FASTFLAG(LuauUserTypeFunFixMetatable) LUAU_FASTFLAG(LuauUserDefinedTypeFunctionResetState) LUAU_FASTFLAG(LuauUserTypeFunNonstrict) +LUAU_FASTFLAG(LuauUserTypeFunExportedAndLocal) +LUAU_FASTFLAG(LuauUserDefinedTypeFunParseExport) TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); @@ -1298,4 +1300,92 @@ local a: foo<> = "a" LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "implicit_export") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; + + fileResolver.source["game/A"] = R"( +type function concat(a, b) + return types.singleton(a:value() .. b:value()) +end +export type Concat = concat +local a: concat<'first', 'second'> +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")"); + + CheckResult bResult = check(R"( +local Test = require(game.A); +local b: Test.Concat<'third', 'fourth'> + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + CHECK(toString(requireType("b")) == R"("thirdfourth")"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "local_scope") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; + + CheckResult result = check(R"( +type function foo() + return "hi" +end +local function test() + type function bar() + return types.singleton(foo()) + end + + return ("" :: any) :: bar<> +end +local a = test() + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(toString(requireType("a")) == R"("hi")"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "explicit_export") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag udtfSyntax{FFlag::LuauUserDefinedTypeFunctionsSyntax2, true}; + ScopedFastFlag udtf{FFlag::LuauUserDefinedTypeFunctions2, true}; + ScopedFastFlag luauUserTypeFunFixRegister{FFlag::LuauUserTypeFunFixRegister, true}; + ScopedFastFlag luauUserTypeFunExportedAndLocal{FFlag::LuauUserTypeFunExportedAndLocal, true}; + ScopedFastFlag luauUserDefinedTypeFunParseExport{FFlag::LuauUserDefinedTypeFunParseExport, true}; + + fileResolver.source["game/A"] = R"( +export type function concat(a, b) + return types.singleton(a:value() .. b:value()) +end +local a: concat<'first', 'second'> +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")"); + + CheckResult bResult = check(R"( +local Test = require(game.A); +local b: Test.concat<'third', 'fourth'> + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + CHECK(toString(requireType("b")) == R"("thirdfourth")"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index ad4f9a85..3686f2d4 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -20,6 +20,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauRetrySubtypingWithoutHiddenPack) +LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions) TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -681,6 +682,11 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") { + // CLI-114134: this code *probably* wants the egraph in order + // to work properly. The new solver either falls over or + // forces so many constraints as to be unreliable. + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( function bottomupmerge(comp, a, b, left, mid, right) local i, j = left, mid @@ -743,6 +749,11 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") { + // CLI-114134: this code *probably* wants the egraph in order + // to work properly. The new solver either falls over or + // forces so many constraints as to be unreliable. + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( function bottomupmerge(comp, a, b, left, mid, right) local i, j = left, mid @@ -2554,8 +2565,17 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type") { - if (!FFlag::LuauSolverV2) - return; + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontRefCountTypesInTypeFunctions, true} + }; + + // CLI-114134: This test: + // a) Has a kind of weird result (suggesting `number | false` is not great); + // b) Is force solving some constraints. + // We end up with a weird recursive type that, if you roughly look at it, is + // clearly `number`. Hopefully the egraph will be able to unfold this. + CheckResult result = check(R"( function fib(n) return n < 2 and 1 or fib(n-1) + fib(n-2) @@ -2565,9 +2585,7 @@ end LUAU_REQUIRE_ERRORS(result); auto err = get(result.errors.back()); LUAU_ASSERT(err); - CHECK("number" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("number" == toString(err->recommendedArgs[0].second)); + CHECK("false | number" == toString(err->recommendedReturn)); } TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type") @@ -2862,6 +2880,8 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun") TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") { + ScopedFastFlag _{FFlag::LuauDontRefCountTypesInTypeFunctions, true}; + CheckResult result = check(R"( function foo(player) local success,result = player:thing() @@ -2889,7 +2909,7 @@ TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") auto tm2 = get(result.errors[1]); REQUIRE(tm2); CHECK(toString(tm2->wantedTp) == "string"); - CHECK(toString(tm2->givenTp) == "buffer | class | function | number | string | table | thread | true"); + CHECK(toString(tm2->givenTp) == "(buffer | class | function | number | string | table | thread | true) & unknown"); } else { diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index fd8e06a7..80dddc67 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -24,6 +24,7 @@ LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauRecursionLimit); LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTFLAG(LuauNewSolverVisitErrorExprLvalues) +LUAU_FASTFLAG(LuauDontRefCountTypesInTypeFunctions) using namespace Luau; @@ -1730,4 +1731,36 @@ TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue") )")); } +TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontRefCountTypesInTypeFunctions, true} + }; + + LUAU_CHECK_NO_ERRORS(check(R"( + --!strict + local function foo(a : string?) + local b = a or "" + return b:upper() + end + )")); +} + +TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontRefCountTypesInTypeFunctions, true} + }; + + LUAU_CHECK_NO_ERRORS(check(R"( + --!strict + local function wtf(name: string?) + local message + message = "invalid alternate fiber: " .. (name or "UNNAMED alternate") + end + )")); +} + TEST_SUITE_END(); From e6bf71871a6b9f601545dba8a42ce89c6069675c Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 8 Nov 2024 16:23:09 -0800 Subject: [PATCH 6/8] CodeGen: Rewrite dot product lowering using a dedicated IR instruction (#1512) Instead of doing the dot product related math in scalar IR, we lift the computation into a dedicated IR instruction. On x64, we can use VDPPS which was more or less tailor made for this purpose. This is better than manual scalar lowering that requires reloading components from memory; it's not always a strict improvement over the shuffle+add version (which we never had), but this can now be adjusted in the IR lowering in an optimal fashion (maybe even based on CPU vendor, although that'd create issues for offline compilation). On A64, we can either use naive adds or paired adds, as there is no dedicated vector-wide horizontal instruction until SVE. Both run at about the same performance on M2, but paired adds require fewer instructions and temporaries. I've measured this using mesh-normal-vector benchmark, changing the benchmark to just report the time of the second loop inside `calculate_normals`, testing master vs #1504 vs this PR, also increasing the grid size to 400 for more stable timings. On Zen 4 (7950X), this PR is comfortably ~8% faster vs master, while I see neutral to negative results in #1504. On M2 (base), this PR is ~28% faster vs master, while #1504 is only about ~10% faster. If I measure the second loop in `calculate_tangent_space` instead, I get: On Zen 4 (7950X), this PR is ~12% faster vs master, while #1504 is ~3% faster On M2 (base), this PR is ~24% faster vs master, while #1504 is only about ~13% faster. Note that the loops in question are not quite optimal, as they store and reload various vectors to dictionary values due to inappropriate use of locals. The underlying gains in individual functions are thus larger than the numbers above; for example, changing the `calculate_normals` loop to use a local variable to store the normalized vector (but still saving the result to dictionary value), I get a ~24% performance increase from this PR on Zen4 vs master instead of just 8% (#1504 is ~15% slower in this setup). --- CodeGen/include/Luau/AssemblyBuilderA64.h | 1 + CodeGen/include/Luau/AssemblyBuilderX64.h | 2 + CodeGen/include/Luau/IrData.h | 4 + CodeGen/include/Luau/IrUtils.h | 1 + CodeGen/src/AssemblyBuilderA64.cpp | 8 ++ CodeGen/src/AssemblyBuilderX64.cpp | 5 ++ CodeGen/src/IrDump.cpp | 2 + CodeGen/src/IrLoweringA64.cpp | 15 ++++ CodeGen/src/IrLoweringX64.cpp | 14 +++ CodeGen/src/IrTranslateBuiltins.cpp | 104 +++++++++++++++------- CodeGen/src/IrUtils.cpp | 2 + CodeGen/src/OptimizeConstProp.cpp | 4 +- tests/AssemblyBuilderA64.test.cpp | 3 + tests/AssemblyBuilderX64.test.cpp | 2 + 14 files changed, 135 insertions(+), 32 deletions(-) diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index a4d857a4..9d337942 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -138,6 +138,7 @@ public: void fneg(RegisterA64 dst, RegisterA64 src); void fsqrt(RegisterA64 dst, RegisterA64 src); void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void faddp(RegisterA64 dst, RegisterA64 src); // Vector component manipulation void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index c52d95c5..30790ee5 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -167,6 +167,8 @@ public: void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle); void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset); + void vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask); + // Run final checks bool finalize(); diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index ae406bbc..b603af9e 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -194,6 +194,10 @@ enum class IrCmd : uint8_t // A: TValue UNM_VEC, + // Compute dot product between two vectors + // A, B: TValue + DOT_VEC, + // Compute Luau 'not' operation on destructured TValue // A: tag // B: int (value) diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 8d48780f..08700573 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -176,6 +176,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: case IrCmd::DIV_VEC: + case IrCmd::DOT_VEC: case IrCmd::UNM_VEC: case IrCmd::NOT_ANY: case IrCmd::CMP_ANY: diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index b98a21f2..23384e57 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -586,6 +586,14 @@ void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src) placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000); } +void AssemblyBuilderA64::faddp(RegisterA64 dst, RegisterA64 src) +{ + CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::s); + CODEGEN_ASSERT(dst.kind == src.kind); + + placeR1("faddp", dst, src, 0b011'11110'0'0'11000'01101'10 | ((dst.kind == KindA64::d) << 12)); +} + void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { if (dst.kind == KindA64::d) diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 73c40679..1e646bcb 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -946,6 +946,11 @@ void AssemblyBuilderX64::vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 s placeAvx("vpinsrd", dst, src1, src2, offset, 0x22, false, AVX_0F3A, AVX_66); } +void AssemblyBuilderX64::vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask) +{ + placeAvx("vdpps", dst, src1, src2, mask, 0x40, false, AVX_0F3A, AVX_66); +} + bool AssemblyBuilderX64::finalize() { code.resize(codePos - code.data()); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 2846db54..f4806b31 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -163,6 +163,8 @@ const char* getCmdName(IrCmd cmd) return "DIV_VEC"; case IrCmd::UNM_VEC: return "UNM_VEC"; + case IrCmd::DOT_VEC: + return "DOT_VEC"; case IrCmd::NOT_ANY: return "NOT_ANY"; case IrCmd::CMP_ANY: diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index a63655cc..45ae5eeb 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -728,6 +728,21 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fneg(inst.regA64, regOp(inst.a)); break; } + case IrCmd::DOT_VEC: + { + inst.regA64 = regs.allocReg(KindA64::d, index); + + RegisterA64 temp = regs.allocTemp(KindA64::q); + RegisterA64 temps = castReg(KindA64::s, temp); + RegisterA64 regs = castReg(KindA64::s, inst.regA64); + + build.fmul(temp, regOp(inst.a), regOp(inst.b)); + build.faddp(regs, temps); // x+y + build.dup_4s(temp, temp, 2); + build.fadd(regs, regs, temps); // +z + build.fcvt(inst.regA64, regs); + break; + } case IrCmd::NOT_ANY: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index d06cef13..3e4592bf 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -675,6 +675,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0)); break; } + case IrCmd::DOT_VEC: + { + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); + + ScopedRegX64 tmp1{regs}; + ScopedRegX64 tmp2{regs}; + + RegisterX64 tmpa = vecOp(inst.a, tmp1); + RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + + build.vdpps(inst.regX64, tmpa, tmpb, 0x71); // 7 = 0b0111, sum first 3 products into first float + build.vcvtss2sd(inst.regX64, inst.regX64, inst.regX64); + break; + } case IrCmd::NOT_ANY: { // TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index cec18204..ebded522 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -14,6 +14,7 @@ static const int kMinMaxUnrolledParams = 5; static const int kBit32BinaryOpUnrolledParams = 5; LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen); +LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot); namespace Luau { @@ -907,15 +908,26 @@ static BuiltinImplResult translateBuiltinVectorMagnitude( build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); - IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); - IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); - IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + IrOp sum; - IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); - IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); - IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); - IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + sum = build.inst(IrCmd::DOT_VEC, a, a); + } + else + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + } IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); @@ -945,25 +957,43 @@ static BuiltinImplResult translateBuiltinVectorNormalize( build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); - IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); - IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); - IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); + IrOp sum = build.inst(IrCmd::DOT_VEC, a, a); - IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); - IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); - IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + IrOp invvec = build.inst(IrCmd::NUM_TO_VEC, inv); - IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + IrOp result = build.inst(IrCmd::MUL_VEC, a, invvec); - IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); - IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + result = build.inst(IrCmd::TAG_VECTOR, result); - IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); - IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); - IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result); + } + else + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); - build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); - build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + } return {BuiltinImplType::Full, 1}; } @@ -1019,19 +1049,31 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos)); - IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); - IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); - IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + IrOp sum; - IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); - IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4)); - IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); + IrOp b = build.inst(IrCmd::LOAD_TVALUE, args, build.constInt(0)); - IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); - IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); - IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); + sum = build.inst(IrCmd::DOT_VEC, a, b); + } + else + { + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); - IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); + IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); + + sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); + } build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index ebf4c34b..c1183a47 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -75,6 +75,8 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::DIV_VEC: case IrCmd::UNM_VEC: return IrValueKind::Tvalue; + case IrCmd::DOT_VEC: + return IrValueKind::Double; case IrCmd::NOT_ANY: case IrCmd::CMP_ANY: return IrValueKind::Int; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index fa1b18d3..6d453765 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -768,7 +768,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (tag == LUA_TBOOLEAN && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int))) canSplitTvalueStore = true; - else if (tag == LUA_TNUMBER && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) + else if (tag == LUA_TNUMBER && + (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) canSplitTvalueStore = true; else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst) canSplitTvalueStore = true; @@ -1342,6 +1343,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: case IrCmd::DIV_VEC: + case IrCmd::DOT_VEC: if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR) replace(function, inst.a, a->a); diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 2cd821b5..ee319a5f 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -400,6 +400,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath") SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841); SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD); + SINGLE_COMPARE(faddp(s29, s28), 0x7E30DB9D); + SINGLE_COMPARE(faddp(d29, d28), 0x7E70DB9D); + SINGLE_COMPARE(frinta(d1, d2), 0x1E664041); SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 655fa8f1..016616e0 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -577,6 +577,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4); SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02); + + SINGLE_COMPARE(vdpps(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x40, 0x3c, 0x11, 0x02); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions") From 53e6e4b8f0b74e8770c41ff9bf7165ecfa9da1e2 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Mon, 11 Nov 2024 12:39:09 -0800 Subject: [PATCH 7/8] Fix mesh-normal-vector benchmark array access (#1514) mesh-normal-scalar correctly fills sequential values in the output for triangle cone function, but mesh-normal-vector accidentally reuses the loop index, which results in writes to every third index of the array (1, 4, etc.). This is both slower (as the table turns into a hash map), and incorrect, especially as we have a scalar version of the benchmark that does the right thing. Note: there's a bunch of inefficiencies in the benchmark code that I have not fixed (around field access mostly, e.g. writing to `v.n` and then immediately reading it again). These are not ideal for performance, but they can be valuable to keep as is because this redundancy is common in real-world code, and it would be nice to see codegen optimizations eliminating most of that overhead. This one, however, is a straight up bug, and sparse arrays should not really be the thing this benchmark hits. --- bench/tests/mesh-normal-vector.lua | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/bench/tests/mesh-normal-vector.lua b/bench/tests/mesh-normal-vector.lua index b34f48f8..bfc0f1c7 100644 --- a/bench/tests/mesh-normal-vector.lua +++ b/bench/tests/mesh-normal-vector.lua @@ -86,7 +86,8 @@ function test() function compute_triangle_cones() local mesh_area = 0 - local i = 1 + local pos = 1 + for i = 1,#mesh.indices,3 do local p0 = mesh.vertices[mesh.indices[i]] local p1 = mesh.vertices[mesh.indices[i + 1]] @@ -100,9 +101,9 @@ function test() local area = vector.magnitude(normal) local invarea = (area == 0) and 0 or 1 / area; - mesh.triangle_cone_p[i] = (p0.p + p1.p + p2.p) / 3 - mesh.triangle_cone_n[i] = normal * invarea - i += 1 + mesh.triangle_cone_p[pos] = (p0.p + p1.p + p2.p) / 3 + mesh.triangle_cone_n[pos] = normal * invarea + pos += 1 mesh_area += area end From d1025d00292c4fda84a94ead6b6705c0bedf8e06 Mon Sep 17 00:00:00 2001 From: Varun Saini <61795485+vrn-sn@users.noreply.github.com> Date: Tue, 12 Nov 2024 14:25:04 -0800 Subject: [PATCH 8/8] Remove noexcepts from Config (#1523) Fixes https://github.com/luau-lang/luau/issues/1515. By removing these `noexcept`s, we guarantee that the internal call to `std::swap` uses move semantics when a `Config` is copy-assigned. --- Config/include/Luau/Config.h | 8 ++++---- Config/src/Config.cpp | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/Config/include/Luau/Config.h b/Config/include/Luau/Config.h index d6016229..64b76f07 100644 --- a/Config/include/Luau/Config.h +++ b/Config/include/Luau/Config.h @@ -21,10 +21,10 @@ constexpr const char* kConfigName = ".luaurc"; struct Config { Config(); - Config(const Config& other) noexcept; - Config& operator=(const Config& other) noexcept; - Config(Config&& other) noexcept = default; - Config& operator=(Config&& other) noexcept = default; + Config(const Config& other); + Config& operator=(const Config& other); + Config(Config&& other) = default; + Config& operator=(Config&& other) = default; Mode mode = Mode::Nonstrict; diff --git a/Config/src/Config.cpp b/Config/src/Config.cpp index 3760fd9e..345e039c 100644 --- a/Config/src/Config.cpp +++ b/Config/src/Config.cpp @@ -17,7 +17,7 @@ Config::Config() enabledLint.setDefaults(); } -Config::Config(const Config& other) noexcept +Config::Config(const Config& other) : mode(other.mode) , parseOptions(other.parseOptions) , enabledLint(other.enabledLint) @@ -40,7 +40,7 @@ Config::Config(const Config& other) noexcept } } -Config& Config::operator=(const Config& other) noexcept +Config& Config::operator=(const Config& other) { if (this != &other) {