diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 8e1bf983..72a0c9ff 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -63,10 +63,10 @@ jobs: } valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-nonstrict | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/LuauPolyfillMap.lua 2>&1 | filter map-strict | tee -a analyze-output.txt - valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=DebugLuauDeferredConstraintResolution bench/other/LuauPolyfillMap.lua 2>&1 | filter map-dcr | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=LuauSolverV2 bench/other/LuauPolyfillMap.lua 2>&1 | filter map-dcr | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=nonstrict bench/other/regex.lua 2>&1 | filter regex-nonstrict | tee -a analyze-output.txt valgrind --tool=callgrind ./luau-analyze --mode=strict bench/other/regex.lua 2>&1 | filter regex-strict | tee -a analyze-output.txt - valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=DebugLuauDeferredConstraintResolution bench/other/regex.lua 2>&1 | filter regex-dcr | tee -a analyze-output.txt + valgrind --tool=callgrind ./luau-analyze --mode=strict --fflags=LuauSolverV2 bench/other/regex.lua 2>&1 | filter regex-dcr | tee -a analyze-output.txt - name: Run benchmark (compile) run: | diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7a2b5f10..0ac6e2e4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -46,9 +46,9 @@ jobs: - name: make cli run: | make -j2 config=sanitize werror=1 luau luau-analyze luau-compile # match config with tests to improve build time - ./luau tests/conformance/assert.lua - ./luau-analyze tests/conformance/assert.lua - ./luau-compile tests/conformance/assert.lua + ./luau tests/conformance/assert.luau + ./luau-analyze tests/conformance/assert.luau + ./luau-compile tests/conformance/assert.luau windows: runs-on: windows-latest @@ -81,12 +81,12 @@ jobs: shell: bash # necessary for fail-fast run: | cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Debug # match config with tests to improve build time - Debug/luau tests/conformance/assert.lua - Debug/luau-analyze tests/conformance/assert.lua - Debug/luau-compile tests/conformance/assert.lua + Debug/luau tests/conformance/assert.luau + Debug/luau-analyze tests/conformance/assert.luau + Debug/luau-compile tests/conformance/assert.luau coverage: - runs-on: ubuntu-20.04 # needed for clang++-10 to avoid gcov compatibility issues + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v2 - name: install @@ -94,7 +94,7 @@ jobs: sudo apt install llvm - name: make coverage run: | - CXX=clang++-10 make -j2 config=coverage native=1 coverage + CXX=clang++ make -j2 config=coverage native=1 coverage - name: upload coverage uses: codecov/codecov-action@v3 with: diff --git a/.github/workflows/new-release.yml b/.github/workflows/new-release.yml index 078a18f9..64a85a0c 100644 --- a/.github/workflows/new-release.yml +++ b/.github/workflows/new-release.yml @@ -29,8 +29,8 @@ jobs: build: needs: ["create-release"] strategy: - matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility - os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}] + matrix: # not using ubuntu-latest to improve compatibility + os: [{name: ubuntu, version: ubuntu-22.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}] name: ${{matrix.os.name}} runs-on: ${{matrix.os.version}} steps: @@ -38,7 +38,7 @@ jobs: - name: configure run: cmake . -DCMAKE_BUILD_TYPE=Release - name: build - run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI --config Release -j 2 + run: cmake --build . --target Luau.Repl.CLI Luau.Analyze.CLI Luau.Compile.CLI Luau.Ast.CLI --config Release -j 2 - name: pack if: matrix.os.name != 'windows' run: zip luau-${{matrix.os.name}}.zip luau* diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5e18eb68..24454243 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,8 +13,8 @@ on: jobs: build: strategy: - matrix: # using ubuntu-20.04 to build a Linux binary targeting older glibc to improve compatibility - os: [{name: ubuntu, version: ubuntu-20.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}] + matrix: # not using ubuntu-latest to improve compatibility + os: [{name: ubuntu, version: ubuntu-22.04}, {name: macos, version: macos-latest}, {name: windows, version: windows-latest}] name: ${{matrix.os.name}} runs-on: ${{matrix.os.version}} steps: diff --git a/.gitignore b/.gitignore index 764b97cf..5852b330 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ /luau /luau-tests /luau-analyze +/luau-bytecode /luau-compile __pycache__ .cache diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index bc709c7f..b54f7a44 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -1,10 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/AutocompleteTypes.h" #include "Luau/Location.h" #include "Luau/Type.h" -#include #include #include #include @@ -16,89 +16,8 @@ struct Frontend; struct SourceModule; struct Module; struct TypeChecker; - -using ModulePtr = std::shared_ptr; - -enum class AutocompleteContext -{ - Unknown, - Expression, - Statement, - Property, - Type, - Keyword, - String, -}; - -enum class AutocompleteEntryKind -{ - Property, - Binding, - Keyword, - String, - Type, - Module, - GeneratedFunction, -}; - -enum class ParenthesesRecommendation -{ - None, - CursorAfter, - CursorInside, -}; - -enum class TypeCorrectKind -{ - None, - Correct, - CorrectFunctionResult, -}; - -struct AutocompleteEntry -{ - AutocompleteEntryKind kind = AutocompleteEntryKind::Property; - // Nullopt if kind is Keyword - std::optional type = std::nullopt; - bool deprecated = false; - // Only meaningful if kind is Property. - bool wrongIndexType = false; - // Set if this suggestion matches the type expected in the context - TypeCorrectKind typeCorrect = TypeCorrectKind::None; - - std::optional containingClass = std::nullopt; - std::optional prop = std::nullopt; - std::optional documentationSymbol = std::nullopt; - Tags tags; - ParenthesesRecommendation parens = ParenthesesRecommendation::None; - std::optional insertText; - - // Only meaningful if kind is Property. - bool indexedWithSelf = false; -}; - -using AutocompleteEntryMap = std::unordered_map; -struct AutocompleteResult -{ - AutocompleteEntryMap entryMap; - std::vector ancestry; - AutocompleteContext context = AutocompleteContext::Unknown; - - AutocompleteResult() = default; - AutocompleteResult(AutocompleteEntryMap entryMap, std::vector ancestry, AutocompleteContext context) - : entryMap(std::move(entryMap)) - , ancestry(std::move(ancestry)) - , context(context) - { - } -}; - -using ModuleName = std::string; -using StringCompletionCallback = - std::function(std::string tag, std::optional ctx, std::optional contents)>; +struct FileResolver; AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); -constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)"; - } // namespace Luau diff --git a/Analysis/include/Luau/AutocompleteTypes.h b/Analysis/include/Luau/AutocompleteTypes.h new file mode 100644 index 00000000..37d45244 --- /dev/null +++ b/Analysis/include/Luau/AutocompleteTypes.h @@ -0,0 +1,92 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Type.h" + +#include + +namespace Luau +{ + +enum class AutocompleteContext +{ + Unknown, + Expression, + Statement, + Property, + Type, + Keyword, + String, +}; + +enum class AutocompleteEntryKind +{ + Property, + Binding, + Keyword, + String, + Type, + Module, + GeneratedFunction, + RequirePath, +}; + +enum class ParenthesesRecommendation +{ + None, + CursorAfter, + CursorInside, +}; + +enum class TypeCorrectKind +{ + None, + Correct, + CorrectFunctionResult, +}; + +struct AutocompleteEntry +{ + AutocompleteEntryKind kind = AutocompleteEntryKind::Property; + // Nullopt if kind is Keyword + std::optional type = std::nullopt; + bool deprecated = false; + // Only meaningful if kind is Property. + bool wrongIndexType = false; + // Set if this suggestion matches the type expected in the context + TypeCorrectKind typeCorrect = TypeCorrectKind::None; + + std::optional containingClass = std::nullopt; + std::optional prop = std::nullopt; + std::optional documentationSymbol = std::nullopt; + Tags tags; + ParenthesesRecommendation parens = ParenthesesRecommendation::None; + std::optional insertText; + + // Only meaningful if kind is Property. + bool indexedWithSelf = false; +}; + +using AutocompleteEntryMap = std::unordered_map; +struct AutocompleteResult +{ + AutocompleteEntryMap entryMap; + std::vector ancestry; + AutocompleteContext context = AutocompleteContext::Unknown; + + AutocompleteResult() = default; + AutocompleteResult(AutocompleteEntryMap entryMap, std::vector ancestry, AutocompleteContext context) + : entryMap(std::move(entryMap)) + , ancestry(std::move(ancestry)) + , context(context) + { + } +}; + +using StringCompletionCallback = + std::function(std::string tag, std::optional ctx, std::optional contents)>; + +constexpr char kGeneratedAnonymousFunctionEntryName[] = "function (anonymous autofilled)"; + +} // namespace Luau diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 71e50580..7dc38835 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -9,6 +9,8 @@ namespace Luau { +static constexpr char kRequireTagName[] = "require"; + struct Frontend; struct GlobalTypes; struct TypeChecker; @@ -63,10 +65,7 @@ TypeId makeFunction( // Polymorphic bool checked = false ); -void attachMagicFunction(TypeId ty, MagicFunction fn); -void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn); -void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn); -void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn); +void attachMagicFunction(TypeId ty, std::shared_ptr fn); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableType::Props& props, const std::string& baseName); @@ -80,4 +79,16 @@ std::optional tryGetGlobalBinding(GlobalTypes& globals, const std::stri Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name); TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name); + +/** A number of built-in functions are magical enough that we need to match on them specifically by + * name when they are called. These are listed here to be used whenever necessary, instead of duplicating this logic repeatedly. + */ + +bool matchSetMetatable(const AstExprCall& call); +bool matchTableFreeze(const AstExprCall& call); +bool matchAssert(const AstExprCall& call); + +// Returns `true` if the function should introduce typestate for its first argument. +bool shouldTypestateForFirstArgument(const AstExprCall& call); + } // namespace Luau diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index 103b5bbd..7d5ce892 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -4,6 +4,7 @@ #include #include "Luau/TypeArena.h" #include "Luau/Type.h" +#include "Luau/Scope.h" #include @@ -22,8 +23,21 @@ struct CloneState SeenTypePacks seenTypePacks; }; +/** `shallowClone` will make a copy of only the _top level_ constructor of the type, + * while `clone` will make a deep copy of the entire type and its every component. + * + * Be mindful about which behavior you actually _want_. + * + * Persistent types are not cloned as an optimization. + * If a type is cloned in order to mutate it, 'ignorePersistent' has to be set + */ + +TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false); +TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent = false); + TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState); TypeId clone(TypeId tp, TypeArena& dest, CloneState& cloneState); TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState); +Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState); } // namespace Luau diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index 61253732..2b0fbeb7 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -109,6 +109,21 @@ struct FunctionCheckConstraint NotNull> astExpectedTypes; }; +// table_check expectedType exprType +// +// If `expectedType` is a table type and `exprType` is _also_ a table type, +// propogate the member types of `expectedType` into the types of `exprType`. +// This is used to implement bidirectional inference on table assignment. +// Also see: FunctionCheckConstraint. +struct TableCheckConstraint +{ + TypeId expectedType; + TypeId exprType; + AstExprTable* table = nullptr; + NotNull> astTypes; + NotNull> astExpectedTypes; +}; + // prim FreeType ExpectedType PrimitiveType // // FreeType is bounded below by the singleton type and above by PrimitiveType @@ -273,7 +288,8 @@ using ConstraintV = Variant< UnpackConstraint, ReduceConstraint, ReducePackConstraint, - EqualityConstraint>; + EqualityConstraint, + TableCheckConstraint>; struct Constraint { diff --git a/Analysis/include/Luau/ConstraintGenerator.h b/Analysis/include/Luau/ConstraintGenerator.h index eb6e18eb..8a072e82 100644 --- a/Analysis/include/Luau/ConstraintGenerator.h +++ b/Analysis/include/Luau/ConstraintGenerator.h @@ -5,6 +5,7 @@ #include "Luau/Constraint.h" #include "Luau/ControlFlow.h" #include "Luau/DataFlowGraph.h" +#include "Luau/EqSatSimplification.h" #include "Luau/InsertionOrderedMap.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" @@ -15,7 +16,6 @@ #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Variant.h" -#include "Luau/Normalize.h" #include #include @@ -28,6 +28,7 @@ struct Scope; using ScopePtr = std::shared_ptr; struct DcrLogger; +struct TypeFunctionRuntime; struct Inference { @@ -95,6 +96,9 @@ struct ConstraintGenerator // will enqueue them during solving. std::vector unqueuedConstraints; + // Map a function's signature scope back to its signature type. + DenseHashMap scopeToFunction{nullptr}; + // The private scope of type aliases for which the type parameters belong to. DenseHashMap astTypeAliasDefiningScopes{nullptr}; @@ -108,6 +112,11 @@ struct ConstraintGenerator // Needed to be able to enable error-suppression preservation for immediate refinements. NotNull normalizer; + + NotNull simplifier; + + // Needed to register all available type functions for execution at later stages. + NotNull typeFunctionRuntime; // Needed to resolve modules to make 'require' import types properly. NotNull moduleResolver; // Occasionally constraint generation needs to produce an ICE. @@ -125,6 +134,8 @@ struct ConstraintGenerator ConstraintGenerator( ModulePtr module, NotNull normalizer, + NotNull simplifier, + NotNull typeFunctionRuntime, NotNull moduleResolver, NotNull builtinTypes, NotNull ice, @@ -142,6 +153,8 @@ struct ConstraintGenerator */ void visitModuleRoot(AstStatBlock* block); + void visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block); + private: std::vector> interiorTypes; @@ -223,7 +236,10 @@ private: ); void applyRefinements(const ScopePtr& scope, Location location, RefinementId refinement); + LUAU_NOINLINE void checkAliases(const ScopePtr& scope, AstStatBlock* block); + ControlFlow visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block); + ControlFlow visitBlockWithoutChildScope_DEPRECATED(const ScopePtr& scope, AstStatBlock* block); ControlFlow visit(const ScopePtr& scope, AstStat* stat); ControlFlow visit(const ScopePtr& scope, AstStatBlock* block); @@ -282,11 +298,25 @@ private: Inference check(const ScopePtr& scope, AstExprFunction* func, std::optional expectedType, bool generalize); Inference check(const ScopePtr& scope, AstExprUnary* unary); Inference check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); + Inference checkAstExprBinary( + const ScopePtr& scope, + const Location& location, + AstExprBinary::Op op, + AstExpr* left, + AstExpr* right, + std::optional expectedType + ); Inference check(const ScopePtr& scope, AstExprIfElse* ifElse, std::optional expectedType); Inference check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); Inference check(const ScopePtr& scope, AstExprInterpString* interpString); Inference check(const ScopePtr& scope, AstExprTable* expr, std::optional expectedType); - std::tuple checkBinary(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType); + std::tuple checkBinary( + const ScopePtr& scope, + AstExprBinary::Op op, + AstExpr* left, + AstExpr* right, + std::optional expectedType + ); void visitLValue(const ScopePtr& scope, AstExpr* expr, TypeId rhsType); void visitLValue(const ScopePtr& scope, AstExprLocal* local, TypeId rhsType); @@ -321,6 +351,11 @@ private: */ void checkFunctionBody(const ScopePtr& scope, AstExprFunction* fn); + // Specializations of 'resolveType' below + TypeId resolveReferenceType(const ScopePtr& scope, AstType* ty, AstTypeReference* ref, bool inTypeArguments, bool replaceErrorWithFresh); + TypeId resolveTableType(const ScopePtr& scope, AstType* ty, AstTypeTable* tab, bool inTypeArguments, bool replaceErrorWithFresh); + TypeId resolveFunctionType(const ScopePtr& scope, AstType* ty, AstTypeFunction* fn, bool inTypeArguments, bool replaceErrorWithFresh); + /** * Resolves a type from its AST annotation. * @param scope the scope that the type annotation appears within. @@ -360,7 +395,7 @@ private: **/ std::vector> createGenerics( const ScopePtr& scope, - AstArray generics, + AstArray generics, bool useCache = false, bool addTypes = true ); @@ -377,7 +412,7 @@ private: **/ std::vector> createGenericPacks( const ScopePtr& scope, - AstArray packs, + AstArray packs, bool useCache = false, bool addTypes = true ); @@ -391,6 +426,7 @@ private: TypeId makeUnion(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); // make an intersect type function of these two types TypeId makeIntersect(const ScopePtr& scope, Location location, TypeId lhs, TypeId rhs); + void prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program); /** Scan the program for global definitions. * @@ -421,6 +457,8 @@ private: const ScopePtr& scope, Location location ); + + TypeId simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right); }; /** Borrow a vector of pointers from a vector of owning pointers to constraints. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 308b983b..8b5a6dec 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -3,7 +3,9 @@ #pragma once #include "Luau/Constraint.h" +#include "Luau/DataFlowGraph.h" #include "Luau/DenseHash.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Error.h" #include "Luau/Location.h" #include "Luau/Module.h" @@ -12,6 +14,7 @@ #include "Luau/ToString.h" #include "Luau/Type.h" #include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" #include "Luau/TypeFwd.h" #include "Luau/Variant.h" @@ -56,17 +59,42 @@ struct HashInstantiationSignature size_t operator()(const InstantiationSignature& signature) const; }; + +struct TablePropLookupResult +{ + // What types are we blocked on for determining this type? + std::vector blockedTypes; + // The type of the property (if we were able to determine it). + std::optional propType; + // Whether or not this is _definitely_ derived as the result of an indexer. + // We use this to determine whether or not code like: + // + // t.lol = nil; + // + // ... is legal. If `t: { [string]: ~nil }` then this is legal as + // there's no guarantee on whether "lol" specifically exists. + // However, if `t: { lol: ~nil }`, then we cannot allow assignment as + // that would remove "lol" from the table entirely. + bool isIndex = false; +}; + struct ConstraintSolver { NotNull arena; NotNull builtinTypes; InternalErrorReporter iceReporter; NotNull normalizer; + NotNull simplifier; + NotNull typeFunctionRuntime; // The entire set of constraints that the solver is trying to resolve. std::vector> constraints; + NotNull> scopeToFunction; NotNull rootScope; ModuleName currentModuleName; + // The dataflow graph of the program, used in constraint generation and for magic functions. + NotNull dfg; + // Constraints that the solver has generated, rather than sourcing from the // scope tree. std::vector> solverConstraints; @@ -91,6 +119,9 @@ struct ConstraintSolver // A mapping from free types to the number of unresolved constraints that mention them. DenseHashMap unresolvedConstraints{{}}; + std::unordered_map, DenseHashSet> maybeMutatedFreeTypes; + std::unordered_map> mutatedFreeTypeToConstraint; + // Irreducible/uninhabited type functions or type pack functions. DenseHashSet uninhabitedTypeFunctions{{}}; @@ -114,12 +145,16 @@ struct ConstraintSolver explicit ConstraintSolver( NotNull normalizer, + NotNull simplifier, + NotNull typeFunctionRuntime, NotNull rootScope, std::vector> constraints, + NotNull> scopeToFunction, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger, + NotNull dfg, TypeCheckLimits limits ); @@ -139,9 +174,11 @@ struct ConstraintSolver **/ void finalizeTypeFunctions(); - bool isDone(); + bool isDone() const; private: + void generalizeOneType(TypeId ty); + /** * Bind a type variable to another type. * @@ -167,13 +204,14 @@ public: */ bool tryDispatch(NotNull c, bool force); - bool tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const SubtypeConstraint& c, NotNull constraint); + bool tryDispatch(const PackSubtypeConstraint& c, NotNull constraint); + bool tryDispatch(const GeneralizationConstraint& c, NotNull constraint); bool tryDispatch(const IterableConstraint& c, NotNull constraint, bool force); bool tryDispatch(const NameConstraint& c, NotNull constraint); bool tryDispatch(const TypeAliasExpansionConstraint& c, NotNull constraint); bool tryDispatch(const FunctionCallConstraint& c, NotNull constraint); + bool tryDispatch(const TableCheckConstraint& c, NotNull constraint); bool tryDispatch(const FunctionCheckConstraint& c, NotNull constraint); bool tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint); bool tryDispatch(const HasPropConstraint& c, NotNull constraint); @@ -194,16 +232,16 @@ public: bool tryDispatch(const UnpackConstraint& c, NotNull constraint); bool tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force); bool tryDispatch(const ReducePackConstraint& c, NotNull constraint, bool force); - bool tryDispatch(const EqualityConstraint& c, NotNull constraint, bool force); + bool tryDispatch(const EqualityConstraint& c, NotNull constraint); // for a, ... in some_table do // also handles __iter metamethod bool tryDispatchIterableTable(TypeId iteratorTy, const IterableConstraint& c, NotNull constraint, bool force); // for a, ... in next_function, t, ... do - bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint, bool force); + bool tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint); - std::pair, std::optional> lookupTableProp( + TablePropLookupResult lookupTableProp( NotNull constraint, TypeId subjectType, const std::string& propName, @@ -211,7 +249,8 @@ public: bool inConditional = false, bool suppressSimplification = false ); - std::pair, std::optional> lookupTableProp( + + TablePropLookupResult lookupTableProp( NotNull constraint, TypeId subjectType, const std::string& propName, @@ -270,10 +309,10 @@ public: // FIXME: This use of a boolean for the return result is an appalling // interface. bool blockOnPendingTypes(TypeId target, NotNull constraint); - bool blockOnPendingTypes(TypePackId target, NotNull constraint); + bool blockOnPendingTypes(TypePackId targetPack, NotNull constraint); void unblock(NotNull progressed); - void unblock(TypeId progressed, Location location); + void unblock(TypeId ty, Location location); void unblock(TypePackId progressed, Location location); void unblock(const std::vector& types, Location location); void unblock(const std::vector& packs, Location location); @@ -281,18 +320,18 @@ public: /** * @returns true if the TypeId is in a blocked state. */ - bool isBlocked(TypeId ty); + bool isBlocked(TypeId ty) const; /** * @returns true if the TypePackId is in a blocked state. */ - bool isBlocked(TypePackId tp); + bool isBlocked(TypePackId tp) const; /** * Returns whether the constraint is blocked on anything. * @param constraint the constraint to check. */ - bool isBlocked(NotNull constraint); + bool isBlocked(NotNull constraint) const; /** Pushes a new solver constraint to the solver. * @param cv the body of the constraint. @@ -308,7 +347,7 @@ public: * @param location the location where the require is taking place; used for * error locations. **/ - TypeId resolveModule(const ModuleInfo& module, const Location& location); + TypeId resolveModule(const ModuleInfo& info, const Location& location); void reportError(TypeErrorData&& data, const Location& location); void reportError(TypeError e); @@ -379,15 +418,21 @@ public: **/ void reproduceConstraints(NotNull scope, const Location& location, const Substitution& subst); + TypeId simplifyIntersection(NotNull scope, Location location, TypeId left, TypeId right); + TypeId simplifyIntersection(NotNull scope, Location location, std::set parts); + TypeId simplifyUnion(NotNull scope, Location location, TypeId left, TypeId right); + TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); - void throwTimeLimitError(); - void throwUserCancelError(); + void throwTimeLimitError() const; + void throwUserCancelError() const; ToStringOptions opts; + + void fillInDiscriminantTypes(NotNull constraint, const std::vector>& discriminantTypes); }; void dump(NotNull rootScope, struct ToStringOptions& opts); diff --git a/Analysis/include/Luau/DataFlowGraph.h b/Analysis/include/Luau/DataFlowGraph.h index a84561dd..1f28abe9 100644 --- a/Analysis/include/Luau/DataFlowGraph.h +++ b/Analysis/include/Luau/DataFlowGraph.h @@ -6,6 +6,7 @@ #include "Luau/ControlFlow.h" #include "Luau/DenseHash.h" #include "Luau/Def.h" +#include "Luau/NotNull.h" #include "Luau/Symbol.h" #include "Luau/TypedAllocator.h" @@ -35,6 +36,8 @@ struct DataFlowGraph DataFlowGraph& operator=(DataFlowGraph&&) = default; DefId getDef(const AstExpr* expr) const; + // Look up the definition optionally, knowing it may not be present. + std::optional getDefOptional(const AstExpr* expr) const; // Look up for the rvalue def for a compound assignment. std::optional getRValueDefForCompoundAssign(const AstExpr* expr) const; @@ -46,13 +49,13 @@ struct DataFlowGraph const RefinementKey* getRefinementKey(const AstExpr* expr) const; private: - DataFlowGraph() = default; + DataFlowGraph(NotNull defArena, NotNull keyArena); DataFlowGraph(const DataFlowGraph&) = delete; DataFlowGraph& operator=(const DataFlowGraph&) = delete; - DefArena defArena; - RefinementKeyArena keyArena; + NotNull defArena; + NotNull keyArena; DenseHashMap astDefs{nullptr}; @@ -68,7 +71,6 @@ private: DenseHashMap compoundAssignDefs{nullptr}; DenseHashMap astRefinementKeys{nullptr}; - friend struct DataFlowGraphBuilder; }; @@ -105,25 +107,37 @@ struct DataFlowResult const RefinementKey* parent = nullptr; }; +using ScopeStack = std::vector; + struct DataFlowGraphBuilder { - static DataFlowGraph build(AstStatBlock* root, NotNull handle); + static DataFlowGraph build( + AstStatBlock* block, + NotNull defArena, + NotNull keyArena, + NotNull handle + ); private: - DataFlowGraphBuilder() = default; + DataFlowGraphBuilder(NotNull defArena, NotNull keyArena); DataFlowGraphBuilder(const DataFlowGraphBuilder&) = delete; DataFlowGraphBuilder& operator=(const DataFlowGraphBuilder&) = delete; DataFlowGraph graph; - NotNull defArena{&graph.defArena}; - NotNull keyArena{&graph.keyArena}; + NotNull defArena; + NotNull keyArena; struct InternalErrorReporter* handle = nullptr; - DfgScope* moduleScope = nullptr; + /// The arena owning all of the scope allocations for the dataflow graph being built. std::vector> scopes; + /// A stack of scopes used by the visitor to see where we are. + ScopeStack scopeStack; + + DfgScope* currentScope(); + struct FunctionCapture { std::vector captureDefs; @@ -134,81 +148,81 @@ private: DenseHashMap captures{Symbol{}}; void resolveCaptures(); - DfgScope* childScope(DfgScope* scope, DfgScope::ScopeType scopeType = DfgScope::Linear); + DfgScope* makeChildScope(DfgScope::ScopeType scopeType = DfgScope::Linear); void join(DfgScope* p, DfgScope* a, DfgScope* b); void joinBindings(DfgScope* p, const DfgScope& a, const DfgScope& b); void joinProps(DfgScope* p, const DfgScope& a, const DfgScope& b); - DefId lookup(DfgScope* scope, Symbol symbol); - DefId lookup(DfgScope* scope, DefId def, const std::string& key); + DefId lookup(Symbol symbol); + DefId lookup(DefId def, const std::string& key); - ControlFlow visit(DfgScope* scope, AstStatBlock* b); - ControlFlow visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b); + ControlFlow visit(AstStatBlock* b); + ControlFlow visitBlockWithoutChildScope(AstStatBlock* b); - ControlFlow visit(DfgScope* scope, AstStat* s); - ControlFlow visit(DfgScope* scope, AstStatIf* i); - ControlFlow visit(DfgScope* scope, AstStatWhile* w); - ControlFlow visit(DfgScope* scope, AstStatRepeat* r); - ControlFlow visit(DfgScope* scope, AstStatBreak* b); - ControlFlow visit(DfgScope* scope, AstStatContinue* c); - ControlFlow visit(DfgScope* scope, AstStatReturn* r); - ControlFlow visit(DfgScope* scope, AstStatExpr* e); - ControlFlow visit(DfgScope* scope, AstStatLocal* l); - ControlFlow visit(DfgScope* scope, AstStatFor* f); - ControlFlow visit(DfgScope* scope, AstStatForIn* f); - ControlFlow visit(DfgScope* scope, AstStatAssign* a); - ControlFlow visit(DfgScope* scope, AstStatCompoundAssign* c); - ControlFlow visit(DfgScope* scope, AstStatFunction* f); - ControlFlow visit(DfgScope* scope, AstStatLocalFunction* l); - ControlFlow visit(DfgScope* scope, AstStatTypeAlias* t); - ControlFlow visit(DfgScope* scope, AstStatTypeFunction* f); - ControlFlow visit(DfgScope* scope, AstStatDeclareGlobal* d); - ControlFlow visit(DfgScope* scope, AstStatDeclareFunction* d); - ControlFlow visit(DfgScope* scope, AstStatDeclareClass* d); - ControlFlow visit(DfgScope* scope, AstStatError* error); + ControlFlow visit(AstStat* s); + ControlFlow visit(AstStatIf* i); + ControlFlow visit(AstStatWhile* w); + ControlFlow visit(AstStatRepeat* r); + ControlFlow visit(AstStatBreak* b); + ControlFlow visit(AstStatContinue* c); + ControlFlow visit(AstStatReturn* r); + ControlFlow visit(AstStatExpr* e); + ControlFlow visit(AstStatLocal* l); + ControlFlow visit(AstStatFor* f); + ControlFlow visit(AstStatForIn* f); + ControlFlow visit(AstStatAssign* a); + ControlFlow visit(AstStatCompoundAssign* c); + ControlFlow visit(AstStatFunction* f); + ControlFlow visit(AstStatLocalFunction* l); + ControlFlow visit(AstStatTypeAlias* t); + ControlFlow visit(AstStatTypeFunction* f); + ControlFlow visit(AstStatDeclareGlobal* d); + ControlFlow visit(AstStatDeclareFunction* d); + ControlFlow visit(AstStatDeclareClass* d); + ControlFlow visit(AstStatError* error); - DataFlowResult visitExpr(DfgScope* scope, AstExpr* e); - DataFlowResult visitExpr(DfgScope* scope, AstExprGroup* group); - DataFlowResult visitExpr(DfgScope* scope, AstExprLocal* l); - DataFlowResult visitExpr(DfgScope* scope, AstExprGlobal* g); - DataFlowResult visitExpr(DfgScope* scope, AstExprCall* c); - DataFlowResult visitExpr(DfgScope* scope, AstExprIndexName* i); - DataFlowResult visitExpr(DfgScope* scope, AstExprIndexExpr* i); - DataFlowResult visitExpr(DfgScope* scope, AstExprFunction* f); - DataFlowResult visitExpr(DfgScope* scope, AstExprTable* t); - DataFlowResult visitExpr(DfgScope* scope, AstExprUnary* u); - DataFlowResult visitExpr(DfgScope* scope, AstExprBinary* b); - DataFlowResult visitExpr(DfgScope* scope, AstExprTypeAssertion* t); - DataFlowResult visitExpr(DfgScope* scope, AstExprIfElse* i); - DataFlowResult visitExpr(DfgScope* scope, AstExprInterpString* i); - DataFlowResult visitExpr(DfgScope* scope, AstExprError* error); + DataFlowResult visitExpr(AstExpr* e); + DataFlowResult visitExpr(AstExprGroup* group); + DataFlowResult visitExpr(AstExprLocal* l); + DataFlowResult visitExpr(AstExprGlobal* g); + DataFlowResult visitExpr(AstExprCall* c); + DataFlowResult visitExpr(AstExprIndexName* i); + DataFlowResult visitExpr(AstExprIndexExpr* i); + DataFlowResult visitExpr(AstExprFunction* f); + DataFlowResult visitExpr(AstExprTable* t); + DataFlowResult visitExpr(AstExprUnary* u); + DataFlowResult visitExpr(AstExprBinary* b); + DataFlowResult visitExpr(AstExprTypeAssertion* t); + DataFlowResult visitExpr(AstExprIfElse* i); + DataFlowResult visitExpr(AstExprInterpString* i); + DataFlowResult visitExpr(AstExprError* error); - void visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef); - DefId visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef); - DefId visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef); - DefId visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef); - DefId visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef); - DefId visitLValue(DfgScope* scope, AstExprError* e, DefId incomingDef); + void visitLValue(AstExpr* e, DefId incomingDef); + DefId visitLValue(AstExprLocal* l, DefId incomingDef); + DefId visitLValue(AstExprGlobal* g, DefId incomingDef); + DefId visitLValue(AstExprIndexName* i, DefId incomingDef); + DefId visitLValue(AstExprIndexExpr* i, DefId incomingDef); + DefId visitLValue(AstExprError* e, DefId incomingDef); - void visitType(DfgScope* scope, AstType* t); - void visitType(DfgScope* scope, AstTypeReference* r); - void visitType(DfgScope* scope, AstTypeTable* t); - void visitType(DfgScope* scope, AstTypeFunction* f); - void visitType(DfgScope* scope, AstTypeTypeof* t); - void visitType(DfgScope* scope, AstTypeUnion* u); - void visitType(DfgScope* scope, AstTypeIntersection* i); - void visitType(DfgScope* scope, AstTypeError* error); + void visitType(AstType* t); + void visitType(AstTypeReference* r); + void visitType(AstTypeTable* t); + void visitType(AstTypeFunction* f); + void visitType(AstTypeTypeof* t); + void visitType(AstTypeUnion* u); + void visitType(AstTypeIntersection* i); + void visitType(AstTypeError* error); - void visitTypePack(DfgScope* scope, AstTypePack* p); - void visitTypePack(DfgScope* scope, AstTypePackExplicit* e); - void visitTypePack(DfgScope* scope, AstTypePackVariadic* v); - void visitTypePack(DfgScope* scope, AstTypePackGeneric* g); + void visitTypePack(AstTypePack* p); + void visitTypePack(AstTypePackExplicit* e); + void visitTypePack(AstTypePackVariadic* v); + void visitTypePack(AstTypePackGeneric* g); - void visitTypeList(DfgScope* scope, AstTypeList l); + void visitTypeList(AstTypeList l); - void visitGenerics(DfgScope* scope, AstArray g); - void visitGenericPacks(DfgScope* scope, AstArray g); + void visitGenerics(AstArray g); + void visitGenericPacks(AstArray g); }; } // namespace Luau diff --git a/Analysis/include/Luau/EqSatSimplification.h b/Analysis/include/Luau/EqSatSimplification.h new file mode 100644 index 00000000..16d00849 --- /dev/null +++ b/Analysis/include/Luau/EqSatSimplification.h @@ -0,0 +1,50 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/TypeFwd.h" +#include "Luau/NotNull.h" +#include "Luau/DenseHash.h" + +#include +#include +#include + +namespace Luau +{ +struct TypeArena; +} + +// The EqSat stuff is pretty template heavy, so we go to some lengths to prevent +// the complexity from leaking outside its implementation sources. +namespace Luau::EqSatSimplification +{ + +struct Simplifier; + +using SimplifierPtr = std::unique_ptr; + +SimplifierPtr newSimplifier(NotNull arena, NotNull builtinTypes); + +} // namespace Luau::EqSatSimplification + +namespace Luau +{ + +struct EqSatSimplificationResult +{ + TypeId result; + + // New type function applications that were created by the reduction phase. + // We return these so that the ConstraintSolver can know to try to reduce + // them. + std::vector newTypeFunctions; +}; + +using EqSatSimplification::newSimplifier; // NOLINT: clang-tidy thinks these are unused. It is incorrect. +using Luau::EqSatSimplification::Simplifier; // NOLINT +using Luau::EqSatSimplification::SimplifierPtr; + +std::optional eqSatSimplify(NotNull simplifier, TypeId ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/EqSatSimplificationImpl.h b/Analysis/include/Luau/EqSatSimplificationImpl.h new file mode 100644 index 00000000..73019621 --- /dev/null +++ b/Analysis/include/Luau/EqSatSimplificationImpl.h @@ -0,0 +1,376 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/EGraph.h" +#include "Luau/Id.h" +#include "Luau/Language.h" +#include "Luau/Lexer.h" // For Allocator +#include "Luau/NotNull.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFwd.h" + +namespace Luau +{ +struct TypeFunction; +} + +namespace Luau::EqSatSimplification +{ + +using StringId = uint32_t; +using Id = Luau::EqSat::Id; + +LUAU_EQSAT_UNIT(TNil); +LUAU_EQSAT_UNIT(TBoolean); +LUAU_EQSAT_UNIT(TNumber); +LUAU_EQSAT_UNIT(TString); +LUAU_EQSAT_UNIT(TThread); +LUAU_EQSAT_UNIT(TTopFunction); +LUAU_EQSAT_UNIT(TTopTable); +LUAU_EQSAT_UNIT(TTopClass); +LUAU_EQSAT_UNIT(TBuffer); + +// Used for any type that eqsat can't do anything interesting with. +LUAU_EQSAT_ATOM(TOpaque, TypeId); + +LUAU_EQSAT_ATOM(SBoolean, bool); +LUAU_EQSAT_ATOM(SString, StringId); + +LUAU_EQSAT_ATOM(TFunction, TypeId); + +LUAU_EQSAT_ATOM(TImportedTable, TypeId); + +LUAU_EQSAT_ATOM(TClass, TypeId); + +LUAU_EQSAT_UNIT(TAny); +LUAU_EQSAT_UNIT(TError); +LUAU_EQSAT_UNIT(TUnknown); +LUAU_EQSAT_UNIT(TNever); + +LUAU_EQSAT_NODE_SET(Union); +LUAU_EQSAT_NODE_SET(Intersection); + +LUAU_EQSAT_NODE_ARRAY(Negation, 1); + +LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(TTypeFun, std::shared_ptr); + +LUAU_EQSAT_UNIT(TNoRefine); +LUAU_EQSAT_UNIT(Invalid); + +// enodes are immutable, but types are cyclic. We need a way to tie the knot. +// We handle this by generating TBound nodes at points where we encounter cycles. +// Each TBound has an ordinal that we later map onto the type. +// We use a substitution rule to replace all TBound nodes with their referrent. +LUAU_EQSAT_ATOM(TBound, size_t); + +// Tables are sufficiently unlike other enodes that the Language.h macros won't cut it. +struct TTable +{ + explicit TTable(Id basis); + TTable(Id basis, std::vector propNames_, std::vector propTypes_); + + // All TTables extend some other table. This may be TTopTable. + // + // It will frequently be a TImportedTable, in which case we can reuse things + // like source location and documentation info. + Id getBasis() const; + EqSat::Slice propTypes() const; + // TODO: Also support read-only table props + // TODO: Indexer type, index result type. + + std::vector propNames; + + // The enode interface + EqSat::Slice mutableOperands(); + EqSat::Slice operands() const; + bool operator==(const TTable& rhs) const; + bool operator!=(const TTable& rhs) const + { + return !(*this == rhs); + } + + struct Hash + { + size_t operator()(const TTable& value) const; + }; + +private: + // The first element of this vector is the basis. Subsequent elements are + // property types. As we add other things like read-only properties and + // indexers, the structure of this array is likely to change. + // + // We encode our data in this way so that the operands() method can properly + // return a Slice. + std::vector storage; +}; + +template +using Node = EqSat::Node; + +using EType = EqSat::Language< + TNil, + TBoolean, + TNumber, + TString, + TThread, + TTopFunction, + TTopTable, + TTopClass, + TBuffer, + + TOpaque, + + SBoolean, + SString, + + TFunction, + TTable, + TImportedTable, + TClass, + + TAny, + TError, + TUnknown, + TNever, + + Union, + Intersection, + + Negation, + + TTypeFun, + + Invalid, + TNoRefine, + TBound>; + + +struct StringCache +{ + Allocator allocator; + DenseHashMap strings{{}}; + std::vector views; + + StringId add(std::string_view s); + std::string_view asStringView(StringId id) const; + std::string asString(StringId id) const; +}; + +using EGraph = Luau::EqSat::EGraph; + +struct Simplify +{ + using Data = bool; + + template + Data make(const EGraph&, const T&) const; + + void join(Data& left, const Data& right) const; +}; + +struct Subst +{ + Id eclass; + Id newClass; + + // The node into eclass which is boring, if any + std::optional boringIndex; + + std::string desc; + + Subst(Id eclass, Id newClass, std::string desc = ""); +}; + +struct Simplifier +{ + NotNull arena; + NotNull builtinTypes; + EGraph egraph; + StringCache stringCache; + + // enodes are immutable but types can be cyclic, so we need some way to + // encode the cycle. This map is used to connect TBound nodes to the right + // eclass. + // + // The cyclicIntersection rewrite rule uses this to sense when a cycle can + // be deleted from an intersection or union. + std::unordered_map mappingIdToClass; + + std::vector substs; + + using RewriteRuleFn = void (Simplifier::*)(Id id); + + Simplifier(NotNull arena, NotNull builtinTypes); + + // Utilities + const EqSat::EClass& get(Id id) const; + Id find(Id id) const; + Id add(EType enode); + + template + const Tag* isTag(Id id) const; + + template + const Tag* isTag(const EType& enode) const; + + void subst(Id from, Id to); + void subst(Id from, Id to, const std::string& ruleName); + void subst(Id from, Id to, const std::string& ruleName, const std::unordered_map& forceNodes); + void subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map& forceNodes); + + void unionClasses(std::vector& hereParts, Id there); + + // Rewrite rules + void simplifyUnion(Id id); + void uninhabitedIntersection(Id id); + void intersectWithNegatedClass(Id id); + void intersectWithNegatedAtom(Id id); + void intersectWithNoRefine(Id id); + void cyclicIntersectionOfUnion(Id id); + void cyclicUnionOfIntersection(Id id); + void expandNegation(Id id); + void intersectionOfUnion(Id id); + void intersectTableProperty(Id id); + void uninhabitedTable(Id id); + void unneededTableModification(Id id); + void builtinTypeFunctions(Id id); + void iffyTypeFunctions(Id id); + void strictMetamethods(Id id); +}; + +template +struct QueryIterator +{ + QueryIterator(); + QueryIterator(EGraph* egraph, Id eclass); + + bool operator==(const QueryIterator& other) const; + bool operator!=(const QueryIterator& other) const; + + std::pair operator*() const; + + QueryIterator& operator++(); + QueryIterator& operator++(int); + +private: + EGraph* egraph = nullptr; + Id eclass; + size_t index = 0; +}; + +template +struct Query +{ + EGraph* egraph; + Id eclass; + + Query(EGraph* egraph, Id eclass) + : egraph(egraph) + , eclass(eclass) + { + } + + QueryIterator begin() + { + return QueryIterator{egraph, eclass}; + } + + QueryIterator end() + { + return QueryIterator{}; + } +}; + +template +QueryIterator::QueryIterator() + : egraph(nullptr) + , eclass(Id{0}) + , index(0) +{ +} + +template +QueryIterator::QueryIterator(EGraph* egraph_, Id eclass) + : egraph(egraph_) + , eclass(eclass) + , index(0) +{ + const auto& ecl = (*egraph)[eclass]; + + static constexpr const int idx = EType::VariantTy::getTypeId(); + + for (const auto& enode : ecl.nodes) + { + if (enode.node.index() < idx) + ++index; + else + break; + } + + if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != idx) + { + egraph = nullptr; + index = 0; + } +} + +template +bool QueryIterator::operator==(const QueryIterator& rhs) const +{ + if (egraph == nullptr && rhs.egraph == nullptr) + return true; + + return egraph == rhs.egraph && eclass == rhs.eclass && index == rhs.index; +} + +template +bool QueryIterator::operator!=(const QueryIterator& rhs) const +{ + return !(*this == rhs); +} + +template +std::pair QueryIterator::operator*() const +{ + LUAU_ASSERT(egraph != nullptr); + + EGraph::EClassT& ecl = (*egraph)[eclass]; + + LUAU_ASSERT(index < ecl.nodes.size()); + auto& enode = ecl.nodes[index].node; + Tag* result = enode.template get(); + LUAU_ASSERT(result); + return {result, index}; +} + +// pre-increment +template +QueryIterator& QueryIterator::operator++() +{ + const auto& ecl = (*egraph)[eclass]; + + do + { + ++index; + if (index >= ecl.nodes.size() || ecl.nodes[index].node.index() != EType::VariantTy::getTypeId()) + { + egraph = nullptr; + index = 0; + break; + } + } while (ecl.nodes[index].boring); + + return *this; +} + +// post-increment +template +QueryIterator& QueryIterator::operator++(int) +{ + QueryIterator res = *this; + ++res; + return res; +} + +} // namespace Luau::EqSatSimplification diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index baf3318c..fe9d7924 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -448,6 +448,13 @@ struct UnexpectedTypePackInSubtyping bool operator==(const UnexpectedTypePackInSubtyping& rhs) const; }; +struct UserDefinedTypeFunctionError +{ + std::string message; + + bool operator==(const UserDefinedTypeFunctionError& rhs) const; +}; + using TypeErrorData = Variant< TypeMismatch, UnknownSymbol, @@ -496,7 +503,8 @@ using TypeErrorData = Variant< CheckedFunctionIncorrectArgs, UnexpectedTypeInSubtyping, UnexpectedTypePackInSubtyping, - ExplicitFunctionAnnotationRecommended>; + ExplicitFunctionAnnotationRecommended, + UserDefinedTypeFunctionError>; struct TypeErrorSummary { diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 0fdcce16..d3fc6ad3 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -3,6 +3,7 @@ #include #include +#include namespace Luau { @@ -31,6 +32,13 @@ struct ModuleInfo bool optional = false; }; +struct RequireSuggestion +{ + std::string label; + std::string fullPath; +}; +using RequireSuggestions = std::vector; + struct FileResolver { virtual ~FileResolver() {} @@ -51,6 +59,11 @@ struct FileResolver { return std::nullopt; } + + virtual std::optional getRequireSuggestions(const ModuleName& requirer, const std::optional& pathString) const + { + return std::nullopt; + } }; struct NullFileResolver : FileResolver diff --git a/Analysis/include/Luau/FragmentAutocomplete.h b/Analysis/include/Luau/FragmentAutocomplete.h new file mode 100644 index 00000000..4c32f90b --- /dev/null +++ b/Analysis/include/Luau/FragmentAutocomplete.h @@ -0,0 +1,146 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Parser.h" +#include "Luau/AutocompleteTypes.h" +#include "Luau/DenseHash.h" +#include "Luau/Module.h" +#include "Luau/Frontend.h" + +#include +#include + +namespace Luau +{ +struct FrontendOptions; + +enum class FragmentTypeCheckStatus +{ + SkipAutocomplete, + Success, +}; + +struct FragmentAutocompleteAncestryResult +{ + DenseHashMap localMap{AstName()}; + std::vector localStack; + std::vector ancestry; + AstStat* nearestStatement = nullptr; +}; + +struct FragmentParseResult +{ + std::string fragmentToParse; + AstStatBlock* root = nullptr; + std::vector ancestry; + AstStat* nearestStatement = nullptr; + std::vector commentLocations; + std::unique_ptr alloc = std::make_unique(); +}; + +struct FragmentTypeCheckResult +{ + ModulePtr incrementalModule = nullptr; + ScopePtr freshScope; + std::vector ancestry; +}; + +struct FragmentAutocompleteResult +{ + ModulePtr incrementalModule; + Scope* freshScope; + TypeArena arenaForAutocomplete; + AutocompleteResult acResults; +}; + +FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos); + +std::optional parseFragment( + const SourceModule& srcModule, + std::string_view src, + const Position& cursorPos, + std::optional fragmentEndPosition +); + +std::pair typecheckFragment( + Frontend& frontend, + const ModuleName& moduleName, + const Position& cursorPos, + std::optional opts, + std::string_view src, + std::optional fragmentEndPosition +); + +FragmentAutocompleteResult fragmentAutocomplete( + Frontend& frontend, + std::string_view src, + const ModuleName& moduleName, + Position cursorPosition, + std::optional opts, + StringCompletionCallback callback, + std::optional fragmentEndPosition = std::nullopt +); + +enum class FragmentAutocompleteStatus +{ + Success, + FragmentTypeCheckFail, + InternalIce +}; + +struct FragmentAutocompleteStatusResult +{ + FragmentAutocompleteStatus status; + std::optional result; +}; + +struct FragmentContext +{ + std::string_view newSrc; + const ParseResult& newAstRoot; + std::optional opts; + std::optional DEPRECATED_fragmentEndPosition; +}; + +/** + * @brief Attempts to compute autocomplete suggestions from the fragment context. + * + * This function computes autocomplete suggestions using outdated frontend typechecking data + * by patching the fragment context of the new script source content. + * + * @param frontend The Luau Frontend data structure, which may contain outdated typechecking data. + * + * @param moduleName The name of the target module, specifying which script the caller wants to request autocomplete for. + * + * @param cursorPosition The position in the script where the caller wants to trigger autocomplete. + * + * @param context The fragment context that this API will use to patch the outdated typechecking data. + * + * @param stringCompletionCB A callback function that provides autocomplete suggestions for string contexts. + * + * @return + * The status indicating whether `fragmentAutocomplete` ran successfully or failed, along with the reason for failure. + * Also includes autocomplete suggestions if the status is successful. + * + * @usage + * FragmentAutocompleteStatusResult acStatusResult; + * if (shouldFragmentAC) + * acStatusResult = Luau::tryFragmentAutocomplete(...); + * + * if (acStatusResult.status != Successful) + * { + * frontend.check(moduleName, options); + * acStatusResult.acResult = Luau::autocomplete(...); + * } + * return convertResultWithContext(acStatusResult.acResult); + */ +FragmentAutocompleteStatusResult tryFragmentAutocomplete( + Frontend& frontend, + const ModuleName& moduleName, + Position cursorPosition, + FragmentContext context, + StringCompletionCallback stringCompletionCB +); + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index d8a40d24..dc443777 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -7,6 +7,7 @@ #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" #include "Luau/Scope.h" +#include "Luau/Set.h" #include "Luau/TypeCheckLimits.h" #include "Luau/Variant.h" #include "Luau/AnyTypeSummary.h" @@ -44,21 +45,6 @@ struct LoadDefinitionFileResult std::optional parseMode(const std::vector& hotcomments); -std::vector parsePathExpr(const AstExpr& pathExpr); - -// Exported only for convenient testing. -std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& expr); - -/** Try to convert an AST fragment into a ModuleName. - * Returns std::nullopt if the expression cannot be resolved. This will most likely happen in cases where - * the import path involves some dynamic computation that we cannot see into at typechecking time. - * - * Unintuitively, weirdly-formulated modules (like game.Parent.Parent.Parent.Foo) will successfully produce a ModuleName - * as long as it falls within the permitted syntax. This is ok because we will fail to find the module and produce an - * error when we try during typechecking. - */ -std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& expr); - struct SourceNode { bool hasDirtySourceModule() const @@ -71,13 +57,32 @@ struct SourceNode return forAutocomplete ? dirtyModuleForAutocomplete : dirtyModule; } + bool hasInvalidModuleDependency(bool forAutocomplete) const + { + return forAutocomplete ? invalidModuleDependencyForAutocomplete : invalidModuleDependency; + } + + void setInvalidModuleDependency(bool value, bool forAutocomplete) + { + if (forAutocomplete) + invalidModuleDependencyForAutocomplete = value; + else + invalidModuleDependency = value; + } + ModuleName name; std::string humanReadableName; DenseHashSet requireSet{{}}; std::vector> requireLocations; + Set dependents{{}}; + bool dirtySourceModule = true; bool dirtyModule = true; bool dirtyModuleForAutocomplete = true; + + bool invalidModuleDependency = true; + bool invalidModuleDependencyForAutocomplete = true; + double autocompleteLimitsMult = 1.0; }; @@ -132,7 +137,7 @@ struct FrontendModuleResolver : ModuleResolver std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; std::string getHumanReadableModuleName(const ModuleName& moduleName) const override; - void setModule(const ModuleName& moduleName, ModulePtr module); + bool setModule(const ModuleName& moduleName, ModulePtr module); void clearModules(); private: @@ -166,9 +171,13 @@ struct Frontend // Parse and typecheck module graph CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess + bool allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete = false) const; + bool isDirty(const ModuleName& name, bool forAutocomplete = false) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); + void traverseDependents(const ModuleName& name, std::function processSubtree); + /** Borrow a pointer into the SourceModule cache. * * Returns nullptr if we don't have it. This could mean that the script @@ -209,6 +218,7 @@ struct Frontend ); std::optional getCheckResult(const ModuleName& name, bool accumulateNested, bool forAutocomplete = false); + std::vector getRequiredScripts(const ModuleName& name); private: ModulePtr check( diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h index 0fd2817a..73345f98 100644 --- a/Analysis/include/Luau/Instantiation.h +++ b/Analysis/include/Luau/Instantiation.h @@ -60,7 +60,7 @@ struct ReplaceGenerics : Substitution }; // A substitution which replaces generic functions by monomorphic functions -struct Instantiation : Substitution +struct Instantiation final : Substitution { Instantiation(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope) : Substitution(log, arena) diff --git a/Analysis/include/Luau/Instantiation2.h b/Analysis/include/Luau/Instantiation2.h index c9215fad..ee949388 100644 --- a/Analysis/include/Luau/Instantiation2.h +++ b/Analysis/include/Luau/Instantiation2.h @@ -53,7 +53,7 @@ struct Replacer : Substitution }; // A substitution which replaces generic functions by monomorphic functions -struct Instantiation2 : Substitution +struct Instantiation2 final : Substitution { // Mapping from generic types to free types to be used in instantiation. DenseHashMap genericSubstitutions{nullptr}; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index f909deb8..ebce78cf 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -9,15 +9,24 @@ #include "Luau/Scope.h" #include "Luau/TypeArena.h" #include "Luau/AnyTypeSummary.h" +#include "Luau/DataFlowGraph.h" #include #include #include #include +LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection) + namespace Luau { +using LogLuauProc = void (*)(std::string_view); +extern LogLuauProc logLuau; + +void setLogLuau(LogLuauProc ll); +void resetLogLuauProc(); + struct Module; struct AnyTypeSummary; @@ -54,6 +63,7 @@ struct SourceModule } }; +bool isWithinComment(const std::vector& commentLocations, Position pos); bool isWithinComment(const SourceModule& sourceModule, Position pos); bool isWithinComment(const ParseResult& result, Position pos); @@ -67,6 +77,9 @@ struct Module { ~Module(); + // TODO: Clip this when we clip FFlagLuauSolverV2 + bool checkedInNewSolver = false; + ModuleName name; std::string humanReadableName; @@ -132,6 +145,11 @@ struct Module TypePackId returnType = nullptr; std::unordered_map exportedTypeBindings; + // Arenas related to the DFG must persist after the DFG no longer exists, as + // Module objects maintain raw pointers to objects in these arenas. + DefArena defArena; + RefinementKeyArena keyArena; + bool hasModuleScope() const; ScopePtr getModuleScope() const; diff --git a/Analysis/include/Luau/ModuleResolver.h b/Analysis/include/Luau/ModuleResolver.h index d892ccd7..59751793 100644 --- a/Analysis/include/Luau/ModuleResolver.h +++ b/Analysis/include/Luau/ModuleResolver.h @@ -20,8 +20,6 @@ struct ModuleResolver virtual ~ModuleResolver() {} /** Compute a ModuleName from an AST fragment. This AST fragment is generally the argument to the require() function. - * - * You probably want to implement this with some variation of pathExprToModuleName. * * @returns The ModuleInfo if the expression is a syntactically legal path. * @returns std::nullopt if we are unable to determine whether or not the expression is a valid path. Type inference will diff --git a/Analysis/include/Luau/NonStrictTypeChecker.h b/Analysis/include/Luau/NonStrictTypeChecker.h index 8e80c762..880d487f 100644 --- a/Analysis/include/Luau/NonStrictTypeChecker.h +++ b/Analysis/include/Luau/NonStrictTypeChecker.h @@ -9,11 +9,14 @@ namespace Luau { struct BuiltinTypes; +struct TypeFunctionRuntime; struct UnifierSharedState; struct TypeCheckLimits; void checkNonStrict( NotNull builtinTypes, + NotNull simplifier, + NotNull typeFunctionRuntime, NotNull ice, NotNull unifierState, NotNull dfg, diff --git a/Analysis/include/Luau/Normalize.h b/Analysis/include/Luau/Normalize.h index d844d211..f014c433 100644 --- a/Analysis/include/Luau/Normalize.h +++ b/Analysis/include/Luau/Normalize.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/EqSatSimplification.h" #include "Luau/NotNull.h" #include "Luau/Set.h" #include "Luau/TypeFwd.h" @@ -21,10 +22,22 @@ struct Scope; using ModulePtr = std::shared_ptr; -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); -bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); -bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); -bool isConsistentSubtype(TypePackId subTy, TypePackId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice); +bool isSubtype( + TypeId subTy, + TypeId superTy, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +); +bool isSubtype( + TypePackId subPack, + TypePackId superPack, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +); class TypeIds { @@ -336,6 +349,7 @@ struct NormalizedType }; +using SeenTablePropPairs = Set, TypeIdPairHash>; class Normalizer { @@ -390,7 +404,13 @@ public: void unionTablesWithTable(TypeIds& heres, TypeId there); void unionTables(TypeIds& heres, const TypeIds& theres); NormalizationResult unionNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); - NormalizationResult unionNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes, int ignoreSmallerTyvars = -1); + NormalizationResult unionNormalWithTy( + NormalizedType& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes, + int ignoreSmallerTyvars = -1 + ); // ------- Negations std::optional negateNormal(const NormalizedType& here); @@ -407,16 +427,26 @@ public: void intersectClassesWithClass(NormalizedClassType& heres, TypeId there); void intersectStrings(NormalizedStringType& here, const NormalizedStringType& there); std::optional intersectionOfTypePacks(TypePackId here, TypePackId there); - std::optional intersectionOfTables(TypeId here, TypeId there, Set& seenSet); - void intersectTablesWithTable(TypeIds& heres, TypeId there, Set& seenSetTypes); + std::optional intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSet); + void intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSetTypes); void intersectTables(TypeIds& heres, const TypeIds& theres); std::optional intersectionOfFunctions(TypeId here, TypeId there); void intersectFunctionsWithFunction(NormalizedFunctionType& heress, TypeId there); void intersectFunctions(NormalizedFunctionType& heress, const NormalizedFunctionType& theress); - NormalizationResult intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set& seenSetTypes); + NormalizationResult intersectTyvarsWithTy( + NormalizedTyvars& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes + ); NormalizationResult intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars = -1); - NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes); - NormalizationResult normalizeIntersections(const std::vector& intersections, NormalizedType& outType, Set& seenSet); + NormalizationResult intersectNormalWithTy(NormalizedType& here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSetTypes); + NormalizationResult normalizeIntersections( + const std::vector& intersections, + NormalizedType& outType, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSet + ); // Check for inhabitance NormalizationResult isInhabited(TypeId ty); @@ -426,7 +456,7 @@ public: // Check for intersections being inhabited NormalizationResult isIntersectionInhabited(TypeId left, TypeId right); - NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, Set& seenSet); + NormalizationResult isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set& seenSet); // -------- Convert back from a normalized type to a type TypeId typeFromNormal(const NormalizedType& norm); diff --git a/Analysis/include/Luau/OverloadResolution.h b/Analysis/include/Luau/OverloadResolution.h index 9a2974a5..d85d769e 100644 --- a/Analysis/include/Luau/OverloadResolution.h +++ b/Analysis/include/Luau/OverloadResolution.h @@ -2,12 +2,13 @@ #pragma once #include "Luau/Ast.h" -#include "Luau/InsertionOrderedMap.h" -#include "Luau/NotNull.h" -#include "Luau/TypeFwd.h" -#include "Luau/Location.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Error.h" +#include "Luau/InsertionOrderedMap.h" +#include "Luau/Location.h" +#include "Luau/NotNull.h" #include "Luau/Subtyping.h" +#include "Luau/TypeFwd.h" namespace Luau { @@ -34,7 +35,9 @@ struct OverloadResolver OverloadResolver( NotNull builtinTypes, NotNull arena, + NotNull simplifier, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull scope, NotNull reporter, NotNull limits, @@ -43,7 +46,9 @@ struct OverloadResolver NotNull builtinTypes; NotNull arena; + NotNull simplifier; NotNull normalizer; + NotNull typeFunctionRuntime; NotNull scope; NotNull ice; NotNull limits; @@ -108,7 +113,9 @@ struct SolveResult SolveResult solveFunctionCall( NotNull arena, NotNull builtinTypes, + NotNull simplifier, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter, NotNull limits, NotNull scope, diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index 718a6cc1..beffaa2e 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -11,14 +11,12 @@ namespace Luau { -class AstStat; -class AstExpr; +class AstNode; class AstStatBlock; -struct AstLocal; struct RequireTraceResult { - DenseHashMap exprs{nullptr}; + DenseHashMap exprs{nullptr}; std::vector> requireList; }; diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h index 0e6eff56..4604a2e1 100644 --- a/Analysis/include/Luau/Scope.h +++ b/Analysis/include/Luau/Scope.h @@ -85,12 +85,18 @@ struct Scope void inheritAssignments(const ScopePtr& childScope); void inheritRefinements(const ScopePtr& childScope); + // Track globals that should emit warnings during type checking. + DenseHashSet globalsToWarn{""}; + bool shouldWarnGlobal(std::string name) const; + // For mutually recursive type aliases, it's important that // they use the same types for the same names. // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` // we need that the generic type `T` in both cases is the same, so we use a cache. std::unordered_map typeAliasTypeParameters; std::unordered_map typeAliasTypePackParameters; + + std::optional> interiorFreeTypes; }; // Returns true iff the left scope encloses the right scope. A Scope* equal to diff --git a/Analysis/include/Luau/Simplify.h b/Analysis/include/Luau/Simplify.h index 5b363e96..aab37876 100644 --- a/Analysis/include/Luau/Simplify.h +++ b/Analysis/include/Luau/Simplify.h @@ -19,10 +19,10 @@ struct SimplifyResult DenseHashSet blockedTypes; }; -SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right); SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts); -SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId ty, TypeId discriminant); +SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right); enum class Relation { diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h index 18217a6b..26c4553e 100644 --- a/Analysis/include/Luau/Subtyping.h +++ b/Analysis/include/Luau/Subtyping.h @@ -1,13 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Set.h" +#include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunction.h" #include "Luau/TypeFwd.h" #include "Luau/TypePairHash.h" #include "Luau/TypePath.h" -#include "Luau/TypeFunction.h" -#include "Luau/TypeCheckLimits.h" -#include "Luau/DenseHash.h" #include #include @@ -96,6 +97,22 @@ struct SubtypingEnvironment DenseHashSet upperBound{nullptr}; }; + /* For nested subtyping relationship tests of mapped generic bounds, we keep the outer environment immutable */ + SubtypingEnvironment* parent = nullptr; + + /// Applies `mappedGenerics` to the given type. + /// This is used specifically to substitute for generics in type function instances. + std::optional applyMappedGenerics(NotNull builtinTypes, NotNull arena, TypeId ty); + + const TypeId* tryFindSubstitution(TypeId ty) const; + const SubtypingResult* tryFindSubtypingResult(std::pair subAndSuper) const; + + bool containsMappedType(TypeId ty) const; + bool containsMappedPack(TypePackId tp) const; + + GenericBounds& getMappedTypeBounds(TypeId ty); + TypePackId* getMappedPackBounds(TypePackId tp); + /* * When we encounter a generic over the course of a subtyping test, we need * to tentatively map that generic onto a type on the other side. @@ -112,17 +129,15 @@ struct SubtypingEnvironment DenseHashMap substitutions{nullptr}; DenseHashMap, SubtypingResult, TypePairHash> ephemeralCache{{}}; - - /// Applies `mappedGenerics` to the given type. - /// This is used specifically to substitute for generics in type function instances. - std::optional applyMappedGenerics(NotNull builtinTypes, NotNull arena, TypeId ty); }; struct Subtyping { NotNull builtinTypes; NotNull arena; + NotNull simplifier; NotNull normalizer; + NotNull typeFunctionRuntime; NotNull iceReporter; TypeCheckLimits limits; @@ -142,7 +157,9 @@ struct Subtyping Subtyping( NotNull builtinTypes, NotNull typeArena, + NotNull simplifier, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter ); diff --git a/Analysis/include/Luau/TableLiteralInference.h b/Analysis/include/Luau/TableLiteralInference.h index dd9ecf97..6be1e872 100644 --- a/Analysis/include/Luau/TableLiteralInference.h +++ b/Analysis/include/Luau/TableLiteralInference.h @@ -6,6 +6,8 @@ #include "Luau/NotNull.h" #include "Luau/TypeFwd.h" +#include + namespace Luau { diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index f8001e08..4862e3b4 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -44,6 +44,7 @@ struct ToStringOptions bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. bool hideFunctionSelfArgument = false; // If true, `self: X` will be omitted from the function signature if the function has self + bool useQuestionMarks = true; // If true, use a postfix ? for options, else write them out as unions that include nil. size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypes size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); size_t compositeTypesSingleLineLimit = 5; // The number of type elements permitted on a single line when printing type unions/intersections diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 951f89ee..de8665e5 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -65,11 +65,10 @@ T* getMutable(PendingTypePack* pending) // Log of what TypeIds we are rebinding, to be committed later. struct TxnLog { - explicit TxnLog(bool useScopes = false) + explicit TxnLog() : typeVarChanges(nullptr) , typePackChanges(nullptr) , ownedSeen() - , useScopes(useScopes) , sharedSeen(&ownedSeen) { } diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index a43dbff9..890c7078 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -31,6 +31,7 @@ namespace Luau struct TypeArena; struct Scope; using ScopePtr = std::shared_ptr; +struct Module; struct TypeFunction; struct Constraint; @@ -68,12 +69,16 @@ using Name = std::string; // A free type is one whose exact shape has yet to be fully determined. struct FreeType { + // New constructors + explicit FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound); + // This one got promoted to explicit + explicit FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound); + explicit FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound); + // Old constructors explicit FreeType(TypeLevel level); explicit FreeType(Scope* scope); FreeType(Scope* scope, TypeLevel level); - FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound); - int index; TypeLevel level; Scope* scope = nullptr; @@ -130,14 +135,14 @@ struct BlockedType BlockedType(); int index; - Constraint* getOwner() const; - void setOwner(Constraint* newOwner); - void replaceOwner(Constraint* newOwner); + const Constraint* getOwner() const; + void setOwner(const Constraint* newOwner); + void replaceOwner(const Constraint* newOwner); private: // The constraint that is intended to unblock this type. Other constraints // should block on this constraint if present. - Constraint* owner = nullptr; + const Constraint* owner = nullptr; }; struct PrimitiveType @@ -278,9 +283,6 @@ struct WithPredicate } }; -using MagicFunction = std::function>(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate)>; - struct MagicFunctionCallContext { NotNull solver; @@ -290,7 +292,6 @@ struct MagicFunctionCallContext TypePackId result; }; -using DcrMagicFunction = std::function; struct MagicRefinementContext { NotNull scope; @@ -307,8 +308,30 @@ struct MagicFunctionTypeCheckContext NotNull checkScope; }; -using DcrMagicRefinement = void (*)(const MagicRefinementContext&); -using DcrMagicFunctionTypeCheck = std::function; +struct MagicFunction +{ + virtual std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) = 0; + + // Callback to allow custom typechecking of builtin function calls whose argument types + // will only be resolved after constraint solving. For example, the arguments to string.format + // have types that can only be decided after parsing the format string and unifying + // with the passed in values, but the correctness of the call can only be decided after + // all the types have been finalized. + virtual bool infer(const MagicFunctionCallContext&) = 0; + virtual void refine(const MagicRefinementContext&) {} + + // If a magic function needs to do its own special typechecking, do it here. + // Returns true if magic typechecking was performed. Return false if the + // default typechecking logic should run. + virtual bool typeCheck(const MagicFunctionTypeCheckContext&) + { + return false; + } + + virtual ~MagicFunction() {} +}; + struct FunctionType { // Global monomorphic function @@ -366,16 +389,7 @@ struct FunctionType Scope* scope = nullptr; TypePackId argTypes; TypePackId retTypes; - MagicFunction magicFunction = nullptr; - DcrMagicFunction dcrMagicFunction = nullptr; - DcrMagicRefinement dcrMagicRefinement = nullptr; - - // Callback to allow custom typechecking of builtin function calls whose argument types - // will only be resolved after constraint solving. For example, the arguments to string.format - // have types that can only be decided after parsing the format string and unifying - // with the passed in values, but the correctness of the call can only be decided after - // all the types have been finalized. - DcrMagicFunctionTypeCheck dcrMagicTypeCheck = nullptr; + std::shared_ptr magic = nullptr; bool hasSelf; // `hasNoFreeOrGenericTypes` should be true if and only if the type does not have any free or generic types present inside it. @@ -598,6 +612,19 @@ struct ClassType } }; +// Data required to initialize a user-defined function and its environment +struct UserDefinedFunctionData +{ + // Store a weak module reference to ensure the lifetime requirements are preserved + std::weak_ptr owner; + + // References to AST elements are owned by the Module allocator which also stores this type + AstStatTypeFunction* definition = nullptr; + + DenseHashMap> environment{""}; + DenseHashMap environment_DEPRECATED{""}; +}; + /** * An instance of a type function that has not yet been reduced to a more concrete * type. The constraint solver receives a constraint to reduce each @@ -612,21 +639,21 @@ struct TypeFunctionInstanceType std::vector typeArguments; std::vector packArguments; - std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs - std::optional userFuncBody; // Body of the user-defined type function; only available for UDTFs + std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs + UserDefinedFunctionData userFuncData; TypeFunctionInstanceType( NotNull function, std::vector typeArguments, std::vector packArguments, - std::optional userFuncName = std::nullopt, - std::optional userFuncBody = std::nullopt + std::optional userFuncName, + UserDefinedFunctionData userFuncData ) : function(function) , typeArguments(typeArguments) , packArguments(packArguments) , userFuncName(userFuncName) - , userFuncBody(userFuncBody) + , userFuncData(userFuncData) { } @@ -643,6 +670,13 @@ struct TypeFunctionInstanceType , packArguments(packArguments) { } + + TypeFunctionInstanceType(NotNull function, std::vector typeArguments, std::vector packArguments) + : function{function} + , typeArguments(typeArguments) + , packArguments(packArguments) + { + } }; /** Represents a pending type alias instantiation. @@ -670,6 +704,11 @@ struct AnyType { }; +// A special, trivial type for the refinement system that is always eliminated from intersections. +struct NoRefineType +{ +}; + // `T | U` struct UnionType { @@ -737,7 +776,7 @@ struct NegationType TypeId ty; }; -using ErrorType = Unifiable::Error; +using ErrorType = Unifiable::Error; using TypeVariant = Unifiable::Variant< TypeId, @@ -758,6 +797,7 @@ using TypeVariant = Unifiable::Variant< UnknownType, NeverType, NegationType, + NoRefineType, TypeFunctionInstanceType>; struct Type final @@ -803,6 +843,13 @@ struct Type final Type& operator=(const TypeVariant& rhs); Type& operator=(TypeVariant&& rhs); + Type(Type&&) = default; + Type& operator=(Type&&) = default; + + Type clone() const; + +private: + Type(const Type&) = default; Type& operator=(const Type& rhs); }; @@ -952,6 +999,7 @@ public: const TypeId unknownType; const TypeId neverType; const TypeId errorType; + const TypeId noRefineType; const TypeId falsyType; const TypeId truthyType; @@ -1159,6 +1207,10 @@ TypeId freshType(NotNull arena, NotNull builtinTypes, S using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); +// A tag to mark a type which doesn't derive directly from the root type as overriding the return of `typeof`. +// Any classes which derive from this type will have typeof return this type. +static constexpr char kTypeofRootTag[] = "typeofRoot"; + void attachTag(TypeId ty, const std::string& tagName); void attachTag(Property& prop, const std::string& tagName); diff --git a/Analysis/include/Luau/TypeArena.h b/Analysis/include/Luau/TypeArena.h index 4f8aea87..ebefa41f 100644 --- a/Analysis/include/Luau/TypeArena.h +++ b/Analysis/include/Luau/TypeArena.h @@ -32,9 +32,13 @@ struct TypeArena TypeId addTV(Type&& tv); - TypeId freshType(TypeLevel level); - TypeId freshType(Scope* scope); - TypeId freshType(Scope* scope, TypeLevel level); + TypeId freshType(NotNull builtins, TypeLevel level); + TypeId freshType(NotNull builtins, Scope* scope); + TypeId freshType(NotNull builtins, Scope* scope, TypeLevel level); + + TypeId freshType_DEPRECATED(TypeLevel level); + TypeId freshType_DEPRECATED(Scope* scope); + TypeId freshType_DEPRECATED(Scope* scope, TypeLevel level); TypePackId freshTypePack(Scope* scope); diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index 0faf036d..0c52b1f1 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -2,15 +2,16 @@ #pragma once -#include "Luau/Error.h" -#include "Luau/NotNull.h" #include "Luau/Common.h" -#include "Luau/TypeUtils.h" +#include "Luau/EqSatSimplification.h" +#include "Luau/Error.h" +#include "Luau/Normalize.h" +#include "Luau/NotNull.h" +#include "Luau/Subtyping.h" #include "Luau/Type.h" #include "Luau/TypeFwd.h" #include "Luau/TypeOrPack.h" -#include "Luau/Normalize.h" -#include "Luau/Subtyping.h" +#include "Luau/TypeUtils.h" namespace Luau { @@ -60,7 +61,9 @@ struct Reasonings void check( NotNull builtinTypes, - NotNull sharedState, + NotNull simplifier, + NotNull typeFunctionRuntime, + NotNull unifierState, NotNull limits, DcrLogger* logger, const SourceModule& sourceModule, @@ -70,6 +73,8 @@ void check( struct TypeChecker2 { NotNull builtinTypes; + NotNull simplifier; + NotNull typeFunctionRuntime; DcrLogger* logger; const NotNull limits; const NotNull ice; @@ -88,6 +93,8 @@ struct TypeChecker2 TypeChecker2( NotNull builtinTypes, + NotNull simplifier, + NotNull typeFunctionRuntime, NotNull unifierState, NotNull limits, DcrLogger* logger, @@ -109,14 +116,14 @@ private: std::optional pushStack(AstNode* node); void checkForInternalTypeFunction(TypeId ty, Location location); TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location); - TypePackId lookupPack(AstExpr* expr); + TypePackId lookupPack(AstExpr* expr) const; TypeId lookupType(AstExpr* expr); TypeId lookupAnnotation(AstType* annotation); - std::optional lookupPackAnnotation(AstTypePack* annotation); - TypeId lookupExpectedType(AstExpr* expr); - TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena); + std::optional lookupPackAnnotation(AstTypePack* annotation) const; + TypeId lookupExpectedType(AstExpr* expr) const; + TypePackId lookupExpectedPack(AstExpr* expr, TypeArena& arena) const; TypePackId reconstructPack(AstArray exprs, TypeArena& arena); - Scope* findInnermostScope(Location location); + Scope* findInnermostScope(Location location) const; void visit(AstStat* stat); void visit(AstStatIf* ifStatement); void visit(AstStatWhile* whileStatement); @@ -153,7 +160,7 @@ private: void visit(AstExprVarargs* expr); void visitCall(AstExprCall* call); void visit(AstExprCall* call); - std::optional tryStripUnionFromNil(TypeId ty); + std::optional tryStripUnionFromNil(TypeId ty) const; TypeId stripFromNilAndReport(TypeId ty, const Location& location); void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context, TypeId astIndexExprTy); void visit(AstExprIndexName* indexName, ValueContext context); @@ -168,7 +175,7 @@ private: void visit(AstExprInterpString* interpString); void visit(AstExprError* expr); TypeId flattenPack(TypePackId pack); - void visitGenerics(AstArray generics, AstArray genericPacks); + void visitGenerics(AstArray generics, AstArray genericPacks); void visit(AstType* ty); void visit(AstTypeReference* ty); void visit(AstTypeTable* table); @@ -210,6 +217,9 @@ private: std::vector& errors ); + // Avoid duplicate warnings being emitted for the same global variable. + DenseHashSet warnedGlobals{""}; + void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) const; bool isErrorSuppressing(Location loc, TypeId ty); bool isErrorSuppressing(Location loc1, TypeId ty1, Location loc2, TypeId ty2); diff --git a/Analysis/include/Luau/TypeFunction.h b/Analysis/include/Luau/TypeFunction.h index c686f482..1c97550f 100644 --- a/Analysis/include/Luau/TypeFunction.h +++ b/Analysis/include/Luau/TypeFunction.h @@ -1,29 +1,68 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/ConstraintSolver.h" +#include "Luau/Constraint.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Error.h" #include "Luau/NotNull.h" #include "Luau/TypeCheckLimits.h" +#include "Luau/TypeFunctionRuntime.h" #include "Luau/TypeFwd.h" #include #include #include +struct lua_State; + namespace Luau { struct TypeArena; struct TxnLog; +struct ConstraintSolver; class Normalizer; +using StateRef = std::unique_ptr; + +struct TypeFunctionRuntime +{ + TypeFunctionRuntime(NotNull ice, NotNull limits); + ~TypeFunctionRuntime(); + + // Return value is an error message if registration failed + std::optional registerFunction(AstStatTypeFunction* function); + + // For user-defined type functions, we store all generated types and packs for the duration of the typecheck + TypedAllocator typeArena; + TypedAllocator typePackArena; + + NotNull ice; + NotNull limits; + + StateRef state; + + // Set of functions which have their environment table initialized + DenseHashSet initialized{nullptr}; + + // Evaluation of type functions should only be performed in the absence of parse errors in the source module + bool allowEvaluation = true; + + // Output created by 'print' function + std::vector messages; + +private: + void prepareState(); +}; + struct TypeFunctionContext { NotNull arena; NotNull builtins; NotNull scope; + NotNull simplifier; NotNull normalizer; + NotNull typeFunctionRuntime; NotNull ice; NotNull limits; @@ -32,33 +71,26 @@ struct TypeFunctionContext // The constraint being reduced in this run of the reduction const Constraint* constraint; - std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs - std::optional userFuncBody; // Body of the user-defined type function; only available for UDTFs + std::optional userFuncName; // Name of the user-defined type function; only available for UDTFs - TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint) - : arena(cs->arena) - , builtins(cs->builtinTypes) - , scope(scope) - , normalizer(cs->normalizer) - , ice(NotNull{&cs->iceReporter}) - , limits(NotNull{&cs->limits}) - , solver(cs.get()) - , constraint(constraint.get()) - { - } + TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint); TypeFunctionContext( NotNull arena, NotNull builtins, NotNull scope, + NotNull simplifier, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull ice, NotNull limits ) : arena(arena) , builtins(builtins) , scope(scope) + , simplifier(simplifier) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , ice(ice) , limits(limits) , solver(nullptr) @@ -66,7 +98,17 @@ struct TypeFunctionContext { } - NotNull pushConstraint(ConstraintV&& c); + NotNull pushConstraint(ConstraintV&& c) const; +}; + +enum class Reduction +{ + // The type function is either known to be reducible or the determination is blocked. + MaybeOk, + // The type function is known to be irreducible, but maybe not be erroneous, e.g. when it's over generics or free types. + Irreducible, + // The type function is known to be irreducible, and is definitely erroneous. + Erroneous, }; /// Represents a reduction result, which may have successfully reduced the type, @@ -75,19 +117,25 @@ struct TypeFunctionContext template struct TypeFunctionReductionResult { + /// The result of the reduction, if any. If this is nullopt, the type function /// could not be reduced. std::optional result; - /// Whether the result is uninhabited: whether we know, unambiguously and - /// permanently, whether this type function reduction results in an - /// uninhabitable type. This will trigger an error to be reported. - bool uninhabited; + /// Indicates the status of this reduction: is `Reduction::Irreducible` if + /// the this result indicates the type function is irreducible, and + /// `Reduction::Erroneous` if this result indicates the type function is + /// erroneous. `Reduction::MaybeOk` otherwise. + Reduction reductionStatus; /// Any types that need to be progressed or mutated before the reduction may /// proceed. std::vector blockedTypes; /// Any type packs that need to be progressed or mutated before the /// reduction may proceed. std::vector blockedPacks; + /// A runtime error message from user-defined type functions + std::optional error; + /// Messages printed out from user-defined type functions + std::vector messages; }; template @@ -121,6 +169,7 @@ struct TypePackFunction struct FunctionGraphReductionResult { ErrorVec errors; + ErrorVec messages; DenseHashSet blockedTypes{nullptr}; DenseHashSet blockedPacks{nullptr}; DenseHashSet reducedTypes{nullptr}; @@ -192,6 +241,9 @@ struct BuiltinTypeFunctions TypeFunction indexFunc; TypeFunction rawgetFunc; + TypeFunction setmetatableFunc; + TypeFunction getmetatableFunc; + void addToScope(NotNull arena, NotNull scope) const; }; diff --git a/Analysis/include/Luau/TypeFunctionRuntime.h b/Analysis/include/Luau/TypeFunctionRuntime.h new file mode 100644 index 00000000..e6cc4d26 --- /dev/null +++ b/Analysis/include/Luau/TypeFunctionRuntime.h @@ -0,0 +1,298 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Variant.h" +#include "Luau/TypeFwd.h" + +#include +#include +#include +#include + +using lua_State = struct lua_State; + +namespace Luau +{ + +void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize); + +// Replica of types from Type.h +struct TypeFunctionType; +using TypeFunctionTypeId = const TypeFunctionType*; + +struct TypeFunctionTypePackVar; +using TypeFunctionTypePackId = const TypeFunctionTypePackVar*; + +struct TypeFunctionPrimitiveType +{ + enum Type + { + NilType, + Boolean, + Number, + String, + Thread, + Buffer, + }; + + Type type; + + TypeFunctionPrimitiveType(Type type) + : type(type) + { + } +}; + +struct TypeFunctionBooleanSingleton +{ + bool value = false; +}; + +struct TypeFunctionStringSingleton +{ + std::string value; +}; + +using TypeFunctionSingletonVariant = Variant; + +struct TypeFunctionSingletonType +{ + TypeFunctionSingletonVariant variant; + + explicit TypeFunctionSingletonType(TypeFunctionSingletonVariant variant) + : variant(std::move(variant)) + { + } +}; + +template +const T* get(const TypeFunctionSingletonType* tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&tv->variant) : nullptr; +} + +template +T* getMutable(const TypeFunctionSingletonType* tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&const_cast(tv)->variant) : nullptr; +} + +struct TypeFunctionUnionType +{ + std::vector components; +}; + +struct TypeFunctionIntersectionType +{ + std::vector components; +}; + +struct TypeFunctionAnyType +{ +}; + +struct TypeFunctionUnknownType +{ +}; + +struct TypeFunctionNeverType +{ +}; + +struct TypeFunctionNegationType +{ + TypeFunctionTypeId type; +}; + +struct TypeFunctionTypePack +{ + std::vector head; + std::optional tail; +}; + +struct TypeFunctionVariadicTypePack +{ + TypeFunctionTypeId type; +}; + +struct TypeFunctionGenericTypePack +{ + bool isNamed = false; + + std::string name; +}; + +using TypeFunctionTypePackVariant = Variant; + +struct TypeFunctionTypePackVar +{ + TypeFunctionTypePackVariant type; + + TypeFunctionTypePackVar(TypeFunctionTypePackVariant type) + : type(std::move(type)) + { + } + + bool operator==(const TypeFunctionTypePackVar& rhs) const; +}; + +struct TypeFunctionFunctionType +{ + std::vector generics; + std::vector genericPacks; + + TypeFunctionTypePackId argTypes; + TypeFunctionTypePackId retTypes; +}; + +template +const T* get(TypeFunctionTypePackId tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&tv->type) : nullptr; +} + +template +T* getMutable(TypeFunctionTypePackId tv) +{ + LUAU_ASSERT(tv); + + return tv ? get_if(&const_cast(tv)->type) : nullptr; +} + +struct TypeFunctionTableIndexer +{ + TypeFunctionTableIndexer(TypeFunctionTypeId keyType, TypeFunctionTypeId valueType) + : keyType(keyType) + , valueType(valueType) + { + } + + TypeFunctionTypeId keyType; + TypeFunctionTypeId valueType; +}; + +struct TypeFunctionProperty +{ + static TypeFunctionProperty readonly(TypeFunctionTypeId ty); + static TypeFunctionProperty writeonly(TypeFunctionTypeId ty); + static TypeFunctionProperty rw(TypeFunctionTypeId ty); // Shared read-write type. + static TypeFunctionProperty rw(TypeFunctionTypeId read, TypeFunctionTypeId write); // Separate read-write type. + + bool isReadOnly() const; + bool isWriteOnly() const; + + std::optional readTy; + std::optional writeTy; +}; + +struct TypeFunctionTableType +{ + using Name = std::string; + using Props = std::map; + + Props props; + + std::optional indexer; + + // Should always be a TypeFunctionTableType + std::optional metatable; +}; + +struct TypeFunctionClassType +{ + using Name = std::string; + using Props = std::map; + + Props props; + + std::optional indexer; + + std::optional metatable; // metaclass? + + // this was mistaken, and we should actually be keeping separate read/write types here. + std::optional parent_DEPRECATED; + + std::optional readParent; + std::optional writeParent; + + TypeId classTy; + + std::string name_DEPRECATED; +}; + +struct TypeFunctionGenericType +{ + bool isNamed = false; + bool isPack = false; + + std::string name; +}; + +using TypeFunctionTypeVariant = Luau::Variant< + TypeFunctionPrimitiveType, + TypeFunctionAnyType, + TypeFunctionUnknownType, + TypeFunctionNeverType, + TypeFunctionSingletonType, + TypeFunctionUnionType, + TypeFunctionIntersectionType, + TypeFunctionNegationType, + TypeFunctionFunctionType, + TypeFunctionTableType, + TypeFunctionClassType, + TypeFunctionGenericType>; + +struct TypeFunctionType +{ + TypeFunctionTypeVariant type; + + TypeFunctionType(TypeFunctionTypeVariant type) + : type(std::move(type)) + { + } + + bool operator==(const TypeFunctionType& rhs) const; +}; + +template +const T* get(TypeFunctionTypeId tv) +{ + LUAU_ASSERT(tv); + + return tv ? Luau::get_if(&tv->type) : nullptr; +} + +template +T* getMutable(TypeFunctionTypeId tv) +{ + LUAU_ASSERT(tv); + + return tv ? Luau::get_if(&const_cast(tv)->type) : nullptr; +} + +std::optional checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult); + +TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type); +TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type); + +void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type); + +bool isTypeUserData(lua_State* L, int idx); +TypeFunctionTypeId getTypeUserData(lua_State* L, int idx); +std::optional optionalTypeUserData(lua_State* L, int idx); + +void registerTypesLibrary(lua_State* L); +void registerTypeUserData(lua_State* L); + +void setTypeFunctionEnvironment(lua_State* L); + +void resetTypeFunctionState(lua_State* L); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h b/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h new file mode 100644 index 00000000..040a3092 --- /dev/null +++ b/Analysis/include/Luau/TypeFunctionRuntimeBuilder.h @@ -0,0 +1,50 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Type.h" +#include "Luau/TypeFunction.h" +#include "Luau/TypeFunctionRuntime.h" + +namespace Luau +{ + +using Kind = Variant; + +template +const T* get(const Kind& kind) +{ + return get_if(&kind); +} + +using TypeFunctionKind = Variant; + +template +const T* get(const TypeFunctionKind& tfkind) +{ + return get_if(&tfkind); +} + +struct TypeFunctionRuntimeBuilderState +{ + NotNull ctx; + + // Mapping of class name to ClassType + // Invariant: users can not create a new class types -> any class types that get deserialized must have been an argument to the type function + // Using this invariant, whenever a ClassType is serialized, we can put it into this map + // whenever a ClassType is deserialized, we can use this map to return the corresponding value + DenseHashMap classesSerialized_DEPRECATED{{}}; + + // List of errors that occur during serialization/deserialization + // At every iteration of serialization/deserialzation, if this list.size() != 0, we halt the process + std::vector errors{}; + + TypeFunctionRuntimeBuilderState(NotNull ctx) + : ctx(ctx) + { + } +}; + +TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state); +TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 7f2e29b5..2b8dbc3a 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -399,8 +399,8 @@ private: const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, - const AstArray& genericPackNames, + const AstArray& genericNames, + const AstArray& genericPackNames, bool useCache = false ); diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 1065b947..8509da03 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -52,7 +52,7 @@ struct GenericTypePack }; using BoundTypePack = Unifiable::Bound; -using ErrorTypePack = Unifiable::Error; +using ErrorTypePack = Unifiable::Error; using TypePackVariant = Unifiable::Variant; diff --git a/Analysis/include/Luau/TypePath.h b/Analysis/include/Luau/TypePath.h index 50c75da4..2af5185d 100644 --- a/Analysis/include/Luau/TypePath.h +++ b/Analysis/include/Luau/TypePath.h @@ -51,6 +51,8 @@ struct Index /// Represents fields of a type or pack that contain a type. enum class TypeField { + /// The table of a metatable type. + Table, /// The metatable of a type. This could be a metatable type, a primitive /// type, a class type, or perhaps even a string singleton type. Metatable, diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 92be19d1..c3bed421 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -40,7 +40,7 @@ struct InConditionalContext TypeContext* typeContext; TypeContext oldValue; - InConditionalContext(TypeContext* c) + explicit InConditionalContext(TypeContext* c) : typeContext(c) , oldValue(*c) { @@ -248,4 +248,45 @@ std::optional follow(std::optional ty) return std::nullopt; } +/** + * Returns whether or not expr is a literal expression, for example: + * - Scalar literals (numbers, booleans, strings, nil) + * - Table literals + * - Lambdas (a "function literal") + */ +bool isLiteral(const AstExpr* expr); + +/** + * Given a table literal and a mapping from expression to type, determine + * whether any literal expression in this table depends on any blocked types. + * This is used as a precondition for bidirectional inference: be warned that + * the behavior of this algorithm is tightly coupled to that of bidirectional + * inference. + * @param expr Expression to search + * @param astTypes Mapping from AST node to TypeID + * @returns A vector of blocked types + */ +std::vector findBlockedTypesIn(AstExprTable* expr, NotNull> astTypes); + +/** + * Given a function call and a mapping from expression to type, determine + * whether the type of any argument in said call in depends on a blocked types. + * This is used as a precondition for bidirectional inference: be warned that + * the behavior of this algorithm is tightly coupled to that of bidirectional + * inference. + * @param expr Expression to search + * @param astTypes Mapping from AST node to TypeID + * @returns A vector of blocked types + */ +std::vector findBlockedArgTypesIn(AstExprCall* expr, NotNull> astTypes); + +/** + * Given a scope and a free type, find the closest parent that has a present + * `interiorFreeTypes` and append the given type to said list. This list will + * be generalized when the requiste `GeneralizationConstraint` is resolved. + * @param scope Initial scope this free type was attached to + * @param ty Free type to track. + */ +void trackInteriorFreeType(Scope* scope, TypeId ty); + } // namespace Luau diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 79b3b7de..132eda96 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -3,6 +3,7 @@ #include "Luau/Variant.h" +#include #include namespace Luau @@ -94,19 +95,29 @@ struct Bound Id boundTo; }; +template struct Error { // This constructor has to be public, since it's used in Type and TypePack, // but shouldn't be called directly. Please use errorRecoveryType() instead. - Error(); + explicit Error(); + + explicit Error(Id synthetic) + : synthetic{synthetic} + { + } int index; + // This is used to create an error that can be rendered out using this field + // as appropriate metadata for communicating it to the user. + std::optional synthetic; + private: static int nextIndex; }; template -using Variant = Luau::Variant, Error, Value...>; +using Variant = Luau::Variant, Error, Value...>; } // namespace Luau::Unifiable diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index b0a855d3..8d0f2806 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -93,10 +93,6 @@ struct Unifier Unifier(NotNull normalizer, NotNull scope, const Location& location, Variance variance, TxnLog* parentLog = nullptr); - // Configure the Unifier to test for scope subsumption via embedded Scope - // pointers rather than TypeLevels. - void enableNewSolver(); - // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId subTy, TypeId superTy); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); @@ -169,7 +165,6 @@ private: std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); - TxnLog combineLogsIntoIntersection(std::vector logs); TxnLog combineLogsIntoUnion(std::vector logs); public: @@ -179,7 +174,7 @@ public: bool occursCheck(TypePackId needle, TypePackId haystack, bool reversed); bool occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); - Unifier makeChildUnifier(); + std::unique_ptr makeChildUnifier(); void reportError(TypeError err); LUAU_NOINLINE void reportError(Location location, TypeErrorData data); @@ -195,11 +190,6 @@ private: // Available after regular type pack unification errors std::optional firstPackErrorPos; - - // If true, we do a bunch of small things differently to work better with - // the new type inference engine. Most notably, we use the Scope hierarchy - // directly rather than using TypeLevels. - bool useNewSolver = false; }; void promoteTypeLevels(TxnLog& log, const TypeArena* arena, TypeLevel minLevel, Scope* outerScope, bool useScope, TypePackId tp); diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index de69c17c..bc2acbf1 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -49,6 +49,26 @@ struct UnifierSharedState DenseHashSet tempSeenTp{nullptr}; UnifierCounters counters; + + bool reentrantTypeReduction = false; +}; + +struct TypeReductionRentrancyGuard final +{ + explicit TypeReductionRentrancyGuard(NotNull sharedState) + : sharedState{sharedState} + { + sharedState->reentrantTypeReduction = true; + } + ~TypeReductionRentrancyGuard() + { + sharedState->reentrantTypeReduction = false; + } + TypeReductionRentrancyGuard(const TypeReductionRentrancyGuard&) = delete; + TypeReductionRentrancyGuard(TypeReductionRentrancyGuard&&) = delete; + +private: + NotNull sharedState; }; } // namespace Luau diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index e943cced..a9685462 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -10,7 +10,6 @@ #include "Type.h" LUAU_FASTINT(LuauVisitRecursionLimit) -LUAU_FASTFLAG(LuauBoundLazyTypes2) LUAU_FASTFLAG(LuauSolverV2) namespace Luau @@ -86,6 +85,8 @@ struct GenericTypeVisitor { } + virtual ~GenericTypeVisitor() {} + virtual void cycle(TypeId) {} virtual void cycle(TypePackId) {} @@ -133,6 +134,10 @@ struct GenericTypeVisitor { return visit(ty); } + virtual bool visit(TypeId ty, const NoRefineType& nrt) + { + return visit(ty); + } virtual bool visit(TypeId ty, const UnknownType& utv) { return visit(ty); @@ -186,7 +191,7 @@ struct GenericTypeVisitor { return visit(tp); } - virtual bool visit(TypePackId tp, const Unifiable::Error& etp) + virtual bool visit(TypePackId tp, const ErrorTypePack& etp) { return visit(tp); } @@ -345,6 +350,8 @@ struct GenericTypeVisitor } else if (auto atv = get(ty)) visit(ty, *atv); + else if (auto nrt = get(ty)) + visit(ty, *nrt); else if (auto utv = get(ty)) { if (visit(ty, *utv)) @@ -455,7 +462,7 @@ struct GenericTypeVisitor else if (auto gtv = get(tp)) visit(tp, *gtv); - else if (auto etv = get(tp)) + else if (auto etv = get(tp)) visit(tp, *etv); else if (auto pack = get(tp)) diff --git a/Analysis/src/AnyTypeSummary.cpp b/Analysis/src/AnyTypeSummary.cpp index 85f567af..db50e3e9 100644 --- a/Analysis/src/AnyTypeSummary.cpp +++ b/Analysis/src/AnyTypeSummary.cpp @@ -38,7 +38,7 @@ #include -LUAU_FASTFLAGVARIABLE(StudioReportLuauAny2, false); +LUAU_FASTFLAGVARIABLE(StudioReportLuauAny2); LUAU_FASTINTVARIABLE(LuauAnySummaryRecursionLimit, 300); LUAU_FASTFLAG(DebugLuauMagicTypes); @@ -161,7 +161,7 @@ void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module* typeInfo.push_back(ti); } } - + if (ret->list.size > 1 && !seenTP) { if (containsAny(retScope->returnType)) @@ -177,7 +177,6 @@ void AnyTypeSummary::visit(const Scope* scope, AstStatReturn* ret, const Module* } } } - } void AnyTypeSummary::visit(const Scope* scope, AstStatLocal* local, const Module* module, NotNull builtinTypes) diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index fd90a6ee..c0a6c254 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -425,6 +425,7 @@ struct AstJsonEncoder : public AstVisitor "AstExprFunction", [&]() { + PROP(attributes); PROP(generics); PROP(genericPacks); if (node->self) @@ -881,7 +882,7 @@ struct AstJsonEncoder : public AstVisitor PROP(name); PROP(generics); PROP(genericPacks); - PROP(type); + write("value", node->type); PROP(exported); } ); @@ -894,7 +895,7 @@ struct AstJsonEncoder : public AstVisitor "AstStatDeclareFunction", [&]() { - // TODO: attributes + PROP(attributes); PROP(name); PROP(nameLocation); PROP(params); @@ -1042,6 +1043,7 @@ struct AstJsonEncoder : public AstVisitor "AstTypeFunction", [&]() { + PROP(attributes); PROP(generics); PROP(genericPacks); PROP(argTypes); @@ -1136,6 +1138,42 @@ struct AstJsonEncoder : public AstVisitor ); } + void write(AstAttr::Type type) + { + switch (type) + { + case AstAttr::Type::Checked: + return writeString("checked"); + case AstAttr::Type::Native: + return writeString("native"); + } + } + + void write(class AstAttr* node) + { + writeNode( + node, + "AstAttr", + [&]() + { + write("name", node->type); + } + ); + } + + bool visit(class AstTypeGroup* node) override + { + writeNode( + node, + "AstTypeGroup", + [&]() + { + write("inner", node->type); + } + ); + return false; + } + bool visit(class AstTypeSingletonBool* node) override { writeNode( diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index c8470373..815164d8 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -13,6 +13,8 @@ LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) + namespace Luau { @@ -41,11 +43,26 @@ struct AutocompleteNodeFinder : public AstVisitor bool visit(AstStat* stat) override { - if (stat->location.begin < pos && pos <= stat->location.end) + if (FFlag::LuauExtendStatEndPosWithSemicolon) { - ancestry.push_back(stat); - return true; + // Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal + // to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case + // (no semicolon) we are still part of the AstStatLocal, hence the different comparison check. + if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end)) + { + ancestry.push_back(stat); + return true; + } } + else + { + if (stat->location.begin < pos && pos <= stat->location.end) + { + ancestry.push_back(stat); + return true; + } + } + return false; } @@ -509,6 +526,37 @@ static std::optional checkOverloadedDocumentationSymbol( return documentationSymbol; } +static std::optional getMetatableDocumentation( + const Module& module, + AstExpr* parentExpr, + const TableType* mtable, + const AstName& index +) +{ + auto indexIt = mtable->props.find("__index"); + if (indexIt == mtable->props.end()) + return std::nullopt; + + TypeId followed = follow(indexIt->second.type()); + const TableType* ttv = get(followed); + if (!ttv) + return std::nullopt; + + auto propIt = ttv->props.find(index.value); + if (propIt == ttv->props.end()) + return std::nullopt; + + if (FFlag::LuauSolverV2) + { + if (auto ty = propIt->second.readTy) + return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + } + else + return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + + return std::nullopt; +} + std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) { std::vector ancestry = findAstAncestryOfPosition(source, position); @@ -541,15 +589,29 @@ std::optional getDocumentationSymbolAtPosition(const Source } else if (const ClassType* ctv = get(parentTy)) { - if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) + while (ctv) { - if (FFlag::LuauSolverV2) + if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) { - if (auto ty = propIt->second.readTy) - return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + if (FFlag::LuauSolverV2) + { + if (auto ty = propIt->second.readTy) + return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol); + } + else + return checkOverloadedDocumentationSymbol( + module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol + ); } - else - return checkOverloadedDocumentationSymbol(module, propIt->second.type(), parentExpr, propIt->second.documentationSymbol); + ctv = ctv->parent ? Luau::get(*ctv->parent) : nullptr; + } + } + else if (const PrimitiveType* ptv = get(parentTy); ptv && ptv->metatable) + { + if (auto mtable = get(*ptv->metatable)) + { + if (std::optional docSymbol = getMetatableDocumentation(module, parentExpr, mtable, indexName->index)) + return docSymbol; } } } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index ee865edd..bdfa04bf 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -2,1963 +2,23 @@ #include "Luau/Autocomplete.h" #include "Luau/AstQuery.h" -#include "Luau/BuiltinDefinitions.h" +#include "Luau/TimeTrace.h" +#include "Luau/TypeArena.h" +#include "Luau/Module.h" #include "Luau/Frontend.h" -#include "Luau/ToString.h" -#include "Luau/Subtyping.h" -#include "Luau/TypeInfer.h" -#include "Luau/TypePack.h" -#include -#include -#include +#include "AutocompleteCore.h" -LUAU_FASTFLAG(LuauSolverV2); - -static const std::unordered_set kStatementStartingKeywords = - {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; +LUAU_FASTFLAG(LuauSolverV2) namespace Luau { - -static bool alreadyHasParens(const std::vector& nodes) -{ - auto iter = nodes.rbegin(); - while (iter != nodes.rend() && - ((*iter)->is() || (*iter)->is() || (*iter)->is() || (*iter)->is())) - { - iter++; - } - - if (iter == nodes.rend() || iter == nodes.rbegin()) - { - return false; - } - - if (AstExprCall* call = (*iter)->as()) - { - return call->func == *(iter - 1); - } - - return false; -} - -static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionType* func, const std::vector& nodes) -{ - if (alreadyHasParens(nodes)) - { - return ParenthesesRecommendation::None; - } - - auto idxExpr = nodes.back()->as(); - bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; - auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); - - if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) - return ParenthesesRecommendation::CursorInside; - - bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); - return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; -} - -static ParenthesesRecommendation getParenRecommendationForIntersect(const IntersectionType* intersect, const std::vector& nodes) -{ - ParenthesesRecommendation rec = ParenthesesRecommendation::None; - for (Luau::TypeId partId : intersect->parts) - { - if (auto partFunc = Luau::get(partId)) - { - rec = std::max(rec, getParenRecommendationForFunc(partFunc, nodes)); - } - else - { - return ParenthesesRecommendation::None; - } - } - return rec; -} - -static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::vector& nodes, TypeCorrectKind typeCorrect) -{ - // If element is already type-correct, even a function should be inserted without parenthesis - if (typeCorrect == TypeCorrectKind::Correct) - return ParenthesesRecommendation::None; - - id = Luau::follow(id); - if (auto func = get(id)) - { - return getParenRecommendationForFunc(func, nodes); - } - else if (auto intersect = get(id)) - { - return getParenRecommendationForIntersect(intersect, nodes); - } - return ParenthesesRecommendation::None; -} - -static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) -{ - auto expr = node->asExpr(); - if (!expr) - return std::nullopt; - - // Extra care for first function call argument location - // When we don't have anything inside () yet, we also don't have an AST node to base our lookup - if (AstExprCall* exprCall = expr->as()) - { - if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) - { - auto it = module.astTypes.find(exprCall->func); - - if (!it) - return std::nullopt; - - const FunctionType* ftv = get(follow(*it)); - - if (!ftv) - return std::nullopt; - - auto [head, tail] = flatten(ftv->argTypes); - unsigned index = exprCall->self ? 1 : 0; - - if (index < head.size()) - return head[index]; - - return std::nullopt; - } - } - - auto it = module.astExpectedTypes.find(expr); - if (!it) - return std::nullopt; - - return *it; -} - -static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull scope, TypeArena* typeArena, NotNull builtinTypes) -{ - InternalErrorReporter iceReporter; - UnifierSharedState unifierState(&iceReporter); - Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; - - if (FFlag::LuauSolverV2) - { - Subtyping subtyping{builtinTypes, NotNull{typeArena}, NotNull{&normalizer}, NotNull{&iceReporter}}; - - return subtyping.isSubtype(subTy, superTy, scope).isSubtype; - } - else - { - Unifier unifier(NotNull{&normalizer}, scope, Location(), Variance::Covariant); - - // Cost of normalization can be too high for autocomplete response time requirements - unifier.normalize = false; - unifier.checkInhabited = false; - - return unifier.canUnify(subTy, superTy).empty(); - } -} - -static TypeCorrectKind checkTypeCorrectKind( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - AstNode* node, - Position position, - TypeId ty -) -{ - ty = follow(ty); - - LUAU_ASSERT(module.hasModuleScope()); - - NotNull moduleScope{module.getModuleScope().get()}; - - auto typeAtPosition = findExpectedTypeAt(module, node, position); - - if (!typeAtPosition) - return TypeCorrectKind::None; - - TypeId expectedType = follow(*typeAtPosition); - - auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType](const FunctionType* ftv) - { - if (std::optional firstRetTy = first(ftv->retTypes)) - return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); - - return false; - }; - - // We also want to suggest functions that return compatible result - if (const FunctionType* ftv = get(ty); ftv && checkFunctionType(ftv)) - { - return TypeCorrectKind::CorrectFunctionResult; - } - else if (const IntersectionType* itv = get(ty)) - { - for (TypeId id : itv->parts) - { - if (const FunctionType* ftv = get(id); ftv && checkFunctionType(ftv)) - { - return TypeCorrectKind::CorrectFunctionResult; - } - } - } - - return checkTypeMatch(ty, expectedType, moduleScope, typeArena, builtinTypes) ? TypeCorrectKind::Correct : TypeCorrectKind::None; -} - -enum class PropIndexType -{ - Point, - Colon, - Key, -}; - -static void autocompleteProps( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - TypeId rootTy, - TypeId ty, - PropIndexType indexType, - const std::vector& nodes, - AutocompleteEntryMap& result, - std::unordered_set& seen, - std::optional containingClass = std::nullopt -) -{ - rootTy = follow(rootTy); - ty = follow(ty); - - if (seen.count(ty)) - return; - seen.insert(ty); - - auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) - { - if (indexType == PropIndexType::Key) - return false; - - bool calledWithSelf = indexType == PropIndexType::Colon; - - auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) - { - // Strong match with definition is a success - if (calledWithSelf == ftv->hasSelf) - return true; - - // Calls on classes require strict match between how function is declared and how it's called - if (get(rootTy)) - return false; - - // When called with ':', but declared without 'self', it is invalid if a function has incompatible first argument or no arguments at all - // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible - if (std::optional firstArgTy = first(ftv->argTypes)) - { - if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes)) - return calledWithSelf; - } - - return !calledWithSelf; - }; - - if (const FunctionType* ftv = get(type)) - return !isCompatibleCall(ftv); - - // For intersections, any part that is successful makes the whole call successful - if (const IntersectionType* itv = get(type)) - { - for (auto subType : itv->parts) - { - if (const FunctionType* ftv = get(Luau::follow(subType))) - { - if (isCompatibleCall(ftv)) - return false; - } - } - } - - return calledWithSelf; - }; - - auto fillProps = [&](const ClassType::Props& props) - { - for (const auto& [name, prop] : props) - { - // We are walking up the class hierarchy, so if we encounter a property that we have - // already populated, it takes precedence over the property we found just now. - if (result.count(name) == 0 && name != kParseNameError) - { - Luau::TypeId type; - - if (FFlag::LuauSolverV2) - { - if (auto ty = prop.readTy) - type = follow(*ty); - else - continue; - } - else - type = follow(prop.type()); - - TypeCorrectKind typeCorrect = indexType == PropIndexType::Key - ? TypeCorrectKind::Correct - : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); - - ParenthesesRecommendation parens = - indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); - - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Property, - type, - prop.deprecated, - isWrongIndexer(type), - typeCorrect, - containingClass, - &prop, - prop.documentationSymbol, - {}, - parens, - {}, - indexType == PropIndexType::Colon - }; - } - } - }; - - auto fillMetatableProps = [&](const TableType* mtable) - { - auto indexIt = mtable->props.find("__index"); - if (indexIt != mtable->props.end()) - { - TypeId followed = follow(indexIt->second.type()); - if (get(followed) || get(followed)) - { - autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen); - } - else if (auto indexFunction = get(followed)) - { - std::optional indexFunctionResult = first(indexFunction->retTypes); - if (indexFunctionResult) - autocompleteProps(module, typeArena, builtinTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); - } - } - }; - - if (auto cls = get(ty)) - { - containingClass = containingClass.value_or(cls); - fillProps(cls->props); - if (cls->parent) - autocompleteProps(module, typeArena, builtinTypes, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); - } - else if (auto tbl = get(ty)) - fillProps(tbl->props); - else if (auto mt = get(ty)) - { - autocompleteProps(module, typeArena, builtinTypes, rootTy, mt->table, indexType, nodes, result, seen); - - if (auto mtable = get(follow(mt->metatable))) - fillMetatableProps(mtable); - } - else if (auto i = get(ty)) - { - // Complete all properties in every variant - for (TypeId ty : i->parts) - { - AutocompleteEntryMap inner; - std::unordered_set innerSeen = seen; - - autocompleteProps(module, typeArena, builtinTypes, rootTy, ty, indexType, nodes, inner, innerSeen); - - for (auto& pair : inner) - result.insert(pair); - } - } - else if (auto u = get(ty)) - { - // Complete all properties common to all variants - auto iter = begin(u); - auto endIter = end(u); - - while (iter != endIter) - { - if (isNil(*iter)) - ++iter; - else - break; - } - - if (iter == endIter) - return; - - autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, result, seen); - - ++iter; - - while (iter != endIter) - { - AutocompleteEntryMap inner; - std::unordered_set innerSeen; - - if (isNil(*iter)) - { - ++iter; - continue; - } - - autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, inner, innerSeen); - - std::unordered_set toRemove; - - for (const auto& [k, v] : result) - { - (void)v; - if (!inner.count(k)) - toRemove.insert(k); - } - - for (const std::string& k : toRemove) - result.erase(k); - - ++iter; - } - } - else if (auto pt = get(ty)) - { - if (pt->metatable) - { - if (auto mtable = get(*pt->metatable)) - fillMetatableProps(mtable); - } - } - else if (get(get(ty))) - { - autocompleteProps(module, typeArena, builtinTypes, rootTy, builtinTypes->stringType, indexType, nodes, result, seen); - } -} - -static void autocompleteKeywords( - const SourceModule& sourceModule, - const std::vector& ancestry, - Position position, - AutocompleteEntryMap& result -) -{ - LUAU_ASSERT(!ancestry.empty()); - - AstNode* node = ancestry.back(); - - if (!node->is() && node->asExpr()) - { - // This is not strictly correct. We should recommend `and` and `or` only after - // another expression, not at the start of a new one. We should only recommend - // `not` at the start of an expression. Detecting either case reliably is quite - // complex, however; this is good enough for now. - - // These are not context-sensitive keywords, so we can unconditionally assign. - result["and"] = {AutocompleteEntryKind::Keyword}; - result["or"] = {AutocompleteEntryKind::Keyword}; - result["not"] = {AutocompleteEntryKind::Keyword}; - } -} - -static void autocompleteProps( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - TypeId ty, - PropIndexType indexType, - const std::vector& nodes, - AutocompleteEntryMap& result -) -{ - std::unordered_set seen; - autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen); -} - -AutocompleteEntryMap autocompleteProps( - const Module& module, - TypeArena* typeArena, - NotNull builtinTypes, - TypeId ty, - PropIndexType indexType, - const std::vector& nodes -) -{ - AutocompleteEntryMap result; - autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result); - return result; -} - -AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position position, std::string_view moduleName) -{ - AutocompleteEntryMap result; - - for (ScopePtr scope = findScopeAtPosition(module, position); scope; scope = scope->parent) - { - if (auto it = scope->importedTypeBindings.find(std::string(moduleName)); it != scope->importedTypeBindings.end()) - { - for (const auto& [name, ty] : it->second) - result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type}; - - break; - } - } - - return result; -} - -static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result) -{ - if (position == node->location.begin || position == node->location.end) - { - if (auto str = node->as(); str && str->quoteStyle == AstExprConstantString::Quoted) - return; - else if (node->is()) - return; - } - - auto formatKey = [addQuotes](const std::string& key) - { - if (addQuotes) - return "\"" + escape(key) + "\""; - - return escape(key); - }; - - ty = follow(ty); - - if (auto ss = get(get(ty))) - { - result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; - } - else if (auto uty = get(ty)) - { - for (auto el : uty) - { - if (auto ss = get(get(el))) - result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; - } - } -}; - -static bool canSuggestInferredType(ScopePtr scope, TypeId ty) -{ - ty = follow(ty); - - // No point in suggesting 'any', invalid to suggest others - if (get(ty) || get(ty) || get(ty) || get(ty)) - return false; - - // No syntax for unnamed tables with a metatable - if (get(ty)) - return false; - - if (const TableType* ttv = get(ty)) - { - if (ttv->name) - return true; - - if (ttv->syntheticName) - return false; - } - - // We might still have a type with cycles or one that is too long, we'll check that later - return true; -} - -// Walk complex type trees to find the element that is being edited -static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position); - -static std::optional findTypeElementAt(const AstTypeList& astTypeList, TypePackId tp, Position position) -{ - for (size_t i = 0; i < astTypeList.types.size; i++) - { - AstType* type = astTypeList.types.data[i]; - - if (type->location.containsClosed(position)) - { - auto [head, _] = flatten(tp); - - if (i < head.size()) - return findTypeElementAt(type, head[i], position); - } - } - - if (AstTypePack* argTp = astTypeList.tailType) - { - if (auto variadic = argTp->as()) - { - if (variadic->location.containsClosed(position)) - { - auto [_, tail] = flatten(tp); - - if (tail) - { - if (const VariadicTypePack* vtp = get(follow(*tail))) - return findTypeElementAt(variadic->variadicType, vtp->ty, position); - } - } - } - } - - return {}; -} - -static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position) -{ - ty = follow(ty); - - if (astType->is()) - return ty; - - if (astType->is()) - return ty; - - if (AstTypeFunction* type = astType->as()) - { - const FunctionType* ftv = get(ty); - - if (!ftv) - return {}; - - if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) - return element; - - if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) - return element; - } - - // It's possible to walk through other types like intrsection and unions if we find value in doing that - return {}; -} - -std::optional getLocalTypeInScopeAt(const Module& module, Position position, AstLocal* local) -{ - if (ScopePtr scope = findScopeAtPosition(module, position)) - { - for (const auto& [name, binding] : scope->bindings) - { - if (name == local) - return binding.typeId; - } - } - - return {}; -} - -template -static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) -{ - ToStringOptions opts; - opts.useLineBreaks = false; - opts.hideTableKind = true; - opts.functionTypeArguments = functionTypeArguments; - opts.scope = scope; - ToStringResult name = toStringDetailed(ty, opts); - - if (name.error || name.invalid || name.cycle || name.truncated) - return std::nullopt; - - return name.name; -} - -static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool functionTypeArguments = false) -{ - if (!canSuggestInferredType(scope, ty)) - return std::nullopt; - - return tryToStringDetailed(scope, ty, functionTypeArguments); -} - -static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) -{ - std::optional ty; - - if (topType) - ty = findTypeElementAt(topType, inferredType, position); - else - ty = inferredType; - - if (!ty) - return false; - - if (auto name = tryGetTypeNameInScope(scope, *ty)) - { - if (auto it = result.find(*name); it != result.end()) - it->second.typeCorrect = TypeCorrectKind::Correct; - else - result[*name] = AutocompleteEntry{AutocompleteEntryKind::Type, *ty, false, false, TypeCorrectKind::Correct}; - - return true; - } - - return false; -} - -static std::optional tryGetTypePackTypeAt(TypePackId tp, size_t index) -{ - auto [tpHead, tpTail] = flatten(tp); - - if (index < tpHead.size()) - return tpHead[index]; - - // Infinite tail - if (tpTail) - { - if (const VariadicTypePack* vtp = get(follow(*tpTail))) - return vtp->ty; - } - - return {}; -} - -template -std::optional returnFirstNonnullOptionOfType(const UnionType* utv) -{ - std::optional ret; - for (TypeId subTy : utv) - { - if (isNil(subTy)) - continue; - - if (const T* ftv = get(follow(subTy))) - { - if (ret.has_value()) - { - return std::nullopt; - } - ret = ftv; - } - else - { - return std::nullopt; - } - } - return ret; -} - -static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) -{ - auto typeAtPosition = findExpectedTypeAt(module, node, position); - - if (!typeAtPosition) - return std::nullopt; - - TypeId expectedType = follow(*typeAtPosition); - - if (get(expectedType)) - return true; - - if (const IntersectionType* itv = get(expectedType)) - { - return std::all_of( - begin(itv->parts), - end(itv->parts), - [](auto&& ty) - { - return get(Luau::follow(ty)) != nullptr; - } - ); - } - - if (const UnionType* utv = get(expectedType)) - return returnFirstNonnullOptionOfType(utv).has_value(); - - return false; -} - -AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position position, const std::vector& ancestry) -{ - AutocompleteEntryMap result; - - ScopePtr startScope = findScopeAtPosition(module, position); - - for (ScopePtr scope = startScope; scope; scope = scope->parent) - { - for (const auto& [name, ty] : scope->exportedTypeBindings) - { - if (!result.count(name)) - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Type, - ty.type, - false, - false, - TypeCorrectKind::None, - std::nullopt, - std::nullopt, - ty.type->documentationSymbol - }; - } - - for (const auto& [name, ty] : scope->privateTypeBindings) - { - if (!result.count(name)) - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Type, - ty.type, - false, - false, - TypeCorrectKind::None, - std::nullopt, - std::nullopt, - ty.type->documentationSymbol - }; - } - - for (const auto& [name, _] : scope->importedTypeBindings) - { - if (auto binding = scope->linearSearchForBinding(name, true)) - { - if (!result.count(name)) - result[name] = AutocompleteEntry{AutocompleteEntryKind::Module, binding->typeId}; - } - } - } - - AstNode* parent = nullptr; - AstType* topType = nullptr; // TODO: rename? - - for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) - { - if (AstType* asType = (*it)->asType()) - { - topType = asType; - } - else - { - parent = *it; - break; - } - } - - if (!parent) - return result; - - if (AstStatLocal* node = parent->as()) // Try to provide inferred type of the local - { - // Look at which of the variable types we are defining - for (size_t i = 0; i < node->vars.size; i++) - { - AstLocal* var = node->vars.data[i]; - - if (var->annotation && var->annotation->location.containsClosed(position)) - { - if (node->values.size == 0) - break; - - unsigned tailPos = 0; - - // For multiple return values we will try to unpack last function call return type pack - if (i >= node->values.size) - { - tailPos = int(i) - int(node->values.size) + 1; - i = int(node->values.size) - 1; - } - - AstExpr* expr = node->values.data[i]->asExpr(); - - if (!expr) - break; - - TypeId inferredType = nullptr; - - if (AstExprCall* exprCall = expr->as()) - { - if (auto it = module.astTypes.find(exprCall->func)) - { - if (const FunctionType* ftv = get(follow(*it))) - { - if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) - inferredType = *ty; - } - } - } - else - { - if (tailPos != 0) - break; - - if (auto it = module.astTypes.find(expr)) - inferredType = *it; - } - - if (inferredType) - tryAddTypeCorrectSuggestion(result, startScope, topType, inferredType, position); - - break; - } - } - } - else if (AstExprFunction* node = parent->as()) - { - // For lookup inside expected function type if that's available - auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* - { - auto it = module.astExpectedTypes.find(expr); - - if (!it) - return nullptr; - - TypeId ty = follow(*it); - - if (const FunctionType* ftv = get(ty)) - return ftv; - - // Handle optional function type - if (const UnionType* utv = get(ty)) - { - return returnFirstNonnullOptionOfType(utv).value_or(nullptr); - } - - return nullptr; - }; - - // Find which argument type we are defining - for (size_t i = 0; i < node->args.size; i++) - { - AstLocal* arg = node->args.data[i]; - - if (arg->annotation && arg->annotation->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, i)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - // Otherwise, try to use the type inferred by typechecker - else if (auto inferredType = getLocalTypeInScopeAt(module, position, arg)) - { - tryAddTypeCorrectSuggestion(result, startScope, topType, *inferredType, position); - } - - break; - } - } - - if (AstTypePack* argTp = node->varargAnnotation) - { - if (auto variadic = argTp->as()) - { - if (variadic->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, ~0u)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - } - } - } - - if (!node->returnAnnotation) - return result; - - for (size_t i = 0; i < node->returnAnnotation->types.size; i++) - { - AstType* ret = node->returnAnnotation->types.data[i]; - - if (ret->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - - // TODO: with additional type information, we could suggest inferred return type here - break; - } - } - - if (AstTypePack* retTp = node->returnAnnotation->tailType) - { - if (auto variadic = retTp->as()) - { - if (variadic->location.containsClosed(position)) - { - if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) - { - if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) - tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); - } - } - } - } - } - - return result; -} - -static bool isInLocalNames(const std::vector& ancestry, Position position) -{ - for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) - { - if (auto statLocal = (*iter)->as()) - { - for (auto var : statLocal->vars) - { - if (var->location.containsClosed(position)) - { - return true; - } - } - } - else if (auto funcExpr = (*iter)->as()) - { - if (funcExpr->argLocation && funcExpr->argLocation->contains(position)) - { - return true; - } - } - else if (auto localFunc = (*iter)->as()) - { - return localFunc->name->location.containsClosed(position); - } - else if (auto block = (*iter)->as()) - { - if (block->body.size > 0) - { - return false; - } - } - else if ((*iter)->asStat()) - { - return false; - } - } - return false; -} - -static bool isIdentifier(AstNode* node) -{ - return node->is() || node->is(); -} - -static bool isBeingDefined(const std::vector& ancestry, const Symbol& symbol) -{ - // Current set of rules only check for local binding match - if (!symbol.local) - return false; - - for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) - { - if (auto statLocal = (*iter)->as()) - { - for (auto var : statLocal->vars) - { - if (symbol.local == var) - return true; - } - } - } - - return false; -} - -template -T* extractStat(const std::vector& ancestry) -{ - AstNode* node = ancestry.size() >= 1 ? ancestry.rbegin()[0] : nullptr; - if (!node) - return nullptr; - - if (T* t = node->as()) - return t; - - AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; - if (!parent) - return nullptr; - - AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; - AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; - - if (!grandParent) - return nullptr; - - if (T* t = parent->as(); t && grandParent->is()) - return t; - - if (!greatGrandParent) - return nullptr; - - if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) - return t; - - return nullptr; -} - -static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& binding, Position pos) -{ - if (symbol.local) - return binding.location.end < pos; - - // Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it - return binding.location == Location() || !binding.location.containsClosed(pos); -} - -static AutocompleteEntryMap autocompleteStatement( - const SourceModule& sourceModule, - const Module& module, - const std::vector& ancestry, - Position position -) -{ - // This is inefficient. :( - ScopePtr scope = findScopeAtPosition(module, position); - - AutocompleteEntryMap result; - - if (isInLocalNames(ancestry, position)) - { - autocompleteKeywords(sourceModule, ancestry, position, result); - return result; - } - - while (scope) - { - for (const auto& [name, binding] : scope->bindings) - { - if (!isBindingLegalAtCurrentPosition(name, binding, position)) - continue; - - std::string n = toString(name); - if (!result.count(n)) - result[n] = { - AutocompleteEntryKind::Binding, - binding.typeId, - binding.deprecated, - false, - TypeCorrectKind::None, - std::nullopt, - std::nullopt, - binding.documentationSymbol, - {}, - getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None) - }; - } - - scope = scope->parent; - } - - for (const auto& kw : kStatementStartingKeywords) - result.emplace(kw, AutocompleteEntry{AutocompleteEntryKind::Keyword}); - - for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) - { - if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstStatIf* statIf = (*it)->as()) - { - bool hasEnd = statIf->thenbody->hasEnd; - if (statIf->elsebody) - { - if (AstStatBlock* elseBlock = statIf->elsebody->as()) - hasEnd = elseBlock->hasEnd; - } - - if (!hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) - result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - - if (ancestry.size() >= 2) - { - AstNode* parent = ancestry.rbegin()[1]; - if (AstStatIf* statIf = parent->as()) - { - if (!statIf->elsebody || (statIf->elseLocation && statIf->elseLocation->containsClosed(position))) - { - result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - } - - if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - - if (ancestry.size() >= 4) - { - auto iter = ancestry.rbegin(); - if (AstStatIf* statIf = iter[3]->as(); - statIf != nullptr && !statIf->elsebody && iter[2]->is() && iter[1]->is() && isIdentifier(iter[0])) - { - result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - } - } - - if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) - result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - - return result; -} - -// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) -static bool autocompleteIfElseExpression( - const AstNode* node, - const std::vector& ancestry, - const Position& position, - AutocompleteEntryMap& outResult -) -{ - AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; - if (!parent) - return false; - - if (node->is()) - { - // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else - // expression. - return true; - } - - AstExprIfElse* ifElseExpr = parent->as(); - if (!ifElseExpr || ifElseExpr->condition->location.containsClosed(position)) - { - return false; - } - else if (!ifElseExpr->hasThen) - { - outResult["then"] = {AutocompleteEntryKind::Keyword}; - return true; - } - else if (ifElseExpr->trueExpr->location.containsClosed(position)) - { - return false; - } - else if (!ifElseExpr->hasElse) - { - outResult["else"] = {AutocompleteEntryKind::Keyword}; - outResult["elseif"] = {AutocompleteEntryKind::Keyword}; - return true; - } - else - { - return false; - } -} - -static AutocompleteContext autocompleteExpression( - const SourceModule& sourceModule, - const Module& module, - NotNull builtinTypes, - TypeArena* typeArena, - const std::vector& ancestry, - Position position, - AutocompleteEntryMap& result -) -{ - LUAU_ASSERT(!ancestry.empty()); - - AstNode* node = ancestry.rbegin()[0]; - - if (node->is()) - { - if (auto it = module.astTypes.find(node->asExpr())) - autocompleteProps(module, typeArena, builtinTypes, *it, PropIndexType::Point, ancestry, result); - } - else if (autocompleteIfElseExpression(node, ancestry, position, result)) - return AutocompleteContext::Keyword; - else if (node->is()) - return AutocompleteContext::Unknown; - else - { - // This is inefficient. :( - ScopePtr scope = findScopeAtPosition(module, position); - - while (scope) - { - for (const auto& [name, binding] : scope->bindings) - { - if (!isBindingLegalAtCurrentPosition(name, binding, position)) - continue; - - if (isBeingDefined(ancestry, name)) - continue; - - std::string n = toString(name); - if (!result.count(n)) - { - TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId); - - result[n] = { - AutocompleteEntryKind::Binding, - binding.typeId, - binding.deprecated, - false, - typeCorrect, - std::nullopt, - std::nullopt, - binding.documentationSymbol, - {}, - getParenRecommendation(binding.typeId, ancestry, typeCorrect) - }; - } - } - - scope = scope->parent; - } - - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->nilType); - TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->trueType); - TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->falseType); - TypeCorrectKind correctForFunction = - functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; - result["true"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForTrue}; - result["false"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForFalse}; - result["nil"] = {AutocompleteEntryKind::Keyword, builtinTypes->nilType, false, false, correctForNil}; - result["not"] = {AutocompleteEntryKind::Keyword}; - result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; - - if (auto ty = findExpectedTypeAt(module, node, position)) - autocompleteStringSingleton(*ty, true, node, position, result); - } - - return AutocompleteContext::Expression; -} - -static AutocompleteResult autocompleteExpression( - const SourceModule& sourceModule, - const Module& module, - NotNull builtinTypes, - TypeArena* typeArena, - const std::vector& ancestry, - Position position -) -{ - AutocompleteEntryMap result; - AutocompleteContext context = autocompleteExpression(sourceModule, module, builtinTypes, typeArena, ancestry, position, result); - return {result, ancestry, context}; -} - -static std::optional getMethodContainingClass(const ModulePtr& module, AstExpr* funcExpr) -{ - AstExpr* parentExpr = nullptr; - if (auto indexName = funcExpr->as()) - { - parentExpr = indexName->expr; - } - else if (auto indexExpr = funcExpr->as()) - { - parentExpr = indexExpr->expr; - } - else - { - return std::nullopt; - } - - auto parentIt = module->astTypes.find(parentExpr); - if (!parentIt) - { - return std::nullopt; - } - - Luau::TypeId parentType = Luau::follow(*parentIt); - - if (auto parentClass = Luau::get(parentType)) - { - return parentClass; - } - - if (auto parentUnion = Luau::get(parentType)) - { - return returnFirstNonnullOptionOfType(parentUnion); - } - - return std::nullopt; -} - -static bool stringPartOfInterpString(const AstNode* node, Position position) -{ - const AstExprInterpString* interpString = node->as(); - if (!interpString) - { - return false; - } - - for (const AstExpr* expression : interpString->expressions) - { - if (expression->location.containsClosed(position)) - { - return false; - } - } - - return true; -} - -static bool isSimpleInterpolatedString(const AstNode* node) -{ - const AstExprInterpString* interpString = node->as(); - return interpString != nullptr && interpString->expressions.size == 0; -} - -static std::optional getStringContents(const AstNode* node) -{ - if (const AstExprConstantString* string = node->as()) - { - return std::string(string->value.data, string->value.size); - } - else if (const AstExprInterpString* interpString = node->as(); interpString && interpString->expressions.size == 0) - { - LUAU_ASSERT(interpString->strings.size == 1); - return std::string(interpString->strings.data->data, interpString->strings.data->size); - } - else - { - return std::nullopt; - } -} - -static std::optional autocompleteStringParams( - const SourceModule& sourceModule, - const ModulePtr& module, - const std::vector& nodes, - Position position, - StringCompletionCallback callback -) -{ - if (nodes.size() < 2) - { - return std::nullopt; - } - - if (!nodes.back()->is() && !isSimpleInterpolatedString(nodes.back()) && !nodes.back()->is()) - { - return std::nullopt; - } - - if (!nodes.back()->is()) - { - if (nodes.back()->location.end == position || nodes.back()->location.begin == position) - { - return std::nullopt; - } - } - - AstExprCall* candidate = nodes.at(nodes.size() - 2)->as(); - if (!candidate) - { - return std::nullopt; - } - - // HACK: All current instances of 'magic string' params are the first parameter of their functions, - // so we encode that here rather than putting a useless member on the FunctionType struct. - if (candidate->args.size > 1 && !candidate->args.data[0]->location.contains(position)) - { - return std::nullopt; - } - - auto it = module->astTypes.find(candidate->func); - if (!it) - { - return std::nullopt; - } - - std::optional candidateString = getStringContents(nodes.back()); - - auto performCallback = [&](const FunctionType* funcType) -> std::optional - { - for (const std::string& tag : funcType->tags) - { - if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) - { - return ret; - } - } - return std::nullopt; - }; - - auto followedId = Luau::follow(*it); - if (auto functionType = Luau::get(followedId)) - { - return performCallback(functionType); - } - - if (auto intersect = Luau::get(followedId)) - { - for (TypeId part : intersect->parts) - { - if (auto candidateFunctionType = Luau::get(part)) - { - if (std::optional ret = performCallback(candidateFunctionType)) - { - return ret; - } - } - } - } - - return std::nullopt; -} - -static AutocompleteResult autocompleteWhileLoopKeywords(std::vector ancestry) -{ - AutocompleteEntryMap ret; - ret["do"] = {AutocompleteEntryKind::Keyword}; - ret["and"] = {AutocompleteEntryKind::Keyword}; - ret["or"] = {AutocompleteEntryKind::Keyword}; - return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword}; -} - -static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) -{ - std::string result = "function("; - - auto [args, tail] = Luau::flatten(funcTy.argTypes); - - bool first = true; - // Skip the implicit 'self' argument if call is indexed with ':' - for (size_t argIdx = 0; argIdx < args.size(); ++argIdx) - { - if (!first) - result += ", "; - else - first = false; - - std::string name; - if (argIdx < funcTy.argNames.size() && funcTy.argNames[argIdx]) - name = funcTy.argNames[argIdx]->name; - else - name = "a" + std::to_string(argIdx); - - if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) - result += name + ": " + *type; - else - result += name; - } - - if (tail && (Luau::isVariadic(*tail) || Luau::get(Luau::follow(*tail)))) - { - if (!first) - result += ", "; - - std::optional varArgType; - if (const VariadicTypePack* pack = get(follow(*tail))) - { - if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) - varArgType = std::move(res); - } - - if (varArgType) - result += "...: " + *varArgType; - else - result += "..."; - } - - result += ")"; - - auto [rets, retTail] = Luau::flatten(funcTy.retTypes); - if (const size_t totalRetSize = rets.size() + (retTail ? 1 : 0); totalRetSize > 0) - { - if (std::optional returnTypes = tryToStringDetailed(scope, funcTy.retTypes, true)) - { - result += ": "; - bool wrap = totalRetSize != 1; - if (wrap) - result += "("; - result += *returnTypes; - if (wrap) - result += ")"; - } - } - result += " end"; - return result; -} - -static std::optional makeAnonymousAutofilled( - const ModulePtr& module, - Position position, - const AstNode* node, - const std::vector& ancestry -) -{ - const AstExprCall* call = node->as(); - if (!call && ancestry.size() > 1) - call = ancestry[ancestry.size() - 2]->as(); - - if (!call) - return std::nullopt; - - if (!call->location.containsClosed(position) || call->func->location.containsClosed(position)) - return std::nullopt; - - TypeId* typeIter = module->astTypes.find(call->func); - if (!typeIter) - return std::nullopt; - - const FunctionType* outerFunction = get(follow(*typeIter)); - if (!outerFunction) - return std::nullopt; - - size_t argument = 0; - for (size_t i = 0; i < call->args.size; ++i) - { - if (call->args.data[i]->location.containsClosed(position)) - { - argument = i; - break; - } - } - - if (call->self) - argument++; - - std::optional argType; - auto [args, tail] = flatten(outerFunction->argTypes); - if (argument < args.size()) - argType = args[argument]; - - if (!argType) - return std::nullopt; - - TypeId followed = follow(*argType); - const FunctionType* type = get(followed); - if (!type) - { - if (const UnionType* unionType = get(followed)) - { - if (std::optional nonnullFunction = returnFirstNonnullOptionOfType(unionType)) - type = *nonnullFunction; - } - } - - if (!type) - return std::nullopt; - - const ScopePtr scope = findScopeAtPosition(*module, position); - if (!scope) - return std::nullopt; - - AutocompleteEntry entry; - entry.kind = AutocompleteEntryKind::GeneratedFunction; - entry.typeCorrect = TypeCorrectKind::Correct; - entry.type = argType; - entry.insertText = makeAnonymous(scope, *type); - return std::make_optional(std::move(entry)); -} - -static AutocompleteResult autocomplete( - const SourceModule& sourceModule, - const ModulePtr& module, - NotNull builtinTypes, - TypeArena* typeArena, - Scope* globalScope, - Position position, - StringCompletionCallback callback -) -{ - if (isWithinComment(sourceModule, position)) - return {}; - - std::vector ancestry = findAncestryAtPositionForAutocomplete(sourceModule, position); - LUAU_ASSERT(!ancestry.empty()); - AstNode* node = ancestry.back(); - - AstExprConstantNil dummy{Location{}}; - AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; - - // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node - if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) - { - ancestry.pop_back(); - - node = ancestry.back(); - parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; - } - - if (auto indexName = node->as()) - { - auto it = module->astTypes.find(indexName->expr); - if (!it) - return {}; - - TypeId ty = follow(*it); - PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - - return {autocompleteProps(*module, typeArena, builtinTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; - } - else if (auto typeReference = node->as()) - { - if (typeReference->prefix) - return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), ancestry, AutocompleteContext::Type}; - else - return {autocompleteTypeNames(*module, position, ancestry), ancestry, AutocompleteContext::Type}; - } - else if (node->is()) - { - return {autocompleteTypeNames(*module, position, ancestry), ancestry, AutocompleteContext::Type}; - } - else if (AstStatLocal* statLocal = node->as()) - { - if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) - return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Unknown}; - else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - else - return {}; - } - - else if (AstStatFor* statFor = extractStat(ancestry)) - { - if (!statFor->hasDo || position < statFor->doLocation.begin) - { - if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || - (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - return {}; - } - - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - } - - else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) - { - if (!statForIn->hasIn || position <= statForIn->inLocation.begin) - { - AstLocal* lastName = statForIn->vars.data[statForIn->vars.size - 1]; - if (lastName->name == kParseNameError || lastName->location.containsClosed(position)) - { - // Here we are either working with a missing binding (as would be the case in a bare "for" keyword) or - // the cursor is still touching a binding name. The user is still typing a new name, so we should not offer - // any suggestions. - return {}; - } - - return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } - - if (!statForIn->hasDo || position <= statForIn->doLocation.begin) - { - LUAU_ASSERT(statForIn->values.size > 0); - AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; - - if (lastExpr->location.containsClosed(position)) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (position > lastExpr->location.end) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - - return {}; // Not sure what this means - } - } - else if (AstStatForIn* statForIn = extractStat(ancestry)) - { - // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. - // ex "for f in f do" - if (!statForIn->hasDo) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - } - - else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) - { - if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) - { - return autocompleteWhileLoopKeywords(ancestry); - } - - if (!statWhile->hasDo || position < statWhile->doLocation.begin) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - - if (statWhile->hasDo && position > statWhile->doLocation.end) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - } - - else if (AstStatWhile* statWhile = extractStat(ancestry); - (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition && - !statWhile->condition->location.containsClosed(position))) - { - return autocompleteWhileLoopKeywords(ancestry); - } - else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) - { - return { - {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, - ancestry, - AutocompleteContext::Keyword - }; - } - else if (AstStatIf* statIf = parent->as(); statIf && node->is()) - { - if (statIf->condition->is()) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) - return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } - else if (AstStatIf* statIf = extractStat(ancestry); statIf && - (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && - (statIf->condition && !statIf->condition->location.containsClosed(position))) - { - AutocompleteEntryMap ret; - ret["then"] = {AutocompleteEntryKind::Keyword}; - ret["and"] = {AutocompleteEntryKind::Keyword}; - ret["or"] = {AutocompleteEntryKind::Keyword}; - return {std::move(ret), ancestry, AutocompleteContext::Keyword}; - } - else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - else if (AstExprTable* exprTable = parent->as(); - exprTable && (node->is() || node->is() || node->is())) - { - for (const auto& [kind, key, value] : exprTable->items) - { - // If item doesn't have a key, maybe the value is actually the key - if (key ? key == node : node->is() && value == node) - { - if (auto it = module->astExpectedTypes.find(exprTable)) - { - auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); - - if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*nodeIt, !node->is(), node, position, result); - - if (!key) - { - // If there is "no key," it may be that the user - // intends for the current token to be the key, but - // has yet to type the `=` sign. - // - // If the key type is a union of singleton strings, - // suggest those too. - if (auto ttv = get(follow(*it)); ttv && ttv->indexer) - { - autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); - } - } - - // Remove keys that are already completed - for (const auto& item : exprTable->items) - { - if (!item.key) - continue; - - if (auto stringKey = item.key->as()) - result.erase(std::string(stringKey->value.data, stringKey->value.size)); - } - - // If we know for sure that a key is being written, do not offer general expression suggestions - if (!key) - autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position, result); - - return {result, ancestry, AutocompleteContext::Property}; - } - - break; - } - } - } - else if (AstExprTable* exprTable = node->as()) - { - AutocompleteEntryMap result; - - if (auto it = module->astExpectedTypes.find(exprTable)) - { - result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); - - // If the key type is a union of singleton strings, - // suggest those too. - if (auto ttv = get(follow(*it)); ttv && ttv->indexer) - { - autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); - } - - // Remove keys that are already completed - for (const auto& item : exprTable->items) - { - if (!item.key) - continue; - - if (auto stringKey = item.key->as()) - result.erase(std::string(stringKey->value.data, stringKey->value.size)); - } - } - - // Also offer general expression suggestions - autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position, result); - - return {result, ancestry, AutocompleteContext::Property}; - } - else if (isIdentifier(node) && (parent->is() || parent->is())) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - - if (std::optional ret = autocompleteStringParams(sourceModule, module, ancestry, position, callback)) - { - return {*ret, ancestry, AutocompleteContext::String}; - } - else if (node->is() || isSimpleInterpolatedString(node)) - { - AutocompleteEntryMap result; - - if (auto it = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*it, false, node, position, result); - - if (ancestry.size() >= 2) - { - if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) - { - if (auto it = module->astTypes.find(idxExpr->expr)) - autocompleteProps(*module, typeArena, builtinTypes, follow(*it), PropIndexType::Point, ancestry, result); - } - else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) - { - if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) - { - if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) - autocompleteStringSingleton(*it, false, node, position, result); - } - } - } - - return {result, ancestry, AutocompleteContext::String}; - } - else if (stringPartOfInterpString(node, position)) - { - // We're not a simple interpolated string, we're something like `a{"b"}@1`, and we - // can't know what to format to - AutocompleteEntryMap map; - return {map, ancestry, AutocompleteContext::String}; - } - - if (node->is()) - return {}; - - if (node->asExpr()) - { - AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) - ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); - return ret; - } - else if (node->asStat()) - return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; - - return {}; -} - AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { + LUAU_TIMETRACE_SCOPE("Luau::autocomplete", "Autocomplete"); + LUAU_TIMETRACE_ARGUMENT("name", moduleName.c_str()); + const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; @@ -1980,7 +40,13 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName globalScope = frontend.globalsForAutocomplete.globalScope.get(); TypeArena typeArena; - return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); + if (isWithinComment(*sourceModule, position)) + return {}; + + std::vector ancestry = findAncestryAtPositionForAutocomplete(*sourceModule, position); + LUAU_ASSERT(!ancestry.empty()); + ScopePtr startScope = findScopeAtPosition(*module, position); + return autocomplete_(module, builtinTypes, &typeArena, ancestry, globalScope, startScope, position, frontend.fileResolver, callback); } } // namespace Luau diff --git a/Analysis/src/AutocompleteCore.cpp b/Analysis/src/AutocompleteCore.cpp new file mode 100644 index 00000000..faabdf47 --- /dev/null +++ b/Analysis/src/AutocompleteCore.cpp @@ -0,0 +1,2057 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "AutocompleteCore.h" + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/AutocompleteTypes.h" + +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/FileResolver.h" +#include "Luau/Frontend.h" +#include "Luau/TimeTrace.h" +#include "Luau/ToString.h" +#include "Luau/Subtyping.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypePack.h" + +#include +#include +#include + +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTypeInferRecursionLimit) + +LUAU_FASTFLAGVARIABLE(LuauAutocompleteRefactorsForIncrementalAutocomplete) + +LUAU_FASTFLAGVARIABLE(LuauAutocompleteUsesModuleForTypeCompatibility) + +static const std::unordered_set kStatementStartingKeywords = + {"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; + +namespace Luau +{ + +static bool alreadyHasParens(const std::vector& nodes) +{ + auto iter = nodes.rbegin(); + while (iter != nodes.rend() && + ((*iter)->is() || (*iter)->is() || (*iter)->is() || (*iter)->is())) + { + iter++; + } + + if (iter == nodes.rend() || iter == nodes.rbegin()) + { + return false; + } + + if (AstExprCall* call = (*iter)->as()) + { + return call->func == *(iter - 1); + } + + return false; +} + +static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionType* func, const std::vector& nodes) +{ + if (alreadyHasParens(nodes)) + { + return ParenthesesRecommendation::None; + } + + auto idxExpr = nodes.back()->as(); + bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; + auto [argTypes, argVariadicPack] = Luau::flatten(func->argTypes); + + if (argVariadicPack.has_value() && isVariadic(*argVariadicPack)) + return ParenthesesRecommendation::CursorInside; + + bool noArgFunction = argTypes.empty() || (hasImplicitSelf && argTypes.size() == 1); + return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; +} + +static ParenthesesRecommendation getParenRecommendationForIntersect(const IntersectionType* intersect, const std::vector& nodes) +{ + ParenthesesRecommendation rec = ParenthesesRecommendation::None; + for (Luau::TypeId partId : intersect->parts) + { + if (auto partFunc = Luau::get(partId)) + { + rec = std::max(rec, getParenRecommendationForFunc(partFunc, nodes)); + } + else + { + return ParenthesesRecommendation::None; + } + } + return rec; +} + +static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::vector& nodes, TypeCorrectKind typeCorrect) +{ + // If element is already type-correct, even a function should be inserted without parenthesis + if (typeCorrect == TypeCorrectKind::Correct) + return ParenthesesRecommendation::None; + + id = Luau::follow(id); + if (auto func = get(id)) + { + return getParenRecommendationForFunc(func, nodes); + } + else if (auto intersect = get(id)) + { + return getParenRecommendationForIntersect(intersect, nodes); + } + return ParenthesesRecommendation::None; +} + +static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) +{ + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + // Extra care for first function call argument location + // When we don't have anything inside () yet, we also don't have an AST node to base our lookup + if (AstExprCall* exprCall = expr->as()) + { + if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) + { + auto it = module.astTypes.find(exprCall->func); + + if (!it) + return std::nullopt; + + const FunctionType* ftv = get(follow(*it)); + + if (!ftv) + return std::nullopt; + + auto [head, tail] = flatten(ftv->argTypes); + unsigned index = exprCall->self ? 1 : 0; + + if (index < head.size()) + return head[index]; + + return std::nullopt; + } + } + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + return *it; +} + +static bool checkTypeMatch( + const Module& module, + TypeId subTy, + TypeId superTy, + NotNull scope, + TypeArena* typeArena, + NotNull builtinTypes +) +{ + InternalErrorReporter iceReporter; + UnifierSharedState unifierState(&iceReporter); + SimplifierPtr simplifier = newSimplifier(NotNull{typeArena}, builtinTypes); + Normalizer normalizer{typeArena, builtinTypes, NotNull{&unifierState}}; + if (FFlag::LuauAutocompleteUsesModuleForTypeCompatibility) + { + if (module.checkedInNewSolver) + { + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{ + NotNull{&iceReporter}, NotNull{&limits} + }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime + + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + + Subtyping subtyping{ + builtinTypes, + NotNull{typeArena}, + NotNull{simplifier.get()}, + NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, + NotNull{&iceReporter} + }; + + return subtyping.isSubtype(subTy, superTy, scope).isSubtype; + } + else + { + Unifier unifier(NotNull{&normalizer}, scope, Location(), Variance::Covariant); + + // Cost of normalization can be too high for autocomplete response time requirements + unifier.normalize = false; + unifier.checkInhabited = false; + + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + + return unifier.canUnify(subTy, superTy).empty(); + } + } + else + { + if (FFlag::LuauSolverV2) + { + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{ + NotNull{&iceReporter}, NotNull{&limits} + }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime + + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + + Subtyping subtyping{ + builtinTypes, + NotNull{typeArena}, + NotNull{simplifier.get()}, + NotNull{&normalizer}, + NotNull{&typeFunctionRuntime}, + NotNull{&iceReporter} + }; + + return subtyping.isSubtype(subTy, superTy, scope).isSubtype; + } + else + { + Unifier unifier(NotNull{&normalizer}, scope, Location(), Variance::Covariant); + + // Cost of normalization can be too high for autocomplete response time requirements + unifier.normalize = false; + unifier.checkInhabited = false; + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = FInt::LuauTypeInferIterationLimit; + + return unifier.canUnify(subTy, superTy).empty(); + } + } +} + +static TypeCorrectKind checkTypeCorrectKind( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + AstNode* node, + Position position, + TypeId ty +) +{ + ty = follow(ty); + + LUAU_ASSERT(module.hasModuleScope()); + + NotNull moduleScope{module.getModuleScope().get()}; + + auto typeAtPosition = findExpectedTypeAt(module, node, position); + + if (!typeAtPosition) + return TypeCorrectKind::None; + + TypeId expectedType = follow(*typeAtPosition); + + auto checkFunctionType = [typeArena, builtinTypes, moduleScope, &expectedType, &module](const FunctionType* ftv) + { + if (std::optional firstRetTy = first(ftv->retTypes)) + return checkTypeMatch(module, *firstRetTy, expectedType, moduleScope, typeArena, builtinTypes); + + return false; + }; + + // We also want to suggest functions that return compatible result + if (const FunctionType* ftv = get(ty); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + else if (const IntersectionType* itv = get(ty)) + { + for (TypeId id : itv->parts) + { + id = follow(id); + + if (const FunctionType* ftv = get(id); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + } + } + + return checkTypeMatch(module, ty, expectedType, moduleScope, typeArena, builtinTypes) ? TypeCorrectKind::Correct : TypeCorrectKind::None; +} + +enum class PropIndexType +{ + Point, + Colon, + Key, +}; + +static void autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId rootTy, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes, + AutocompleteEntryMap& result, + std::unordered_set& seen, + std::optional containingClass = std::nullopt +) +{ + rootTy = follow(rootTy); + ty = follow(ty); + + if (seen.count(ty)) + return; + seen.insert(ty); + + auto isWrongIndexer = [typeArena, builtinTypes, &module, rootTy, indexType](Luau::TypeId type) + { + if (indexType == PropIndexType::Key) + return false; + + bool calledWithSelf = indexType == PropIndexType::Colon; + + auto isCompatibleCall = [typeArena, builtinTypes, &module, rootTy, calledWithSelf](const FunctionType* ftv) + { + // Strong match with definition is a success + if (calledWithSelf == ftv->hasSelf) + return true; + + // Calls on classes require strict match between how function is declared and how it's called + if (get(rootTy)) + return false; + + // When called with ':', but declared without 'self', it is invalid if a function has incompatible first argument or no arguments at all + // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible + if (std::optional firstArgTy = first(ftv->argTypes)) + { + if (checkTypeMatch(module, rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena, builtinTypes)) + return calledWithSelf; + } + + return !calledWithSelf; + }; + + if (const FunctionType* ftv = get(type)) + return !isCompatibleCall(ftv); + + // For intersections, any part that is successful makes the whole call successful + if (const IntersectionType* itv = get(type)) + { + for (auto subType : itv->parts) + { + if (const FunctionType* ftv = get(Luau::follow(subType))) + { + if (isCompatibleCall(ftv)) + return false; + } + } + } + + return calledWithSelf; + }; + + auto fillProps = [&](const ClassType::Props& props) + { + for (const auto& [name, prop] : props) + { + // We are walking up the class hierarchy, so if we encounter a property that we have + // already populated, it takes precedence over the property we found just now. + if (result.count(name) == 0 && name != kParseNameError) + { + Luau::TypeId type; + + if (FFlag::LuauSolverV2) + { + if (auto ty = prop.readTy) + type = follow(*ty); + else + continue; + } + else + type = follow(prop.type()); + + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key + ? TypeCorrectKind::Correct + : checkTypeCorrectKind(module, typeArena, builtinTypes, nodes.back(), {{}, {}}, type); + + ParenthesesRecommendation parens = + indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); + + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Property, + type, + prop.deprecated, + isWrongIndexer(type), + typeCorrect, + containingClass, + &prop, + prop.documentationSymbol, + {}, + parens, + {}, + indexType == PropIndexType::Colon + }; + } + } + }; + + auto fillMetatableProps = [&](const TableType* mtable) + { + auto indexIt = mtable->props.find("__index"); + if (indexIt != mtable->props.end()) + { + TypeId followed = follow(indexIt->second.type()); + if (get(followed) || get(followed)) + { + autocompleteProps(module, typeArena, builtinTypes, rootTy, followed, indexType, nodes, result, seen); + } + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retTypes); + if (indexFunctionResult) + autocompleteProps(module, typeArena, builtinTypes, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + } + } + }; + + if (auto cls = get(ty)) + { + containingClass = containingClass.value_or(cls); + fillProps(cls->props); + if (cls->parent) + autocompleteProps(module, typeArena, builtinTypes, rootTy, *cls->parent, indexType, nodes, result, seen, containingClass); + } + else if (auto tbl = get(ty)) + fillProps(tbl->props); + else if (auto mt = get(ty)) + { + autocompleteProps(module, typeArena, builtinTypes, rootTy, mt->table, indexType, nodes, result, seen); + + if (auto mtable = get(follow(mt->metatable))) + fillMetatableProps(mtable); + } + else if (auto i = get(ty)) + { + // Complete all properties in every variant + for (TypeId ty : i->parts) + { + AutocompleteEntryMap inner; + std::unordered_set innerSeen = seen; + + autocompleteProps(module, typeArena, builtinTypes, rootTy, ty, indexType, nodes, inner, innerSeen); + + for (auto& pair : inner) + result.insert(pair); + } + } + else if (auto u = get(ty)) + { + // Complete all properties common to all variants + auto iter = begin(u); + auto endIter = end(u); + + while (iter != endIter) + { + if (isNil(*iter)) + ++iter; + else + break; + } + + if (iter == endIter) + return; + + autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, result, seen); + + ++iter; + + while (iter != endIter) + { + AutocompleteEntryMap inner; + std::unordered_set innerSeen; + + if (isNil(*iter)) + { + ++iter; + continue; + } + + autocompleteProps(module, typeArena, builtinTypes, rootTy, *iter, indexType, nodes, inner, innerSeen); + + std::unordered_set toRemove; + + for (const auto& [k, v] : result) + { + (void)v; + if (!inner.count(k)) + toRemove.insert(k); + } + + for (const std::string& k : toRemove) + result.erase(k); + + ++iter; + } + } + else if (auto pt = get(ty)) + { + if (pt->metatable) + { + if (auto mtable = get(*pt->metatable)) + fillMetatableProps(mtable); + } + } + else if (get(get(ty))) + { + autocompleteProps(module, typeArena, builtinTypes, rootTy, builtinTypes->stringType, indexType, nodes, result, seen); + } +} + +static void autocompleteKeywords(const std::vector& ancestry, Position position, AutocompleteEntryMap& result) +{ + LUAU_ASSERT(!ancestry.empty()); + + AstNode* node = ancestry.back(); + + if (!node->is() && node->asExpr()) + { + // This is not strictly correct. We should recommend `and` and `or` only after + // another expression, not at the start of a new one. We should only recommend + // `not` at the start of an expression. Detecting either case reliably is quite + // complex, however; this is good enough for now. + + // These are not context-sensitive keywords, so we can unconditionally assign. + result["and"] = {AutocompleteEntryKind::Keyword}; + result["or"] = {AutocompleteEntryKind::Keyword}; + result["not"] = {AutocompleteEntryKind::Keyword}; + } +} + +static void autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes, + AutocompleteEntryMap& result +) +{ + std::unordered_set seen; + autocompleteProps(module, typeArena, builtinTypes, ty, ty, indexType, nodes, result, seen); +} + +AutocompleteEntryMap autocompleteProps( + const Module& module, + TypeArena* typeArena, + NotNull builtinTypes, + TypeId ty, + PropIndexType indexType, + const std::vector& nodes +) +{ + AutocompleteEntryMap result; + autocompleteProps(module, typeArena, builtinTypes, ty, indexType, nodes, result); + return result; +} + +AutocompleteEntryMap autocompleteModuleTypes(const Module& module, const ScopePtr& scopeAtPosition, Position position, std::string_view moduleName) +{ + AutocompleteEntryMap result; + ScopePtr startScope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + for (ScopePtr& scope = startScope; scope; scope = scope->parent) + { + if (auto it = scope->importedTypeBindings.find(std::string(moduleName)); it != scope->importedTypeBindings.end()) + { + for (const auto& [name, ty] : it->second) + result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type}; + + break; + } + } + + return result; +} + +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result) +{ + if (position == node->location.begin || position == node->location.end) + { + if (auto str = node->as(); str && str->isQuoted()) + return; + else if (node->is()) + return; + } + + auto formatKey = [addQuotes](const std::string& key) + { + if (addQuotes) + return "\"" + escape(key) + "\""; + + return escape(key); + }; + + ty = follow(ty); + + if (auto ss = get(get(ty))) + { + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + else if (auto uty = get(ty)) + { + for (auto el : uty) + { + if (auto ss = get(get(el))) + result[formatKey(ss->value)] = AutocompleteEntry{AutocompleteEntryKind::String, ty, false, false, TypeCorrectKind::Correct}; + } + } +}; + +static bool canSuggestInferredType(ScopePtr scope, TypeId ty) +{ + ty = follow(ty); + + // No point in suggesting 'any', invalid to suggest others + if (get(ty) || get(ty) || get(ty) || get(ty)) + return false; + + // No syntax for unnamed tables with a metatable + if (get(ty)) + return false; + + if (const TableType* ttv = get(ty)) + { + if (ttv->name) + return true; + + if (ttv->syntheticName) + return false; + } + + // We might still have a type with cycles or one that is too long, we'll check that later + return true; +} + +// Walk complex type trees to find the element that is being edited +static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position); + +static std::optional findTypeElementAt(const AstTypeList& astTypeList, TypePackId tp, Position position) +{ + for (size_t i = 0; i < astTypeList.types.size; i++) + { + AstType* type = astTypeList.types.data[i]; + + if (type->location.containsClosed(position)) + { + auto [head, _] = flatten(tp); + + if (i < head.size()) + return findTypeElementAt(type, head[i], position); + } + } + + if (AstTypePack* argTp = astTypeList.tailType) + { + if (auto variadic = argTp->as()) + { + if (variadic->location.containsClosed(position)) + { + auto [_, tail] = flatten(tp); + + if (tail) + { + if (const VariadicTypePack* vtp = get(follow(*tail))) + return findTypeElementAt(variadic->variadicType, vtp->ty, position); + } + } + } + } + + return {}; +} + +static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position) +{ + ty = follow(ty); + + if (astType->is()) + return ty; + + if (astType->is()) + return ty; + + if (AstTypeFunction* type = astType->as()) + { + const FunctionType* ftv = get(ty); + + if (!ftv) + return {}; + + if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) + return element; + + if (auto element = findTypeElementAt(type->returnTypes, ftv->retTypes, position)) + return element; + } + + // It's possible to walk through other types like intrsection and unions if we find value in doing that + return {}; +} + +std::optional getLocalTypeInScopeAt(const Module& module, const ScopePtr& scopeAtPosition, Position position, AstLocal* local) +{ + if (ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position)) + { + for (const auto& [name, binding] : scope->bindings) + { + if (name == local) + return binding.typeId; + } + } + + return {}; +} + +template +static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) +{ + ToStringOptions opts; + opts.useLineBreaks = false; + opts.hideTableKind = true; + opts.functionTypeArguments = functionTypeArguments; + opts.scope = scope; + ToStringResult name = toStringDetailed(ty, opts); + + if (name.error || name.invalid || name.cycle || name.truncated) + return std::nullopt; + + return name.name; +} + +static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool functionTypeArguments = false) +{ + if (!canSuggestInferredType(scope, ty)) + return std::nullopt; + + return tryToStringDetailed(scope, ty, functionTypeArguments); +} + +static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) +{ + std::optional ty; + + if (topType) + ty = findTypeElementAt(topType, inferredType, position); + else + ty = inferredType; + + if (!ty) + return false; + + if (auto name = tryGetTypeNameInScope(scope, *ty)) + { + if (auto it = result.find(*name); it != result.end()) + it->second.typeCorrect = TypeCorrectKind::Correct; + else + result[*name] = AutocompleteEntry{AutocompleteEntryKind::Type, *ty, false, false, TypeCorrectKind::Correct}; + + return true; + } + + return false; +} + +static std::optional tryGetTypePackTypeAt(TypePackId tp, size_t index) +{ + auto [tpHead, tpTail] = flatten(tp); + + if (index < tpHead.size()) + return tpHead[index]; + + // Infinite tail + if (tpTail) + { + if (const VariadicTypePack* vtp = get(follow(*tpTail))) + return vtp->ty; + } + + return {}; +} + +template +std::optional returnFirstNonnullOptionOfType(const UnionType* utv) +{ + std::optional ret; + for (TypeId subTy : utv) + { + if (isNil(subTy)) + continue; + + if (const T* ftv = get(follow(subTy))) + { + if (ret.has_value()) + { + return std::nullopt; + } + ret = ftv; + } + else + { + return std::nullopt; + } + } + return ret; +} + +static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) +{ + auto typeAtPosition = findExpectedTypeAt(module, node, position); + + if (!typeAtPosition) + return std::nullopt; + + TypeId expectedType = follow(*typeAtPosition); + + if (get(expectedType)) + return true; + + if (const IntersectionType* itv = get(expectedType)) + { + return std::all_of( + begin(itv->parts), + end(itv->parts), + [](auto&& ty) + { + return get(Luau::follow(ty)) != nullptr; + } + ); + } + + if (const UnionType* utv = get(expectedType)) + return returnFirstNonnullOptionOfType(utv).has_value(); + + return false; +} + +AutocompleteEntryMap autocompleteTypeNames( + const Module& module, + const ScopePtr& scopeAtPosition, + Position& position, + const std::vector& ancestry +) +{ + AutocompleteEntryMap result; + + ScopePtr startScope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + + for (ScopePtr scope = startScope; scope; scope = scope->parent) + { + for (const auto& [name, ty] : scope->exportedTypeBindings) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Type, + ty.type, + false, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + ty.type->documentationSymbol + }; + } + + for (const auto& [name, ty] : scope->privateTypeBindings) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Type, + ty.type, + false, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + ty.type->documentationSymbol + }; + } + + for (const auto& [name, _] : scope->importedTypeBindings) + { + if (auto binding = scope->linearSearchForBinding(name, true)) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{AutocompleteEntryKind::Module, binding->typeId}; + } + } + } + + AstNode* parent = nullptr; + AstType* topType = nullptr; // TODO: rename? + + for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) + { + if (AstType* asType = (*it)->asType()) + { + topType = asType; + } + else + { + parent = *it; + break; + } + } + + if (!parent) + return result; + + if (AstStatLocal* node = parent->as()) // Try to provide inferred type of the local + { + // Look at which of the variable types we are defining + for (size_t i = 0; i < node->vars.size; i++) + { + AstLocal* var = node->vars.data[i]; + + if (var->annotation && var->annotation->location.containsClosed(position)) + { + if (node->values.size == 0) + break; + + unsigned tailPos = 0; + + // For multiple return values we will try to unpack last function call return type pack + if (i >= node->values.size) + { + tailPos = int(i) - int(node->values.size) + 1; + i = int(node->values.size) - 1; + } + + AstExpr* expr = node->values.data[i]->asExpr(); + + if (!expr) + break; + + TypeId inferredType = nullptr; + + if (AstExprCall* exprCall = expr->as()) + { + if (auto it = module.astTypes.find(exprCall->func)) + { + if (const FunctionType* ftv = get(follow(*it))) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, tailPos)) + inferredType = *ty; + } + } + } + else + { + if (tailPos != 0) + break; + + if (auto it = module.astTypes.find(expr)) + inferredType = *it; + } + + if (inferredType) + tryAddTypeCorrectSuggestion(result, startScope, topType, inferredType, position); + + break; + } + } + } + else if (AstExprFunction* node = parent->as()) + { + // For lookup inside expected function type if that's available + auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionType* + { + auto it = module.astExpectedTypes.find(expr); + + if (!it) + return nullptr; + + TypeId ty = follow(*it); + + if (const FunctionType* ftv = get(ty)) + return ftv; + + // Handle optional function type + if (const UnionType* utv = get(ty)) + { + return returnFirstNonnullOptionOfType(utv).value_or(nullptr); + } + + return nullptr; + }; + + // Find which argument type we are defining + for (size_t i = 0; i < node->args.size; i++) + { + AstLocal* arg = node->args.data[i]; + + if (arg->annotation && arg->annotation->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, i)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + // Otherwise, try to use the type inferred by typechecker + else if (auto inferredType = getLocalTypeInScopeAt(module, scopeAtPosition, position, arg)) + { + tryAddTypeCorrectSuggestion(result, startScope, topType, *inferredType, position); + } + + break; + } + } + + if (AstTypePack* argTp = node->varargAnnotation) + { + if (auto variadic = argTp->as()) + { + if (variadic->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, ~0u)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + } + } + } + + if (!node->returnAnnotation) + return result; + + for (size_t i = 0; i < node->returnAnnotation->types.size; i++) + { + AstType* ret = node->returnAnnotation->types.data[i]; + + if (ret->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, i)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + + // TODO: with additional type information, we could suggest inferred return type here + break; + } + } + + if (AstTypePack* retTp = node->returnAnnotation->tailType) + { + if (auto variadic = retTp->as()) + { + if (variadic->location.containsClosed(position)) + { + if (const FunctionType* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retTypes, ~0u)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + } + } + } + } + + return result; +} + +static bool isInLocalNames(const std::vector& ancestry, Position position) +{ + for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) + { + if (auto statLocal = (*iter)->as()) + { + for (auto var : statLocal->vars) + { + if (var->location.containsClosed(position)) + { + return true; + } + } + } + else if (auto funcExpr = (*iter)->as()) + { + if (funcExpr->argLocation && funcExpr->argLocation->contains(position)) + { + return true; + } + } + else if (auto localFunc = (*iter)->as()) + { + return localFunc->name->location.containsClosed(position); + } + else if (auto block = (*iter)->as()) + { + if (block->body.size > 0) + { + return false; + } + } + else if ((*iter)->asStat()) + { + return false; + } + } + return false; +} + +static bool isIdentifier(AstNode* node) +{ + return node->is() || node->is(); +} + +static bool isBeingDefined(const std::vector& ancestry, const Symbol& symbol) +{ + // Current set of rules only check for local binding match + if (!symbol.local) + return false; + + for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) + { + if (auto statLocal = (*iter)->as()) + { + for (auto var : statLocal->vars) + { + if (symbol.local == var) + return true; + } + } + } + + return false; +} + +template +T* extractStat(const std::vector& ancestry) +{ + AstNode* node = ancestry.size() >= 1 ? ancestry.rbegin()[0] : nullptr; + if (!node) + return nullptr; + + if (T* t = node->as()) + return t; + + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; + if (!parent) + return nullptr; + + AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; + AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; + + if (!grandParent) + return nullptr; + + if (T* t = parent->as(); t && grandParent->is()) + return t; + + if (!greatGrandParent) + return nullptr; + + if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) + return t; + + return nullptr; +} + +static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& binding, Position pos) +{ + if (symbol.local) + return binding.location.end < pos; + + // Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it + return binding.location == Location() || !binding.location.containsClosed(pos); +} + +static AutocompleteEntryMap autocompleteStatement( + const Module& module, + const std::vector& ancestry, + const ScopePtr& scopeAtPosition, + Position& position +) +{ + // This is inefficient. :( + ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + + AutocompleteEntryMap result; + + if (isInLocalNames(ancestry, position)) + { + autocompleteKeywords(ancestry, position, result); + return result; + } + + while (scope) + { + for (const auto& [name, binding] : scope->bindings) + { + if (!isBindingLegalAtCurrentPosition(name, binding, position)) + continue; + + std::string n = toString(name); + if (!result.count(n)) + result[n] = { + AutocompleteEntryKind::Binding, + binding.typeId, + binding.deprecated, + false, + TypeCorrectKind::None, + std::nullopt, + std::nullopt, + binding.documentationSymbol, + {}, + getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None) + }; + } + + scope = scope->parent; + } + + for (const auto& kw : kStatementStartingKeywords) + result.emplace(kw, AutocompleteEntry{AutocompleteEntryKind::Keyword}); + + for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) + { + if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstStatIf* statIf = (*it)->as()) + { + bool hasEnd = statIf->thenbody->hasEnd; + if (statIf->elsebody) + { + if (AstStatBlock* elseBlock = statIf->elsebody->as()) + hasEnd = elseBlock->hasEnd; + } + + if (!hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->body->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + + if (ancestry.size() >= 2) + { + AstNode* parent = ancestry.rbegin()[1]; + if (AstStatIf* statIf = parent->as()) + { + if (!statIf->elsebody || (statIf->elseLocation && statIf->elseLocation->containsClosed(position))) + { + result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + } + + if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->body->hasEnd) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + + if (ancestry.size() >= 4) + { + auto iter = ancestry.rbegin(); + if (AstStatIf* statIf = iter[3]->as(); + statIf != nullptr && !statIf->elsebody && iter[2]->is() && iter[1]->is() && isIdentifier(iter[0])) + { + result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + } + + if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->body->hasEnd) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + + return result; +} + +// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) +static bool autocompleteIfElseExpression( + const AstNode* node, + const std::vector& ancestry, + const Position& position, + AutocompleteEntryMap& outResult +) +{ + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; + if (!parent) + return false; + + if (node->is()) + { + // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else + // expression. + return true; + } + + AstExprIfElse* ifElseExpr = parent->as(); + if (!ifElseExpr || ifElseExpr->condition->location.containsClosed(position)) + { + return false; + } + else if (!ifElseExpr->hasThen) + { + outResult["then"] = {AutocompleteEntryKind::Keyword}; + return true; + } + else if (ifElseExpr->trueExpr->location.containsClosed(position)) + { + return false; + } + else if (!ifElseExpr->hasElse) + { + outResult["else"] = {AutocompleteEntryKind::Keyword}; + outResult["elseif"] = {AutocompleteEntryKind::Keyword}; + return true; + } + else + { + return false; + } +} + +static AutocompleteContext autocompleteExpression( + const Module& module, + NotNull builtinTypes, + TypeArena* typeArena, + const std::vector& ancestry, + const ScopePtr& scopeAtPosition, + Position position, + AutocompleteEntryMap& result +) +{ + LUAU_ASSERT(!ancestry.empty()); + + AstNode* node = ancestry.rbegin()[0]; + + if (node->is()) + { + if (auto it = module.astTypes.find(node->asExpr())) + autocompleteProps(module, typeArena, builtinTypes, *it, PropIndexType::Point, ancestry, result); + } + else if (autocompleteIfElseExpression(node, ancestry, position, result)) + return AutocompleteContext::Keyword; + else if (node->is()) + return AutocompleteContext::Unknown; + else + { + // This is inefficient. :( + ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(module, position); + + while (scope) + { + for (const auto& [name, binding] : scope->bindings) + { + if (!isBindingLegalAtCurrentPosition(name, binding, position)) + continue; + + if (isBeingDefined(ancestry, name)) + continue; + + std::string n = toString(name); + if (!result.count(n)) + { + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, binding.typeId); + + result[n] = { + AutocompleteEntryKind::Binding, + binding.typeId, + binding.deprecated, + false, + typeCorrect, + std::nullopt, + std::nullopt, + binding.documentationSymbol, + {}, + getParenRecommendation(binding.typeId, ancestry, typeCorrect) + }; + } + } + + scope = scope->parent; + } + + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->nilType); + TypeCorrectKind correctForTrue = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->trueType); + TypeCorrectKind correctForFalse = checkTypeCorrectKind(module, typeArena, builtinTypes, node, position, builtinTypes->falseType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForTrue}; + result["false"] = {AutocompleteEntryKind::Keyword, builtinTypes->booleanType, false, false, correctForFalse}; + result["nil"] = {AutocompleteEntryKind::Keyword, builtinTypes->nilType, false, false, correctForNil}; + result["not"] = {AutocompleteEntryKind::Keyword}; + result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + + if (auto ty = findExpectedTypeAt(module, node, position)) + autocompleteStringSingleton(*ty, true, node, position, result); + } + + return AutocompleteContext::Expression; +} + +static AutocompleteResult autocompleteExpression( + const Module& module, + NotNull builtinTypes, + TypeArena* typeArena, + const std::vector& ancestry, + const ScopePtr& scopeAtPosition, + Position position +) +{ + AutocompleteEntryMap result; + AutocompleteContext context = autocompleteExpression(module, builtinTypes, typeArena, ancestry, scopeAtPosition, position, result); + return {result, ancestry, context}; +} + +static std::optional getMethodContainingClass(const ModulePtr& module, AstExpr* funcExpr) +{ + AstExpr* parentExpr = nullptr; + if (auto indexName = funcExpr->as()) + { + parentExpr = indexName->expr; + } + else if (auto indexExpr = funcExpr->as()) + { + parentExpr = indexExpr->expr; + } + else + { + return std::nullopt; + } + + auto parentIt = module->astTypes.find(parentExpr); + if (!parentIt) + { + return std::nullopt; + } + + Luau::TypeId parentType = Luau::follow(*parentIt); + + if (auto parentClass = Luau::get(parentType)) + { + return parentClass; + } + + if (auto parentUnion = Luau::get(parentType)) + { + return returnFirstNonnullOptionOfType(parentUnion); + } + + return std::nullopt; +} + +static bool stringPartOfInterpString(const AstNode* node, Position position) +{ + const AstExprInterpString* interpString = node->as(); + if (!interpString) + { + return false; + } + + for (const AstExpr* expression : interpString->expressions) + { + if (expression->location.containsClosed(position)) + { + return false; + } + } + + return true; +} + +static bool isSimpleInterpolatedString(const AstNode* node) +{ + const AstExprInterpString* interpString = node->as(); + return interpString != nullptr && interpString->expressions.size == 0; +} + +static std::optional getStringContents(const AstNode* node) +{ + if (const AstExprConstantString* string = node->as()) + { + return std::string(string->value.data, string->value.size); + } + else if (const AstExprInterpString* interpString = node->as(); interpString && interpString->expressions.size == 0) + { + LUAU_ASSERT(interpString->strings.size == 1); + return std::string(interpString->strings.data->data, interpString->strings.data->size); + } + else + { + return std::nullopt; + } +} + +static std::optional convertRequireSuggestionsToAutocompleteEntryMap(std::optional suggestions) +{ + if (!suggestions) + return std::nullopt; + + AutocompleteEntryMap result; + for (const RequireSuggestion& suggestion : *suggestions) + { + AutocompleteEntry entry = {AutocompleteEntryKind::RequirePath}; + entry.insertText = std::move(suggestion.fullPath); + result[std::move(suggestion.label)] = std::move(entry); + } + return result; +} + +static std::optional autocompleteStringParams( + const ModulePtr& module, + const std::vector& nodes, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +) +{ + if (nodes.size() < 2) + { + return std::nullopt; + } + + if (!nodes.back()->is() && !isSimpleInterpolatedString(nodes.back()) && !nodes.back()->is()) + { + return std::nullopt; + } + + if (!nodes.back()->is()) + { + if (nodes.back()->location.end == position || nodes.back()->location.begin == position) + { + return std::nullopt; + } + } + + AstExprCall* candidate = nodes.at(nodes.size() - 2)->as(); + if (!candidate) + { + return std::nullopt; + } + + // HACK: All current instances of 'magic string' params are the first parameter of their functions, + // so we encode that here rather than putting a useless member on the FunctionType struct. + if (candidate->args.size > 1 && !candidate->args.data[0]->location.contains(position)) + { + return std::nullopt; + } + + auto it = module->astTypes.find(candidate->func); + if (!it) + { + return std::nullopt; + } + + std::optional candidateString = getStringContents(nodes.back()); + + auto performCallback = [&](const FunctionType* funcType) -> std::optional + { + for (const std::string& tag : funcType->tags) + { + if (tag == kRequireTagName && fileResolver) + { + return convertRequireSuggestionsToAutocompleteEntryMap(fileResolver->getRequireSuggestions(module->name, candidateString)); + } + if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func), candidateString)) + { + return ret; + } + } + return std::nullopt; + }; + + auto followedId = Luau::follow(*it); + if (auto functionType = Luau::get(followedId)) + { + return performCallback(functionType); + } + + if (auto intersect = Luau::get(followedId)) + { + for (TypeId part : intersect->parts) + { + if (auto candidateFunctionType = Luau::get(part)) + { + if (std::optional ret = performCallback(candidateFunctionType)) + { + return ret; + } + } + } + } + + return std::nullopt; +} + +static AutocompleteResult autocompleteWhileLoopKeywords(std::vector ancestry) +{ + AutocompleteEntryMap ret; + ret["do"] = {AutocompleteEntryKind::Keyword}; + ret["and"] = {AutocompleteEntryKind::Keyword}; + ret["or"] = {AutocompleteEntryKind::Keyword}; + return {std::move(ret), std::move(ancestry), AutocompleteContext::Keyword}; +} + +static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) +{ + std::string result = "function("; + + auto [args, tail] = Luau::flatten(funcTy.argTypes); + + bool first = true; + // Skip the implicit 'self' argument if call is indexed with ':' + for (size_t argIdx = 0; argIdx < args.size(); ++argIdx) + { + if (!first) + result += ", "; + else + first = false; + + std::string name; + if (argIdx < funcTy.argNames.size() && funcTy.argNames[argIdx]) + name = funcTy.argNames[argIdx]->name; + else + name = "a" + std::to_string(argIdx); + + if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) + result += name + ": " + *type; + else + result += name; + } + + if (tail && (Luau::isVariadic(*tail) || Luau::get(Luau::follow(*tail)))) + { + if (!first) + result += ", "; + + std::optional varArgType; + if (const VariadicTypePack* pack = get(follow(*tail))) + { + if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) + varArgType = std::move(res); + } + + if (varArgType) + result += "...: " + *varArgType; + else + result += "..."; + } + + result += ")"; + + auto [rets, retTail] = Luau::flatten(funcTy.retTypes); + if (const size_t totalRetSize = rets.size() + (retTail ? 1 : 0); totalRetSize > 0) + { + if (std::optional returnTypes = tryToStringDetailed(scope, funcTy.retTypes, true)) + { + result += ": "; + bool wrap = totalRetSize != 1; + if (wrap) + result += "("; + result += *returnTypes; + if (wrap) + result += ")"; + } + } + result += " end"; + return result; +} + +static std::optional makeAnonymousAutofilled( + const ModulePtr& module, + const ScopePtr& scopeAtPosition, + Position position, + const AstNode* node, + const std::vector& ancestry +) +{ + const AstExprCall* call = node->as(); + if (!call && ancestry.size() > 1) + call = ancestry[ancestry.size() - 2]->as(); + + if (!call) + return std::nullopt; + + if (!call->location.containsClosed(position) || call->func->location.containsClosed(position)) + return std::nullopt; + + TypeId* typeIter = module->astTypes.find(call->func); + if (!typeIter) + return std::nullopt; + + const FunctionType* outerFunction = get(follow(*typeIter)); + if (!outerFunction) + return std::nullopt; + + size_t argument = 0; + for (size_t i = 0; i < call->args.size; ++i) + { + if (call->args.data[i]->location.containsClosed(position)) + { + argument = i; + break; + } + } + + if (call->self) + argument++; + + std::optional argType; + auto [args, tail] = flatten(outerFunction->argTypes); + if (argument < args.size()) + argType = args[argument]; + + if (!argType) + return std::nullopt; + + TypeId followed = follow(*argType); + const FunctionType* type = get(followed); + if (!type) + { + if (const UnionType* unionType = get(followed)) + { + if (std::optional nonnullFunction = returnFirstNonnullOptionOfType(unionType)) + type = *nonnullFunction; + } + } + + if (!type) + return std::nullopt; + + const ScopePtr scope = FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete ? scopeAtPosition : findScopeAtPosition(*module, position); + if (!scope) + return std::nullopt; + + AutocompleteEntry entry; + entry.kind = AutocompleteEntryKind::GeneratedFunction; + entry.typeCorrect = TypeCorrectKind::Correct; + entry.type = argType; + entry.insertText = makeAnonymous(scope, *type); + return std::make_optional(std::move(entry)); +} + +AutocompleteResult autocomplete_( + const ModulePtr& module, + NotNull builtinTypes, + TypeArena* typeArena, + std::vector& ancestry, + Scope* globalScope, // [TODO] This is unused argument, do we really need this? + const ScopePtr& scopeAtPosition, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +) +{ + LUAU_TIMETRACE_SCOPE("Luau::autocomplete_", "AutocompleteCore"); + AstNode* node = ancestry.back(); + + AstExprConstantNil dummy{Location{}}; + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; + + // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node + if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) + { + ancestry.pop_back(); + + node = ancestry.back(); + parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : &dummy; + } + + if (auto indexName = node->as()) + { + auto it = module->astTypes.find(indexName->expr); + if (!it) + return {}; + + TypeId ty = follow(*it); + PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; + + return {autocompleteProps(*module, typeArena, builtinTypes, ty, indexType, ancestry), ancestry, AutocompleteContext::Property}; + } + else if (auto typeReference = node->as()) + { + if (typeReference->prefix) + return {autocompleteModuleTypes(*module, scopeAtPosition, position, typeReference->prefix->value), ancestry, AutocompleteContext::Type}; + else + return {autocompleteTypeNames(*module, scopeAtPosition, position, ancestry), ancestry, AutocompleteContext::Type}; + } + else if (node->is()) + { + return {autocompleteTypeNames(*module, scopeAtPosition, position, ancestry), ancestry, AutocompleteContext::Type}; + } + else if (AstStatLocal* statLocal = node->as()) + { + if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) + return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Unknown}; + else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + else + return {}; + } + + else if (AstStatFor* statFor = extractStat(ancestry)) + { + if (!statFor->hasDo || position < statFor->doLocation.begin) + { + if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || + (statFor->step && statFor->step->location.containsClosed(position))) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + + if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + return {}; + } + + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + } + + else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) + { + if (!statForIn->hasIn || position <= statForIn->inLocation.begin) + { + AstLocal* lastName = statForIn->vars.data[statForIn->vars.size - 1]; + if (lastName->name == kParseNameError || lastName->location.containsClosed(position)) + { + // Here we are either working with a missing binding (as would be the case in a bare "for" keyword) or + // the cursor is still touching a binding name. The user is still typing a new name, so we should not offer + // any suggestions. + return {}; + } + + return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + } + + if (!statForIn->hasDo || position <= statForIn->doLocation.begin) + { + LUAU_ASSERT(statForIn->values.size > 0); + AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; + + if (lastExpr->location.containsClosed(position)) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + + if (position > lastExpr->location.end) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + + return {}; // Not sure what this means + } + } + else if (AstStatForIn* statForIn = extractStat(ancestry)) + { + // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. + // ex "for f in f do" + if (!statForIn->hasDo) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + } + + else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) + { + if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) + { + return autocompleteWhileLoopKeywords(ancestry); + } + + if (!statWhile->hasDo || position < statWhile->doLocation.begin) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + + if (statWhile->hasDo && position > statWhile->doLocation.end) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + } + + else if (AstStatWhile* statWhile = extractStat(ancestry); + (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition && + !statWhile->condition->location.containsClosed(position))) + { + return autocompleteWhileLoopKeywords(ancestry); + } + else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) + { + return { + {{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, + ancestry, + AutocompleteContext::Keyword + }; + } + else if (AstStatIf* statIf = parent->as(); statIf && node->is()) + { + if (statIf->condition->is()) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; + } + else if (AstStatIf* statIf = extractStat(ancestry); statIf && + (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) && + (statIf->condition && !statIf->condition->location.containsClosed(position))) + { + AutocompleteEntryMap ret; + ret["then"] = {AutocompleteEntryKind::Keyword}; + ret["and"] = {AutocompleteEntryKind::Keyword}; + ret["or"] = {AutocompleteEntryKind::Keyword}; + return {std::move(ret), ancestry, AutocompleteContext::Keyword}; + } + else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) + return autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + else if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + else if (AstExprTable* exprTable = parent->as(); + exprTable && (node->is() || node->is() || node->is())) + { + for (const auto& [kind, key, value] : exprTable->items) + { + // If item doesn't have a key, maybe the value is actually the key + if (key ? key == node : node->is() && value == node) + { + if (auto it = module->astExpectedTypes.find(exprTable)) + { + auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); + + if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*nodeIt, !node->is(), node, position, result); + + if (!key) + { + // If there is "no key," it may be that the user + // intends for the current token to be the key, but + // has yet to type the `=` sign. + // + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + { + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); + } + } + + // Remove keys that are already completed + for (const auto& item : exprTable->items) + { + if (!item.key) + continue; + + if (auto stringKey = item.key->as()) + result.erase(std::string(stringKey->value.data, stringKey->value.size)); + } + + // If we know for sure that a key is being written, do not offer general expression suggestions + if (!key) + autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position, result); + + return {result, ancestry, AutocompleteContext::Property}; + } + + break; + } + } + } + else if (AstExprTable* exprTable = node->as()) + { + AutocompleteEntryMap result; + + if (auto it = module->astExpectedTypes.find(exprTable)) + { + result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); + + // If the key type is a union of singleton strings, + // suggest those too. + if (auto ttv = get(follow(*it)); ttv && ttv->indexer) + { + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); + } + + // Remove keys that are already completed + for (const auto& item : exprTable->items) + { + if (!item.key) + continue; + + if (auto stringKey = item.key->as()) + result.erase(std::string(stringKey->value.data, stringKey->value.size)); + } + } + + // Also offer general expression suggestions + autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position, result); + + return {result, ancestry, AutocompleteContext::Property}; + } + else if (isIdentifier(node) && (parent->is() || parent->is())) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + + if (std::optional ret = autocompleteStringParams(module, ancestry, position, fileResolver, callback)) + { + return {*ret, ancestry, AutocompleteContext::String}; + } + else if (node->is() || isSimpleInterpolatedString(node)) + { + AutocompleteEntryMap result; + + if (auto it = module->astExpectedTypes.find(node->asExpr())) + autocompleteStringSingleton(*it, false, node, position, result); + + if (ancestry.size() >= 2) + { + if (auto idxExpr = ancestry.at(ancestry.size() - 2)->as()) + { + if (auto it = module->astTypes.find(idxExpr->expr)) + autocompleteProps(*module, typeArena, builtinTypes, follow(*it), PropIndexType::Point, ancestry, result); + } + else if (auto binExpr = ancestry.at(ancestry.size() - 2)->as()) + { + if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) + { + if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) + autocompleteStringSingleton(*it, false, node, position, result); + } + } + } + + return {result, ancestry, AutocompleteContext::String}; + } + else if (stringPartOfInterpString(node, position)) + { + // We're not a simple interpolated string, we're something like `a{"b"}@1`, and we + // can't know what to format to + AutocompleteEntryMap map; + return {map, ancestry, AutocompleteContext::String}; + } + + if (node->is()) + return {}; + + if (node->asExpr()) + { + AutocompleteResult ret = autocompleteExpression(*module, builtinTypes, typeArena, ancestry, scopeAtPosition, position); + if (std::optional generated = makeAnonymousAutofilled(module, scopeAtPosition, position, node, ancestry)) + ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); + return ret; + } + else if (node->asStat()) + return {autocompleteStatement(*module, ancestry, scopeAtPosition, position), ancestry, AutocompleteContext::Statement}; + + return {}; +} + + +} // namespace Luau diff --git a/Analysis/src/AutocompleteCore.h b/Analysis/src/AutocompleteCore.h new file mode 100644 index 00000000..d4264da2 --- /dev/null +++ b/Analysis/src/AutocompleteCore.h @@ -0,0 +1,27 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/AutocompleteTypes.h" + +namespace Luau +{ +struct Module; +struct FileResolver; + +using ModulePtr = std::shared_ptr; +using ModuleName = std::string; + + +AutocompleteResult autocomplete_( + const ModulePtr& module, + NotNull builtinTypes, + TypeArena* typeArena, + std::vector& ancestry, + Scope* globalScope, + const ScopePtr& scopeAtPosition, + Position position, + FileResolver* fileResolver, + StringCompletionCallback callback +); + +} // namespace Luau diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 21ae0f11..2a93195f 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -2,6 +2,9 @@ #include "Luau/BuiltinDefinitions.h" #include "Luau/Ast.h" +#include "Luau/Clone.h" +#include "Luau/DenseHash.h" +#include "Luau/Error.h" #include "Luau/Frontend.h" #include "Luau/Symbol.h" #include "Luau/Common.h" @@ -25,47 +28,93 @@ * about a function that takes any number of values, but where each value must have some specific type. */ -LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAGVARIABLE(LuauDCRMagicFunctionTypeChecker, false); +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(LuauStringFormatErrorSuppression) +LUAU_FASTFLAGVARIABLE(LuauTableCloneClonesType3) +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) +LUAU_FASTFLAGVARIABLE(LuauFreezeIgnorePersistent) +LUAU_FASTFLAGVARIABLE(LuauFollowTableFreeze) namespace Luau { -static std::optional> magicFunctionSelect( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionSetMetaTable( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionAssert( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionPack( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); -static std::optional> magicFunctionRequire( - TypeChecker& typechecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -); +struct MagicSelect final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; +struct MagicSetMetatable final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; -static bool dcrMagicFunctionSelect(MagicFunctionCallContext context); -static bool dcrMagicFunctionRequire(MagicFunctionCallContext context); -static bool dcrMagicFunctionPack(MagicFunctionCallContext context); +struct MagicAssert final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicPack final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicRequire final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicClone final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicFreeze final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicFormat final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; + bool typeCheck(const MagicFunctionTypeCheckContext& ctx) override; +}; + +struct MagicMatch final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicGmatch final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; + +struct MagicFind final : MagicFunction +{ + std::optional> + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) override; + bool infer(const MagicFunctionCallContext& ctx) override; +}; TypeId makeUnion(TypeArena& arena, std::vector&& types) { @@ -160,34 +209,10 @@ TypeId makeFunction( return arena.addType(std::move(ftv)); } -void attachMagicFunction(TypeId ty, MagicFunction fn) +void attachMagicFunction(TypeId ty, std::shared_ptr magic) { if (auto ftv = getMutable(ty)) - ftv->magicFunction = fn; - else - LUAU_ASSERT(!"Got a non functional type"); -} - -void attachDcrMagicFunction(TypeId ty, DcrMagicFunction fn) -{ - if (auto ftv = getMutable(ty)) - ftv->dcrMagicFunction = fn; - else - LUAU_ASSERT(!"Got a non functional type"); -} - -void attachDcrMagicRefinement(TypeId ty, DcrMagicRefinement fn) -{ - if (auto ftv = getMutable(ty)) - ftv->dcrMagicRefinement = fn; - else - LUAU_ASSERT(!"Got a non functional type"); -} - -void attachDcrMagicFunctionTypeCheck(TypeId ty, DcrMagicFunctionTypeCheck fn) -{ - if (auto ftv = getMutable(ty)) - ftv->dcrMagicTypeCheck = fn; + ftv->magic = std::move(magic); else LUAU_ASSERT(!"Got a non functional type"); } @@ -293,6 +318,28 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC addGlobalBinding(globals, "string", it->second.type(), "@luau"); + // Setup 'vector' metatable + if (auto it = globals.globalScope->exportedTypeBindings.find("vector"); it != globals.globalScope->exportedTypeBindings.end()) + { + TypeId vectorTy = it->second.type; + ClassType* vectorCls = getMutable(vectorTy); + + vectorCls->metatable = arena.addType(TableType{{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + TableType* metatableTy = Luau::getMutable(vectorCls->metatable); + + metatableTy->props["__add"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})}; + metatableTy->props["__sub"] = {makeFunction(arena, vectorTy, {vectorTy}, {vectorTy})}; + metatableTy->props["__unm"] = {makeFunction(arena, vectorTy, {}, {vectorTy})}; + + std::initializer_list mulOverloads{ + makeFunction(arena, vectorTy, {vectorTy}, {vectorTy}), + makeFunction(arena, vectorTy, {builtinTypes->numberType}, {vectorTy}), + }; + metatableTy->props["__mul"] = {makeIntersection(arena, mulOverloads)}; + metatableTy->props["__div"] = {makeIntersection(arena, mulOverloads)}; + metatableTy->props["__idiv"] = {makeIntersection(arena, mulOverloads)}; + } + // next(t: Table, i: K?) -> (K?, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); @@ -363,7 +410,7 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC } } - attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); + attachMagicFunction(getGlobalBinding(globals, "assert"), std::make_shared()); if (FFlag::LuauSolverV2) { @@ -379,9 +426,8 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC addGlobalBinding(globals, "assert", assertTy, "@luau"); } - attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); + attachMagicFunction(getGlobalBinding(globals, "setmetatable"), std::make_shared()); + attachMagicFunction(getGlobalBinding(globals, "select"), std::make_shared()); if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) { @@ -394,8 +440,10 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC // but it'll be ok for now. TypeId genericTy = arena.addType(GenericType{"T"}); TypePackId thePack = arena.addTypePack({genericTy}); + TypeId idTyWithMagic = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack}); + ttv->props["freeze"] = makeProperty(idTyWithMagic, "@luau/global/table.freeze"); + TypeId idTy = arena.addType(FunctionType{{genericTy}, {}, thePack, thePack}); - ttv->props["freeze"] = makeProperty(idTy, "@luau/global/table.freeze"); ttv->props["clone"] = makeProperty(idTy, "@luau/global/table.clone"); } else @@ -410,12 +458,15 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC ttv->props["foreach"].deprecated = true; ttv->props["foreachi"].deprecated = true; - attachMagicFunction(ttv->props["pack"].type(), magicFunctionPack); - attachDcrMagicFunction(ttv->props["pack"].type(), dcrMagicFunctionPack); + attachMagicFunction(ttv->props["pack"].type(), std::make_shared()); + if (FFlag::LuauTableCloneClonesType3) + attachMagicFunction(ttv->props["clone"].type(), std::make_shared()); + attachMagicFunction(ttv->props["freeze"].type(), std::make_shared()); } - attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); + TypeId requireTy = getGlobalBinding(globals, "require"); + attachTag(requireTy, kRequireTagName); + attachMagicFunction(requireTy, std::make_shared()); } static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) @@ -454,7 +505,7 @@ static std::vector parseFormatString(NotNull builtinTypes, return result; } -std::optional> magicFunctionFormat( +std::optional> MagicFormat::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -504,7 +555,7 @@ std::optional> magicFunctionFormat( return WithPredicate{arena.addTypePack({typechecker.stringType})}; } -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +bool MagicFormat::infer(const MagicFunctionCallContext& context) { TypeArena* arena = context.solver->arena; @@ -548,7 +599,7 @@ static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) return true; } -static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext context) +bool MagicFormat::typeCheck(const MagicFunctionTypeCheckContext& context) { AstExprConstantString* fmt = nullptr; if (auto index = context.callSite->func->as(); index && context.callSite->self) @@ -563,7 +614,10 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex fmt = context.callSite->args.data[0]->as(); if (!fmt) - return; + { + context.typechecker->reportError(CountMismatch{1, std::nullopt, 0, CountMismatch::Arg, true, "string.format"}, context.callSite->location); + return true; + } std::vector expected = parseFormatString(context.builtinTypes, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(context.arguments); @@ -579,12 +633,33 @@ static void dcrMagicFunctionTypeCheckFormat(MagicFunctionTypeCheckContext contex Location location = context.callSite->args.data[i + (calledWithSelf ? 0 : paramOffset)]->location; // use subtyping instead here SubtypingResult result = context.typechecker->subtyping->isSubtype(actualTy, expectedTy, context.checkScope); + if (!result.isSubtype) { - Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); - context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); + if (FFlag::LuauStringFormatErrorSuppression) + { + switch (shouldSuppressErrors(NotNull{&context.typechecker->normalizer}, actualTy)) + { + case ErrorSuppression::Suppress: + break; + case ErrorSuppression::NormalizationFailed: + break; + case ErrorSuppression::DoNotSuppress: + Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); + + if (!reasonings.suppressed) + context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); + } + } + else + { + Reasonings reasonings = context.typechecker->explainReasonings(actualTy, expectedTy, location, result); + context.typechecker->reportError(TypeMismatch{expectedTy, actualTy, reasonings.toString()}, location); + } } } + + return true; } static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) @@ -647,7 +722,7 @@ static std::vector parsePatternString(NotNull builtinTypes return result; } -static std::optional> magicFunctionGmatch( +std::optional> MagicGmatch::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -683,7 +758,7 @@ static std::optional> magicFunctionGmatch( return WithPredicate{arena.addTypePack({iteratorType})}; } -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +bool MagicGmatch::infer(const MagicFunctionCallContext& context) { const auto& [params, tail] = flatten(context.arguments); @@ -716,7 +791,7 @@ static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) return true; } -static std::optional> magicFunctionMatch( +std::optional> MagicMatch::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -756,7 +831,7 @@ static std::optional> magicFunctionMatch( return WithPredicate{returnList}; } -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +bool MagicMatch::infer(const MagicFunctionCallContext& context) { const auto& [params, tail] = flatten(context.arguments); @@ -792,7 +867,7 @@ static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) return true; } -static std::optional> magicFunctionFind( +std::optional> MagicFind::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -850,7 +925,7 @@ static std::optional> magicFunctionFind( return WithPredicate{returnList}; } -static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +bool MagicFind::infer(const MagicFunctionCallContext& context) { const auto& [params, tail] = flatten(context.arguments); @@ -927,12 +1002,9 @@ TypeId makeStringMetatable(NotNull builtinTypes) FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, variadicTailPack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; formatFTV.isCheckedFunction = true; const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - if (FFlag::LuauDCRMagicFunctionTypeChecker) - attachDcrMagicFunctionTypeCheck(formatFn, dcrMagicFunctionTypeCheckFormat); + attachMagicFunction(formatFn, std::make_shared()); const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}, /* checked */ true); @@ -946,16 +1018,14 @@ TypeId makeStringMetatable(NotNull builtinTypes) makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}, /* checked */ false); const TypeId gmatchFunc = makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}, /* checked */ true); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + attachMagicFunction(gmatchFunc, std::make_shared()); FunctionType matchFuncTy{ arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}) }; matchFuncTy.isCheckedFunction = true; const TypeId matchFunc = arena->addType(matchFuncTy); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + attachMagicFunction(matchFunc, std::make_shared()); FunctionType findFuncTy{ arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), @@ -963,8 +1033,7 @@ TypeId makeStringMetatable(NotNull builtinTypes) }; findFuncTy.isCheckedFunction = true; const TypeId findFunc = arena->addType(findFuncTy); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + attachMagicFunction(findFunc, std::make_shared()); // string.byte : string -> number? -> number? -> ...number FunctionType stringDotByte{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList}; @@ -1025,7 +1094,7 @@ TypeId makeStringMetatable(NotNull builtinTypes) return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } -static std::optional> magicFunctionSelect( +std::optional> MagicSelect::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1070,7 +1139,7 @@ static std::optional> magicFunctionSelect( return std::nullopt; } -static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) +bool MagicSelect::infer(const MagicFunctionCallContext& context) { if (context.callSite->args.size <= 0) { @@ -1115,7 +1184,7 @@ static bool dcrMagicFunctionSelect(MagicFunctionCallContext context) return false; } -static std::optional> magicFunctionSetMetaTable( +std::optional> MagicSetMetatable::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1197,7 +1266,12 @@ static std::optional> magicFunctionSetMetaTable( return WithPredicate{arena.addTypePack({target})}; } -static std::optional> magicFunctionAssert( +bool MagicSetMetatable::infer(const MagicFunctionCallContext&) +{ + return false; +} + +std::optional> MagicAssert::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1231,7 +1305,12 @@ static std::optional> magicFunctionAssert( return WithPredicate{arena.addTypePack(TypePack{std::move(head), tail})}; } -static std::optional> magicFunctionPack( +bool MagicAssert::infer(const MagicFunctionCallContext&) +{ + return false; +} + +std::optional> MagicPack::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1274,7 +1353,7 @@ static std::optional> magicFunctionPack( return WithPredicate{arena.addTypePack({packedTable})}; } -static bool dcrMagicFunctionPack(MagicFunctionCallContext context) +bool MagicPack::infer(const MagicFunctionCallContext& context) { TypeArena* arena = context.solver->arena; @@ -1314,6 +1393,162 @@ static bool dcrMagicFunctionPack(MagicFunctionCallContext context) return true; } +std::optional> MagicClone::handleOldSolver( + TypeChecker& typechecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate +) +{ + LUAU_ASSERT(FFlag::LuauTableCloneClonesType3); + + auto [paramPack, _predicates] = withPredicate; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + const auto& [paramTypes, paramTail] = flatten(paramPack); + if (paramTypes.empty() || expr.args.size == 0) + { + typechecker.reportError(expr.argLocation, CountMismatch{1, std::nullopt, 0}); + return std::nullopt; + } + + TypeId inputType = follow(paramTypes[0]); + + if (!get(inputType)) + return std::nullopt; + + CloneState cloneState{typechecker.builtinTypes}; + TypeId resultType = shallowClone(inputType, arena, cloneState); + + TypePackId clonedTypePack = arena.addTypePack({resultType}); + return WithPredicate{clonedTypePack}; +} + +bool MagicClone::infer(const MagicFunctionCallContext& context) +{ + LUAU_ASSERT(FFlag::LuauTableCloneClonesType3); + + TypeArena* arena = context.solver->arena; + + const auto& [paramTypes, paramTail] = flatten(context.arguments); + if (paramTypes.empty() || context.callSite->args.size == 0) + { + context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation); + return false; + } + + TypeId inputType = follow(paramTypes[0]); + + if (!get(inputType)) + return false; + + CloneState cloneState{context.solver->builtinTypes}; + TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent); + + if (auto tableType = getMutable(resultType)) + { + tableType->scope = context.constraint->scope.get(); + } + + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(context.constraint->scope.get(), resultType); + + TypePackId clonedTypePack = arena->addTypePack({resultType}); + asMutable(context.result)->ty.emplace(clonedTypePack); + + return true; +} + +static std::optional freezeTable(TypeId inputType, const MagicFunctionCallContext& context) +{ + TypeArena* arena = context.solver->arena; + if (FFlag::LuauFollowTableFreeze) + inputType = follow(inputType); + if (auto mt = get(inputType)) + { + std::optional frozenTable = freezeTable(mt->table, context); + + if (!frozenTable) + return std::nullopt; + + TypeId resultType = arena->addType(MetatableType{*frozenTable, mt->metatable, mt->syntheticName}); + + return resultType; + } + + if (get(inputType)) + { + // Clone the input type, this will become our final result type after we mutate it. + CloneState cloneState{context.solver->builtinTypes}; + TypeId resultType = shallowClone(inputType, *arena, cloneState, /* ignorePersistent */ FFlag::LuauFreezeIgnorePersistent); + auto tableTy = getMutable(resultType); + // `clone` should not break this. + LUAU_ASSERT(tableTy); + tableTy->state = TableState::Sealed; + + // We'll mutate the table to make every property type read-only. + for (auto iter = tableTy->props.begin(); iter != tableTy->props.end();) + { + if (iter->second.isWriteOnly()) + iter = tableTy->props.erase(iter); + else + { + iter->second.writeTy = std::nullopt; + iter++; + } + } + + return resultType; + } + + context.solver->reportError(TypeMismatch{context.solver->builtinTypes->tableType, inputType}, context.callSite->argLocation); + return std::nullopt; +} + +std::optional> MagicFreeze:: + handleOldSolver(struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, WithPredicate) +{ + return std::nullopt; +} + +bool MagicFreeze::infer(const MagicFunctionCallContext& context) +{ + TypeArena* arena = context.solver->arena; + const DataFlowGraph* dfg = context.solver->dfg.get(); + Scope* scope = context.constraint->scope.get(); + + const auto& [paramTypes, paramTail] = extendTypePack(*arena, context.solver->builtinTypes, context.arguments, 1); + if (paramTypes.empty() || context.callSite->args.size == 0) + { + context.solver->reportError(CountMismatch{1, std::nullopt, 0}, context.callSite->argLocation); + return false; + } + + TypeId inputType = follow(paramTypes[0]); + + AstExpr* targetExpr = context.callSite->args.data[0]; + std::optional resultDef = dfg->getDefOptional(targetExpr); + std::optional resultTy = resultDef ? scope->lookup(*resultDef) : std::nullopt; + + std::optional frozenType = freezeTable(inputType, context); + + if (!frozenType) + { + if (resultTy) + asMutable(*resultTy)->ty.emplace(context.solver->builtinTypes->errorType); + asMutable(context.result)->ty.emplace(context.solver->builtinTypes->errorTypePack); + + return true; + } + + if (resultTy) + asMutable(*resultTy)->ty.emplace(*frozenType); + asMutable(context.result)->ty.emplace(arena->addTypePack({*frozenType})); + + return true; +} + static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) { // require(foo.parent.bar) will technically work, but it depends on legacy goop that @@ -1336,7 +1571,7 @@ static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) return good; } -static std::optional> magicFunctionRequire( +std::optional> MagicRequire::handleOldSolver( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, @@ -1382,7 +1617,7 @@ static bool checkRequirePathDcr(NotNull solver, AstExpr* expr) return good; } -static bool dcrMagicFunctionRequire(MagicFunctionCallContext context) +bool MagicRequire::infer(const MagicFunctionCallContext& context) { if (context.callSite->args.size != 1) { @@ -1405,4 +1640,52 @@ static bool dcrMagicFunctionRequire(MagicFunctionCallContext context) return false; } +bool matchSetMetatable(const AstExprCall& call) +{ + const char* smt = "setmetatable"; + + if (call.args.size != 2) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != smt) + return false; + + return true; +} + +bool matchTableFreeze(const AstExprCall& call) +{ + if (call.args.size < 1) + return false; + + const AstExprIndexName* index = call.func->as(); + if (!index || index->index != "freeze") + return false; + + const AstExprGlobal* global = index->expr->as(); + if (!global || global->name != "table") + return false; + + return true; +} + +bool matchAssert(const AstExprCall& call) +{ + if (call.args.size < 1) + return false; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != "assert") + return false; + + return true; +} + +bool shouldTypestateForFirstArgument(const AstExprCall& call) +{ + // TODO: magic function for setmetatable and assert and then add them + return matchTableFreeze(call); +} + } // namespace Luau diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 4af3e7f8..6309fa7c 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -7,6 +7,7 @@ #include "Luau/Unifiable.h" LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauFreezeIgnorePersistent) // For each `Luau::clone` call, we will clone only up to N amount of types _and_ packs, as controlled by this limit. LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000) @@ -38,14 +39,26 @@ class TypeCloner NotNull types; NotNull packs; + TypeId forceTy = nullptr; + TypePackId forceTp = nullptr; + int steps = 0; public: - TypeCloner(NotNull arena, NotNull builtinTypes, NotNull types, NotNull packs) + TypeCloner( + NotNull arena, + NotNull builtinTypes, + NotNull types, + NotNull packs, + TypeId forceTy, + TypePackId forceTp + ) : arena(arena) , builtinTypes(builtinTypes) , types(types) , packs(packs) + , forceTy(forceTy) + , forceTp(forceTp) { } @@ -112,7 +125,7 @@ private: ty = follow(ty, FollowOption::DisableLazyTypeThunks); if (auto it = types->find(ty); it != types->end()) return it->second; - else if (ty->persistent) + else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy)) return ty; return std::nullopt; } @@ -122,7 +135,7 @@ private: tp = follow(tp); if (auto it = packs->find(tp); it != packs->end()) return it->second; - else if (tp->persistent) + else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp)) return tp; return std::nullopt; } @@ -140,7 +153,7 @@ private: } } -private: +public: TypeId shallowClone(TypeId ty) { // We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s. @@ -148,7 +161,7 @@ private: if (auto clone = find(ty)) return *clone; - else if (ty->persistent) + else if (ty->persistent && (!FFlag::LuauFreezeIgnorePersistent || ty != forceTy)) return ty; TypeId target = arena->addType(ty->ty); @@ -174,7 +187,7 @@ private: if (auto clone = find(tp)) return *clone; - else if (tp->persistent) + else if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || tp != forceTp)) return tp; TypePackId target = arena->addTypePack(tp->ty); @@ -189,6 +202,7 @@ private: return target; } +private: Property shallowClone(const Property& p) { if (FFlag::LuauSolverV2) @@ -256,8 +270,7 @@ private: LUAU_ASSERT(!"Item holds neither TypeId nor TypePackId when enqueuing its children?"); } - // ErrorType and ErrorTypePack is an alias to this type. - void cloneChildren(Unifiable::Error* t) + void cloneChildren(ErrorType* t) { // noop. } @@ -359,6 +372,11 @@ private: // noop. } + void cloneChildren(NoRefineType* t) + { + // noop. + } + void cloneChildren(UnionType* t) { for (TypeId& ty : t->options) @@ -422,6 +440,11 @@ private: t->boundTo = shallowClone(t->boundTo); } + void cloneChildren(ErrorTypePack* t) + { + // noop. + } + void cloneChildren(VariadicTypePack* t) { t->ty = shallowClone(t->ty); @@ -448,12 +471,46 @@ private: } // namespace +TypePackId shallowClone(TypePackId tp, TypeArena& dest, CloneState& cloneState, bool ignorePersistent) +{ + if (tp->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent)) + return tp; + + TypeCloner cloner{ + NotNull{&dest}, + cloneState.builtinTypes, + NotNull{&cloneState.seenTypes}, + NotNull{&cloneState.seenTypePacks}, + nullptr, + FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? tp : nullptr + }; + + return cloner.shallowClone(tp); +} + +TypeId shallowClone(TypeId typeId, TypeArena& dest, CloneState& cloneState, bool ignorePersistent) +{ + if (typeId->persistent && (!FFlag::LuauFreezeIgnorePersistent || !ignorePersistent)) + return typeId; + + TypeCloner cloner{ + NotNull{&dest}, + cloneState.builtinTypes, + NotNull{&cloneState.seenTypes}, + NotNull{&cloneState.seenTypePacks}, + FFlag::LuauFreezeIgnorePersistent && ignorePersistent ? typeId : nullptr, + nullptr + }; + + return cloner.shallowClone(typeId); +} + TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) { if (tp->persistent) return tp; - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; return cloner.clone(tp); } @@ -462,13 +519,13 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (typeId->persistent) return typeId; - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; return cloner.clone(typeId); } TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) { - TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; TypeFun copy = typeFun; @@ -493,4 +550,18 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) return copy; } +Binding clone(const Binding& binding, TypeArena& dest, CloneState& cloneState) +{ + TypeCloner cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}, nullptr, nullptr}; + + Binding b; + b.deprecated = binding.deprecated; + b.deprecatedSuggestion = binding.deprecatedSuggestion; + b.documentationSymbol = binding.documentationSymbol; + b.location = binding.location; + b.typeId = cloner.clone(binding.typeId); + + return b; +} + } // namespace Luau diff --git a/Analysis/src/Constraint.cpp b/Analysis/src/Constraint.cpp index a62879fa..e62a3f18 100644 --- a/Analysis/src/Constraint.cpp +++ b/Analysis/src/Constraint.cpp @@ -3,6 +3,8 @@ #include "Luau/Constraint.h" #include "Luau/VisitType.h" +LUAU_FASTFLAG(DebugLuauGreedyGeneralization) + namespace Luau { @@ -46,6 +48,20 @@ struct ReferenceCountInitializer : TypeOnceVisitor // ClassTypes never contain free types. return false; } + + bool visit(TypeId, const TypeFunctionInstanceType&) override + { + // We do not consider reference counted types that are inside a type + // function to be part of the reachable reference counted types. + // Otherwise, code can be constructed in just the right way such + // that two type functions both claim to mutate a free type, which + // prevents either type function from trying to generalize it, so + // we potentially get stuck. + // + // The default behavior here is `true` for "visit the child types" + // of this type, hence: + return false; + } }; bool isReferenceCountedType(const TypeId typ) @@ -97,6 +113,11 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const { rci.traverse(fchc->argsPack); } + else if (auto fcc = get(*this); fcc && FFlag::DebugLuauGreedyGeneralization) + { + rci.traverse(fcc->fn); + rci.traverse(fcc->argsPack); + } else if (auto ptc = get(*this)) { rci.traverse(ptc->freeType); @@ -104,7 +125,8 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const else if (auto hpc = get(*this)) { rci.traverse(hpc->resultType); - // `HasPropConstraints` should not mutate `subjectType`. + if (FFlag::DebugLuauGreedyGeneralization) + rci.traverse(hpc->subjectType); } else if (auto hic = get(*this)) { @@ -132,6 +154,10 @@ DenseHashSet Constraint::getMaybeMutatedFreeTypes() const { rci.traverse(rpc->tp); } + else if (auto tcc = get(*this)) + { + rci.traverse(tcc->exprType); + } return types; } diff --git a/Analysis/src/ConstraintGenerator.cpp b/Analysis/src/ConstraintGenerator.cpp index face6825..962d11fa 100644 --- a/Analysis/src/ConstraintGenerator.cpp +++ b/Analysis/src/ConstraintGenerator.cpp @@ -2,13 +2,15 @@ #include "Luau/ConstraintGenerator.h" #include "Luau/Ast.h" -#include "Luau/Def.h" +#include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" #include "Luau/Constraint.h" #include "Luau/ControlFlow.h" #include "Luau/DcrLogger.h" +#include "Luau/Def.h" #include "Luau/DenseHash.h" #include "Luau/ModuleResolver.h" +#include "Luau/NotNull.h" #include "Luau/RecursionCounter.h" #include "Luau/Refinement.h" #include "Luau/Scope.h" @@ -26,9 +28,18 @@ #include #include -LUAU_FASTINT(LuauCheckRecursionLimit); -LUAU_FASTFLAG(DebugLuauLogSolverToJson); -LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTINT(LuauCheckRecursionLimit) +LUAU_FASTFLAG(DebugLuauLogSolverToJson) +LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) +LUAU_FASTFLAG(DebugLuauGreedyGeneralization) + +LUAU_FASTFLAGVARIABLE(LuauTrackInteriorFreeTypesOnScope) +LUAU_FASTFLAGVARIABLE(LuauDeferBidirectionalInferenceForTableAssignment) +LUAU_FASTFLAGVARIABLE(LuauUngeneralizedTypesForRecursiveFunctions) + +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) +LUAU_FASTFLAGVARIABLE(LuauInferLocalTypesInMultipleAssignments) namespace Luau { @@ -53,20 +64,6 @@ static std::optional matchRequire(const AstExprCall& call) return call.args.data[0]; } -static bool matchSetmetatable(const AstExprCall& call) -{ - const char* smt = "setmetatable"; - - if (call.args.size != 2) - return false; - - const AstExprGlobal* funcAsGlobal = call.func->as(); - if (!funcAsGlobal || funcAsGlobal->name != smt) - return false; - - return true; -} - struct TypeGuard { bool isTypeof; @@ -74,13 +71,11 @@ struct TypeGuard std::string type; }; -static std::optional matchTypeGuard(const AstExprBinary* binary) +static std::optional matchTypeGuard(const AstExprBinary::Op op, AstExpr* left, AstExpr* right) { - if (binary->op != AstExprBinary::CompareEq && binary->op != AstExprBinary::CompareNe) + if (op != AstExprBinary::CompareEq && op != AstExprBinary::CompareNe) return std::nullopt; - AstExpr* left = binary->left; - AstExpr* right = binary->right; if (right->is()) std::swap(left, right); @@ -109,18 +104,6 @@ static std::optional matchTypeGuard(const AstExprBinary* binary) }; } -static bool matchAssert(const AstExprCall& call) -{ - if (call.args.size < 1) - return false; - - const AstExprGlobal* funcAsGlobal = call.func->as(); - if (!funcAsGlobal || funcAsGlobal->name != "assert") - return false; - - return true; -} - namespace { @@ -191,6 +174,8 @@ bool hasFreeType(TypeId ty) ConstraintGenerator::ConstraintGenerator( ModulePtr module, NotNull normalizer, + NotNull simplifier, + NotNull typeFunctionRuntime, NotNull moduleResolver, NotNull builtinTypes, NotNull ice, @@ -206,6 +191,8 @@ ConstraintGenerator::ConstraintGenerator( , rootScope(nullptr) , dfg(dfg) , normalizer(normalizer) + , simplifier(simplifier) + , typeFunctionRuntime(typeFunctionRuntime) , moduleResolver(moduleResolver) , ice(ice) , globalScope(globalScope) @@ -244,8 +231,17 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) Checkpoint end = checkpoint(this); TypeId result = arena->addType(BlockedType{}); - NotNull genConstraint = - addConstraint(scope, block->location, GeneralizationConstraint{result, moduleFnTy, std::move(interiorTypes.back())}); + NotNull genConstraint = addConstraint( + scope, + block->location, + GeneralizationConstraint{ + result, moduleFnTy, FFlag::LuauTrackInteriorFreeTypesOnScope ? std::vector{} : std::move(interiorTypes.back()) + } + ); + + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + scope->interiorFreeTypes = std::move(interiorTypes.back()); + getMutable(result)->setOwner(genConstraint); forEachConstraint( start, @@ -273,7 +269,7 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) d = follow(d); if (d == ty) continue; - domainTy = simplifyUnion(builtinTypes, arena, domainTy, d).result; + domainTy = simplifyUnion(scope, Location{}, domainTy, d); } LUAU_ASSERT(get(ty)); @@ -281,9 +277,52 @@ void ConstraintGenerator::visitModuleRoot(AstStatBlock* block) } } +void ConstraintGenerator::visitFragmentRoot(const ScopePtr& resumeScope, AstStatBlock* block) +{ + // We prepopulate global data in the resumeScope to avoid writing data into the old modules scopes + prepopulateGlobalScopeForFragmentTypecheck(globalScope, resumeScope, block); + // Pre + // We need to pop the interior types, + interiorTypes.emplace_back(); + visitBlockWithoutChildScope(resumeScope, block); + // Post + interiorTypes.pop_back(); + + fillInInferredBindings(resumeScope, block); + + if (logger) + logger->captureGenerationModule(module); + + for (const auto& [ty, domain] : localTypes) + { + // FIXME: This isn't the most efficient thing. + TypeId domainTy = builtinTypes->neverType; + for (TypeId d : domain) + { + d = follow(d); + if (d == ty) + continue; + domainTy = simplifyUnion(resumeScope, resumeScope->location, domainTy, d); + } + + LUAU_ASSERT(get(ty)); + asMutable(ty)->ty.emplace(domainTy); + } +} + + TypeId ConstraintGenerator::freshType(const ScopePtr& scope) { - return Luau::freshType(arena, builtinTypes, scope.get()); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + { + auto ft = Luau::freshType(arena, builtinTypes, scope.get()); + interiorTypes.back().push_back(ft); + return ft; + } + else + { + return Luau::freshType(arena, builtinTypes, scope.get()); + } } TypePackId ConstraintGenerator::freshTypePack(const ScopePtr& scope) @@ -643,17 +682,10 @@ void ConstraintGenerator::applyRefinements(const ScopePtr& scope, Location locat addConstraint(scope, location, c); } -ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) +void ConstraintGenerator::checkAliases(const ScopePtr& scope, AstStatBlock* block) { - RecursionCounter counter{&recursionCount}; - - if (recursionCount >= FInt::LuauCheckRecursionLimit) - { - reportCodeTooComplex(block->location); - return ControlFlow::None; - } - std::unordered_map aliasDefinitionLocations; + std::unordered_map classDefinitionLocations; // In order to enable mutually-recursive type aliases, we need to // populate the type bindings before we actually check any of the @@ -729,22 +761,140 @@ ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& sco quantifiedTypeParams.push_back(genericTy); } - TypeId typeFunctionTy = arena->addType(TypeFunctionInstanceType{ - NotNull{&builtinTypeFunctions().userFunc}, - std::move(typeParams), - {}, - function->name, - function->body, - }); + if (std::optional error = typeFunctionRuntime->registerFunction(function)) + reportError(function->location, GenericError{*error}); + + UserDefinedFunctionData udtfData; + + udtfData.owner = module; + udtfData.definition = function; + + TypeId typeFunctionTy = arena->addType( + TypeFunctionInstanceType{NotNull{&builtinTypeFunctions().userFunc}, std::move(typeParams), {}, function->name, udtfData} + ); TypeFun typeFunction{std::move(quantifiedTypeParams), typeFunctionTy}; // Set type bindings and definition locations for this user-defined type function - scope->privateTypeBindings[function->name.value] = std::move(typeFunction); + if (function->exported) + scope->exportedTypeBindings[function->name.value] = std::move(typeFunction); + else + scope->privateTypeBindings[function->name.value] = std::move(typeFunction); + aliasDefinitionLocations[function->name.value] = function->location; } + else if (auto classDeclaration = stat->as()) + { + if (scope->exportedTypeBindings.count(classDeclaration->name.value)) + { + auto it = classDefinitionLocations.find(classDeclaration->name.value); + LUAU_ASSERT(it != classDefinitionLocations.end()); + reportError(classDeclaration->location, DuplicateTypeDefinition{classDeclaration->name.value, it->second}); + continue; + } + + // A class might have no name if the code is syntactically + // illegal. We mustn't prepopulate anything in this case. + if (classDeclaration->name == kParseNameError) + continue; + + ScopePtr defnScope = childScope(classDeclaration, scope); + + TypeId initialType = arena->addType(BlockedType{}); + TypeFun initialFun{initialType}; + scope->exportedTypeBindings[classDeclaration->name.value] = std::move(initialFun); + + classDefinitionLocations[classDeclaration->name.value] = classDeclaration->location; + } } + // Additional pass for user-defined type functions to fill in their environments completely + for (AstStat* stat : block->body) + { + if (auto function = stat->as()) + { + // Find the type function we have already created + TypeFunctionInstanceType* mainTypeFun = nullptr; + + if (auto it = scope->privateTypeBindings.find(function->name.value); it != scope->privateTypeBindings.end()) + mainTypeFun = getMutable(it->second.type); + + if (!mainTypeFun) + { + if (auto it = scope->exportedTypeBindings.find(function->name.value); it != scope->exportedTypeBindings.end()) + mainTypeFun = getMutable(it->second.type); + } + + // Fill it with all visible type functions + if (mainTypeFun) + { + UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData; + size_t level = 0; + + for (Scope* curr = scope.get(); curr; curr = curr->parent.get()) + { + for (auto& [name, tf] : curr->privateTypeBindings) + { + if (userFuncData.environment.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment[name] = std::make_pair(ty->userFuncData.definition, level); + } + + for (auto& [name, tf] : curr->exportedTypeBindings) + { + if (userFuncData.environment.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment[name] = std::make_pair(ty->userFuncData.definition, level); + } + + level++; + } + } + else if (mainTypeFun) + { + UserDefinedFunctionData& userFuncData = mainTypeFun->userFuncData; + + for (Scope* curr = scope.get(); curr; curr = curr->parent.get()) + { + for (auto& [name, tf] : curr->privateTypeBindings) + { + if (userFuncData.environment_DEPRECATED.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment_DEPRECATED[name] = ty->userFuncData.definition; + } + + for (auto& [name, tf] : curr->exportedTypeBindings) + { + if (userFuncData.environment_DEPRECATED.find(name)) + continue; + + if (auto ty = get(tf.type); ty && ty->userFuncData.definition) + userFuncData.environment_DEPRECATED[name] = ty->userFuncData.definition; + } + } + } + } + } +} + +ControlFlow ConstraintGenerator::visitBlockWithoutChildScope(const ScopePtr& scope, AstStatBlock* block) +{ + RecursionCounter counter{&recursionCount}; + + if (recursionCount >= FInt::LuauCheckRecursionLimit) + { + reportCodeTooComplex(block->location); + return ControlFlow::None; + } + + checkAliases(scope, block); + std::optional firstControlFlow; for (AstStat* stat : block->body) { @@ -873,37 +1023,49 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat TypePackId rvaluePack = checkPack(scope, statLocal->values, expectedTypes).tp; Checkpoint end = checkpoint(this); - if (hasAnnotation) + if (FFlag::LuauInferLocalTypesInMultipleAssignments) { + std::vector deferredTypes; + auto [head, tail] = flatten(rvaluePack); + for (size_t i = 0; i < statLocal->vars.size; ++i) { LUAU_ASSERT(get(assignees[i])); TypeIds* localDomain = localTypes.find(assignees[i]); LUAU_ASSERT(localDomain); - localDomain->insert(annotatedTypes[i]); + + if (statLocal->vars.data[i]->annotation) + { + localDomain->insert(annotatedTypes[i]); + } + else + { + if (i < head.size()) + { + localDomain->insert(head[i]); + } + else if (tail) + { + deferredTypes.push_back(arena->addType(BlockedType{})); + localDomain->insert(deferredTypes.back()); + } + else + { + localDomain->insert(builtinTypes->nilType); + } + } } - TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); - addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); - } - else - { - std::vector valueTypes; - valueTypes.reserve(statLocal->vars.size); - - auto [head, tail] = flatten(rvaluePack); - - if (head.size() >= statLocal->vars.size) + if (hasAnnotation) { - for (size_t i = 0; i < statLocal->vars.size; ++i) - valueTypes.push_back(head[i]); + TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); + addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); } - else - { - for (size_t i = 0; i < statLocal->vars.size; ++i) - valueTypes.push_back(arena->addType(BlockedType{})); - auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); + if (!deferredTypes.empty()) + { + LUAU_ASSERT(tail); + NotNull uc = addConstraint(scope, statLocal->location, UnpackConstraint{deferredTypes, *tail}); forEachConstraint( start, @@ -911,20 +1073,69 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat this, [&uc](const ConstraintPtr& runBefore) { - uc->dependencies.push_back(NotNull{runBefore.get()}); + uc->dependencies.emplace_back(runBefore.get()); } ); - for (TypeId t : valueTypes) + for (TypeId t : deferredTypes) getMutable(t)->setOwner(uc); } - - for (size_t i = 0; i < statLocal->vars.size; ++i) + } + else + { + if (hasAnnotation) { - LUAU_ASSERT(get(assignees[i])); - TypeIds* localDomain = localTypes.find(assignees[i]); - LUAU_ASSERT(localDomain); - localDomain->insert(valueTypes[i]); + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + LUAU_ASSERT(get(assignees[i])); + TypeIds* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->insert(annotatedTypes[i]); + } + + TypePackId annotatedPack = arena->addTypePack(std::move(annotatedTypes)); + addConstraint(scope, statLocal->location, PackSubtypeConstraint{rvaluePack, annotatedPack}); + } + else + { + std::vector valueTypes; + valueTypes.reserve(statLocal->vars.size); + + auto [head, tail] = flatten(rvaluePack); + + if (head.size() >= statLocal->vars.size) + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(head[i]); + } + else + { + for (size_t i = 0; i < statLocal->vars.size; ++i) + valueTypes.push_back(arena->addType(BlockedType{})); + + auto uc = addConstraint(scope, statLocal->location, UnpackConstraint{valueTypes, rvaluePack}); + + forEachConstraint( + start, + end, + this, + [&uc](const ConstraintPtr& runBefore) + { + uc->dependencies.push_back(NotNull{runBefore.get()}); + } + ); + + for (TypeId t : valueTypes) + getMutable(t)->setOwner(uc); + } + + for (size_t i = 0; i < statLocal->vars.size; ++i) + { + LUAU_ASSERT(get(assignees[i])); + TypeIds* localDomain = localTypes.find(assignees[i]); + LUAU_ASSERT(localDomain); + localDomain->insert(valueTypes[i]); + } } } @@ -937,10 +1148,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatLocal* stat addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); else if (const AstExprCall* call = value->as()) { - if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") - { + if (matchSetMetatable(*call)) addConstraint(scope, value->location, NameConstraint{*firstValueType, var->name.value, /*synthetic*/ true}); - } } } @@ -1172,6 +1381,28 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f FunctionSignature sig = checkFunctionSignature(scope, function->func, /* expectedType */ std::nullopt, function->name->location); bool sigFullyDefined = !hasFreeType(sig.signature); + DefId def = dfg->getDef(function->name); + + if (FFlag::LuauUngeneralizedTypesForRecursiveFunctions) + { + if (AstExprLocal* localName = function->name->as()) + { + sig.bodyScope->bindings[localName->local] = Binding{sig.signature, localName->location}; + sig.bodyScope->lvalueTypes[def] = sig.signature; + sig.bodyScope->rvalueRefinements[def] = sig.signature; + } + else if (AstExprGlobal* globalName = function->name->as()) + { + sig.bodyScope->bindings[globalName->name] = Binding{sig.signature, globalName->location}; + sig.bodyScope->lvalueTypes[def] = sig.signature; + sig.bodyScope->rvalueRefinements[def] = sig.signature; + } + else if (AstExprIndexName* indexName = function->name->as()) + { + sig.bodyScope->rvalueRefinements[def] = sig.signature; + } + } + checkFunctionBody(sig.bodyScope, function->func); Checkpoint end = checkpoint(this); @@ -1205,7 +1436,6 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatFunction* f ); } - DefId def = dfg->getDef(function->name); std::optional existingFunctionTy = follow(lookup(scope, function->name->location, def)); if (AstExprLocal* localName = function->name->as()) @@ -1333,8 +1563,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatAssign* ass ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatCompoundAssign* assign) { - AstExprBinary binop = AstExprBinary{assign->location, assign->op, assign->var, assign->value}; - TypeId resultTy = check(scope, &binop).ty; + TypeId resultTy = checkAstExprBinary(scope, assign->location, assign->op, assign->var, assign->value, std::nullopt).ty; module->astCompoundAssignResultTypes[assign] = resultTy; TypeId lhsType = check(scope, assign->var).ty; @@ -1448,20 +1677,6 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeAlias* ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatTypeFunction* function) { - // If a type function with the same name was already defined, we skip over - auto bindingIt = scope->privateTypeBindings.find(function->name.value); - if (bindingIt == scope->privateTypeBindings.end()) - return ControlFlow::None; - - TypeFun typeFunction = bindingIt->second; - - // Adding typeAliasExpansionConstraint on user-defined type function for the constraint solver - if (auto typeFunctionTy = get(typeFunction.type)) - { - TypeId expansionTy = arena->addType(PendingExpansionType{{}, function->name, typeFunctionTy->typeArguments, typeFunctionTy->packArguments}); - addConstraint(scope, function->location, TypeAliasExpansionConstraint{/* target */ expansionTy}); - } - return ControlFlow::None; } @@ -1492,6 +1707,11 @@ static bool isMetamethod(const Name& name) ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) { + // If a class with the same name was already defined, we skip over + auto bindingIt = scope->exportedTypeBindings.find(declaredClass->name.value); + if (bindingIt == scope->exportedTypeBindings.end()) + return ControlFlow::None; + std::optional superTy = std::make_optional(builtinTypes->classType); if (declaredClass->superName) { @@ -1506,7 +1726,7 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas // We don't have generic classes, so this assertion _should_ never be hit. LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); - superTy = lookupType->type; + superTy = follow(lookupType->type); if (!get(follow(*superTy))) { @@ -1529,7 +1749,8 @@ ControlFlow ConstraintGenerator::visit(const ScopePtr& scope, AstStatDeclareClas ctv->metatable = metaTy; - scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + TypeId classBindTy = bindingIt->second.type; + emplaceType(asMutable(classBindTy), classTy); if (declaredClass->indexer) { @@ -1833,7 +2054,7 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* Checkpoint argEndCheckpoint = checkpoint(this); - if (matchSetmetatable(*call)) + if (matchSetMetatable(*call)) { TypePack argTailPack; if (argTail && args.size() < 2) @@ -1908,72 +2129,80 @@ InferencePack ConstraintGenerator::checkPack(const ScopePtr& scope, AstExprCall* return InferencePack{arena->addTypePack({resultTy}), {refinementArena.variadic(returnRefinements)}}; } - else + + if (shouldTypestateForFirstArgument(*call) && call->args.size > 0 && isLValue(call->args.data[0])) { - if (matchAssert(*call) && !argumentRefinements.empty()) - applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); + AstExpr* targetExpr = call->args.data[0]; + auto resultTy = arena->addType(BlockedType{}); - // TODO: How do expectedTypes play into this? Do they? - TypePackId rets = arena->addTypePack(BlockedTypePack{}); - TypePackId argPack = addTypePack(std::move(args), argTail); - FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); - - /* - * To make bidirectional type checking work, we need to solve these constraints in a particular order: - * - * 1. Solve the function type - * 2. Propagate type information from the function type to the argument types - * 3. Solve the argument types - * 4. Solve the call - */ - - NotNull checkConstraint = addConstraint( - scope, - call->func->location, - FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}} - ); - - forEachConstraint( - funcBeginCheckpoint, - funcEndCheckpoint, - this, - [checkConstraint](const ConstraintPtr& constraint) - { - checkConstraint->dependencies.emplace_back(constraint.get()); - } - ); - - NotNull callConstraint = addConstraint( - scope, - call->func->location, - FunctionCallConstraint{ - fnType, - argPack, - rets, - call, - std::move(discriminantTypes), - &module->astOverloadResolvedTypes, - } - ); - - getMutable(rets)->owner = callConstraint.get(); - - callConstraint->dependencies.push_back(checkConstraint); - - forEachConstraint( - argBeginCheckpoint, - argEndCheckpoint, - this, - [checkConstraint, callConstraint](const ConstraintPtr& constraint) - { - constraint->dependencies.emplace_back(checkConstraint); - - callConstraint->dependencies.emplace_back(constraint.get()); - } - ); - - return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; + if (auto def = dfg->getDefOptional(targetExpr)) + { + scope->lvalueTypes[*def] = resultTy; + scope->rvalueRefinements[*def] = resultTy; + } } + + if (matchAssert(*call) && !argumentRefinements.empty()) + applyRefinements(scope, call->args.data[0]->location, argumentRefinements[0]); + + // TODO: How do expectedTypes play into this? Do they? + TypePackId rets = arena->addTypePack(BlockedTypePack{}); + TypePackId argPack = addTypePack(std::move(args), argTail); + FunctionType ftv(TypeLevel{}, scope.get(), argPack, rets, std::nullopt, call->self); + + /* + * To make bidirectional type checking work, we need to solve these constraints in a particular order: + * + * 1. Solve the function type + * 2. Propagate type information from the function type to the argument types + * 3. Solve the argument types + * 4. Solve the call + */ + + NotNull checkConstraint = addConstraint( + scope, call->func->location, FunctionCheckConstraint{fnType, argPack, call, NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}} + ); + + forEachConstraint( + funcBeginCheckpoint, + funcEndCheckpoint, + this, + [checkConstraint](const ConstraintPtr& constraint) + { + checkConstraint->dependencies.emplace_back(constraint.get()); + } + ); + + NotNull callConstraint = addConstraint( + scope, + call->func->location, + FunctionCallConstraint{ + fnType, + argPack, + rets, + call, + std::move(discriminantTypes), + &module->astOverloadResolvedTypes, + } + ); + + getMutable(rets)->owner = callConstraint.get(); + + callConstraint->dependencies.push_back(checkConstraint); + + forEachConstraint( + argBeginCheckpoint, + argEndCheckpoint, + this, + [checkConstraint, callConstraint](const ConstraintPtr& constraint) + { + constraint->dependencies.emplace_back(checkConstraint); + + callConstraint->dependencies.emplace_back(constraint.get()); + } + ); + + return InferencePack{rets, {refinementArena.variadic(returnRefinements)}}; } Inference ConstraintGenerator::check(const ScopePtr& scope, AstExpr* expr, std::optional expectedType, bool forceSingleton, bool generalize) @@ -2050,7 +2279,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantStrin if (forceSingleton) return Inference{arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}})}; - FreeType ft = FreeType{scope.get()}; + FreeType ft = + FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()}; ft.lowerBound = arena->addType(SingletonType{StringSingleton{std::string{string->value.data, string->value.size}}}); ft.upperBound = builtinTypes->stringType; const TypeId freeTy = arena->addType(ft); @@ -2064,7 +2294,8 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprConstantBool* if (forceSingleton) return Inference{singletonType}; - FreeType ft = FreeType{scope.get()}; + FreeType ft = + FFlag::LuauFreeTypesMustHaveBounds ? FreeType{scope.get(), builtinTypes->neverType, builtinTypes->unknownType} : FreeType{scope.get()}; ft.lowerBound = singletonType; ft.upperBound = builtinTypes->booleanType; const TypeId freeTy = arena->addType(ft); @@ -2226,8 +2457,17 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprFunction* fun Checkpoint endCheckpoint = checkpoint(this); TypeId generalizedTy = arena->addType(BlockedType{}); - NotNull gc = - addConstraint(sig.signatureScope, func->location, GeneralizationConstraint{generalizedTy, sig.signature, std::move(interiorTypes.back())}); + NotNull gc = addConstraint( + sig.signatureScope, + func->location, + GeneralizationConstraint{ + generalizedTy, sig.signature, FFlag::LuauTrackInteriorFreeTypesOnScope ? std::vector{} : std::move(interiorTypes.back()) + } + ); + + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + sig.signatureScope->interiorFreeTypes = std::move(interiorTypes.back()); + getMutable(generalizedTy)->setOwner(gc); interiorTypes.pop_back(); @@ -2288,63 +2528,75 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprUnary* unary) Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binary, std::optional expectedType) { - auto [leftType, rightType, refinement] = checkBinary(scope, binary, expectedType); + return checkAstExprBinary(scope, binary->location, binary->op, binary->left, binary->right, expectedType); +} - switch (binary->op) +Inference ConstraintGenerator::checkAstExprBinary( + const ScopePtr& scope, + const Location& location, + AstExprBinary::Op op, + AstExpr* left, + AstExpr* right, + std::optional expectedType +) +{ + auto [leftType, rightType, refinement] = checkBinary(scope, op, left, right, expectedType); + + switch (op) { case AstExprBinary::Op::Add: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().addFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().addFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Sub: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().subFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().subFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Mul: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().mulFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().mulFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Div: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().divFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().divFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::FloorDiv: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().idivFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().idivFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Pow: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().powFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().powFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Mod: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().modFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().modFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Concat: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().concatFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().concatFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::And: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().andFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().andFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Or: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().orFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().orFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareLt: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().ltFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().ltFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareGe: @@ -2354,13 +2606,13 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar {rightType, leftType}, // lua decided that `__ge(a, b)` is instead just `__lt(b, a)` {}, scope, - binary->location + location ); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareLe: { - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().leFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().leFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareGt: @@ -2370,15 +2622,15 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar {rightType, leftType}, // lua decided that `__gt(a, b)` is instead just `__le(b, a)` {}, scope, - binary->location + location ); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::CompareEq: case AstExprBinary::Op::CompareNe: { - DefId leftDef = dfg->getDef(binary->left); - DefId rightDef = dfg->getDef(binary->right); + DefId leftDef = dfg->getDef(left); + DefId rightDef = dfg->getDef(right); bool leftSubscripted = containsSubscriptedDefinition(leftDef); bool rightSubscripted = containsSubscriptedDefinition(rightDef); @@ -2387,11 +2639,11 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprBinary* binar // we cannot add nil in this case because then we will blindly accept comparisons that we should not. } else if (leftSubscripted) - leftType = makeUnion(scope, binary->location, leftType, builtinTypes->nilType); + leftType = makeUnion(scope, location, leftType, builtinTypes->nilType); else if (rightSubscripted) - rightType = makeUnion(scope, binary->location, rightType, builtinTypes->nilType); + rightType = makeUnion(scope, location, rightType, builtinTypes->nilType); - TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().eqFunc, {leftType, rightType}, {}, scope, binary->location); + TypeId resultType = createTypeFunctionInstance(builtinTypeFunctions().eqFunc, {leftType, rightType}, {}, scope, location); return Inference{resultType, std::move(refinement)}; } case AstExprBinary::Op::Op__Count: @@ -2437,44 +2689,46 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprInterpString* std::tuple ConstraintGenerator::checkBinary( const ScopePtr& scope, - AstExprBinary* binary, + AstExprBinary::Op op, + AstExpr* left, + AstExpr* right, std::optional expectedType ) { - if (binary->op == AstExprBinary::And) + if (op == AstExprBinary::And) { std::optional relaxedExpectedLhs; if (expectedType) relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); - auto [leftType, leftRefinement] = check(scope, binary->left, relaxedExpectedLhs); + auto [leftType, leftRefinement] = check(scope, left, relaxedExpectedLhs); - ScopePtr rightScope = childScope(binary->right, scope); - applyRefinements(rightScope, binary->right->location, leftRefinement); - auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); + ScopePtr rightScope = childScope(right, scope); + applyRefinements(rightScope, right->location, leftRefinement); + auto [rightType, rightRefinement] = check(rightScope, right, expectedType); return {leftType, rightType, refinementArena.conjunction(leftRefinement, rightRefinement)}; } - else if (binary->op == AstExprBinary::Or) + else if (op == AstExprBinary::Or) { std::optional relaxedExpectedLhs; if (expectedType) relaxedExpectedLhs = arena->addType(UnionType{{builtinTypes->falsyType, *expectedType}}); - auto [leftType, leftRefinement] = check(scope, binary->left, relaxedExpectedLhs); + auto [leftType, leftRefinement] = check(scope, left, relaxedExpectedLhs); - ScopePtr rightScope = childScope(binary->right, scope); - applyRefinements(rightScope, binary->right->location, refinementArena.negation(leftRefinement)); - auto [rightType, rightRefinement] = check(rightScope, binary->right, expectedType); + ScopePtr rightScope = childScope(right, scope); + applyRefinements(rightScope, right->location, refinementArena.negation(leftRefinement)); + auto [rightType, rightRefinement] = check(rightScope, right, expectedType); return {leftType, rightType, refinementArena.disjunction(leftRefinement, rightRefinement)}; } - else if (auto typeguard = matchTypeGuard(binary)) + else if (auto typeguard = matchTypeGuard(op, left, right)) { - TypeId leftType = check(scope, binary->left).ty; - TypeId rightType = check(scope, binary->right).ty; + TypeId leftType = check(scope, left).ty; + TypeId rightType = check(scope, right).ty; const RefinementKey* key = dfg->getRefinementKey(typeguard->target); if (!key) @@ -2511,29 +2765,29 @@ std::tuple ConstraintGenerator::checkBinary( TypeId ty = follow(typeFun->type); // We're only interested in the root class of any classes. - if (auto ctv = get(ty); ctv && ctv->parent == builtinTypes->classType) + if (auto ctv = get(ty); ctv && (ctv->parent == builtinTypes->classType || hasTag(ty, kTypeofRootTag))) discriminantTy = ty; } RefinementId proposition = refinementArena.proposition(key, discriminantTy); - if (binary->op == AstExprBinary::CompareEq) + if (op == AstExprBinary::CompareEq) return {leftType, rightType, proposition}; - else if (binary->op == AstExprBinary::CompareNe) + else if (op == AstExprBinary::CompareNe) return {leftType, rightType, refinementArena.negation(proposition)}; else ice->ice("matchTypeGuard should only return a Some under `==` or `~=`!"); } - else if (binary->op == AstExprBinary::CompareEq || binary->op == AstExprBinary::CompareNe) + else if (op == AstExprBinary::CompareEq || op == AstExprBinary::CompareNe) { // We are checking a binary expression of the form a op b // Just because a op b is epxected to return a bool, doesn't mean a, b are expected to be bools too - TypeId leftType = check(scope, binary->left, {}, true).ty; - TypeId rightType = check(scope, binary->right, {}, true).ty; + TypeId leftType = check(scope, left, {}, true).ty; + TypeId rightType = check(scope, right, {}, true).ty; - RefinementId leftRefinement = refinementArena.proposition(dfg->getRefinementKey(binary->left), rightType); - RefinementId rightRefinement = refinementArena.proposition(dfg->getRefinementKey(binary->right), leftType); + RefinementId leftRefinement = refinementArena.proposition(dfg->getRefinementKey(left), rightType); + RefinementId rightRefinement = refinementArena.proposition(dfg->getRefinementKey(right), leftType); - if (binary->op == AstExprBinary::CompareNe) + if (op == AstExprBinary::CompareNe) { leftRefinement = refinementArena.negation(leftRefinement); rightRefinement = refinementArena.negation(rightRefinement); @@ -2543,8 +2797,8 @@ std::tuple ConstraintGenerator::checkBinary( } else { - TypeId leftType = check(scope, binary->left).ty; - TypeId rightType = check(scope, binary->right).ty; + TypeId leftType = check(scope, left).ty; + TypeId rightType = check(scope, right).ty; return {leftType, rightType, nullptr}; } } @@ -2561,7 +2815,13 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExpr* expr, Type visitLValue(scope, e, rhsType); else if (auto e = expr->as()) { - // Nothing? + // If we end up with some sort of error expression in an lvalue + // position, at least go and check the expressions so that when + // we visit them later, there aren't any invalid assumptions. + for (auto subExpr : e->expressions) + { + check(scope, subExpr); + } } else ice->ice("Unexpected lvalue expression", expr->location); @@ -2593,7 +2853,7 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprLocal* local case ErrorSuppression::DoNotSuppress: break; case ErrorSuppression::Suppress: - ty = simplifyUnion(builtinTypes, arena, *ty, builtinTypes->errorType).result; + ty = simplifyUnion(scope, local->location, *ty, builtinTypes->errorType); break; case ErrorSuppression::NormalizationFailed: reportError(local->local->annotation->location, NormalizationTooComplex{}); @@ -2621,6 +2881,11 @@ void ConstraintGenerator::visitLValue(const ScopePtr& scope, AstExprGlobal* glob DefId def = dfg->getDef(global); rootScope->lvalueTypes[def] = rhsType; + // Sketchy: We're specifically looking for BlockedTypes that were + // initially created by ConstraintGenerator::prepopulateGlobalScope. + if (auto bt = get(follow(*annotatedTy)); bt && !bt->getOwner()) + emplaceType(asMutable(*annotatedTy), rhsType); + addConstraint(scope, global->location, SubtypeConstraint{rhsType, *annotatedTy}); } } @@ -2674,6 +2939,7 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, ttv->state = TableState::Unsealed; ttv->definitionModuleName = module->name; + ttv->definitionLocation = expr->location; ttv->scope = scope.get(); interiorTypes.back().push_back(ty); @@ -2739,11 +3005,47 @@ Inference ConstraintGenerator::check(const ScopePtr& scope, AstExprTable* expr, if (expectedType) { - Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice}; - std::vector toBlock; - matchLiteralType( - NotNull{&module->astTypes}, NotNull{&module->astExpectedTypes}, builtinTypes, arena, NotNull{&unifier}, *expectedType, ty, expr, toBlock - ); + if (FFlag::LuauDeferBidirectionalInferenceForTableAssignment) + { + addConstraint( + scope, + expr->location, + TableCheckConstraint{ + *expectedType, + ty, + expr, + NotNull{&module->astTypes}, + NotNull{&module->astExpectedTypes}, + } + ); + } + else + { + Unifier2 unifier{arena, builtinTypes, NotNull{scope.get()}, ice}; + std::vector toBlock; + // This logic is incomplete as we want to re-run this + // _after_ blocked types have resolved, but this + // allows us to do some bidirectional inference. + toBlock = findBlockedTypesIn(expr, NotNull{&module->astTypes}); + if (toBlock.empty()) + { + matchLiteralType( + NotNull{&module->astTypes}, + NotNull{&module->astExpectedTypes}, + builtinTypes, + arena, + NotNull{&unifier}, + *expectedType, + ty, + expr, + toBlock + ); + // The visitor we ran prior should ensure that there are no + // blocked types that we would encounter while matching on + // this expression. + LUAU_ASSERT(toBlock.empty()); + } + } } return Inference{ty}; @@ -2934,6 +3236,9 @@ ConstraintGenerator::FunctionSignature ConstraintGenerator::checkFunctionSignatu if (expectedType && get(*expectedType)) bindFreeType(*expectedType, actualFunctionType); + if (FFlag::DebugLuauGreedyGeneralization) + scopeToFunction[signatureScope.get()] = actualFunctionType; + return { /* signature */ actualFunctionType, /* signatureScope */ signatureScope, @@ -2949,225 +3254,262 @@ void ConstraintGenerator::checkFunctionBody(const ScopePtr& scope, AstExprFuncti addConstraint(scope, fn->location, PackSubtypeConstraint{builtinTypes->emptyTypePack, scope->returnType}); } +TypeId ConstraintGenerator::resolveReferenceType( + const ScopePtr& scope, + AstType* ty, + AstTypeReference* ref, + bool inTypeArguments, + bool replaceErrorWithFresh +) +{ + TypeId result = nullptr; + + if (FFlag::DebugLuauMagicTypes) + { + if (ref->name == "_luau_ice") + ice->ice("_luau_ice encountered", ty->location); + else if (ref->name == "_luau_print") + { + if (ref->parameters.size != 1 || !ref->parameters.data[0].type) + { + reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); + module->astResolvedTypes[ty] = builtinTypes->errorRecoveryType(); + return builtinTypes->errorRecoveryType(); + } + else + return resolveType(scope, ref->parameters.data[0].type, inTypeArguments); + } + } + + std::optional alias; + + if (ref->prefix.has_value()) + { + alias = scope->lookupImportedType(ref->prefix->value, ref->name.value); + } + else + { + alias = scope->lookupType(ref->name.value); + } + + if (alias.has_value()) + { + // If the alias is not generic, we don't need to set up a blocked type and an instantiation constraint + if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty() && !ref->hasParameterList) + { + result = alias->type; + } + else + { + std::vector parameters; + std::vector packParameters; + + for (const AstTypeOrPack& p : ref->parameters) + { + // We do not enforce the ordering of types vs. type packs here; + // that is done in the parser. + if (p.type) + { + parameters.push_back(resolveType(scope, p.type, /* inTypeArguments */ true)); + } + else if (p.typePack) + { + TypePackId tp = resolveTypePack(scope, p.typePack, /*inTypeArguments*/ true); + + // If we need more regular types, we can use single element type packs to fill those in + if (parameters.size() < alias->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) + parameters.push_back(*first(tp)); + else + packParameters.push_back(tp); + } + else + { + // This indicates a parser bug: one of these two pointers + // should be set. + LUAU_ASSERT(false); + } + } + + result = arena->addType(PendingExpansionType{ref->prefix, ref->name, parameters, packParameters}); + + // If we're not in a type argument context, we need to create a constraint that expands this. + // The dispatching of the above constraint will queue up additional constraints for nested + // type function applications. + if (!inTypeArguments) + addConstraint(scope, ty->location, TypeAliasExpansionConstraint{/* target */ result}); + } + } + else + { + result = builtinTypes->errorRecoveryType(); + if (replaceErrorWithFresh) + result = freshType(scope); + } + + return result; +} + +TypeId ConstraintGenerator::resolveTableType(const ScopePtr& scope, AstType* ty, AstTypeTable* tab, bool inTypeArguments, bool replaceErrorWithFresh) +{ + TableType::Props props; + std::optional indexer; + + for (const AstTableProp& prop : tab->props) + { + TypeId propTy = resolveType(scope, prop.type, inTypeArguments); + + Property& p = props[prop.name.value]; + p.typeLocation = prop.location; + + switch (prop.access) + { + case AstTableAccess::ReadWrite: + p.readTy = propTy; + p.writeTy = propTy; + break; + case AstTableAccess::Read: + p.readTy = propTy; + break; + case AstTableAccess::Write: + reportError(*prop.accessLocation, GenericError{"write keyword is illegal here"}); + p.readTy = propTy; + p.writeTy = propTy; + break; + default: + ice->ice("Unexpected property access " + std::to_string(int(prop.access))); + break; + } + } + + if (AstTableIndexer* astIndexer = tab->indexer) + { + if (astIndexer->access == AstTableAccess::Read) + reportError(astIndexer->accessLocation.value_or(Location{}), GenericError{"read keyword is illegal here"}); + else if (astIndexer->access == AstTableAccess::Write) + reportError(astIndexer->accessLocation.value_or(Location{}), GenericError{"write keyword is illegal here"}); + else if (astIndexer->access == AstTableAccess::ReadWrite) + { + indexer = TableIndexer{ + resolveType(scope, astIndexer->indexType, inTypeArguments), + resolveType(scope, astIndexer->resultType, inTypeArguments), + }; + } + else + ice->ice("Unexpected property access " + std::to_string(int(astIndexer->access))); + } + + TypeId tableTy = arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); + TableType* ttv = getMutable(tableTy); + + ttv->definitionModuleName = module->name; + ttv->definitionLocation = tab->location; + + return tableTy; +} + +TypeId ConstraintGenerator::resolveFunctionType( + const ScopePtr& scope, + AstType* ty, + AstTypeFunction* fn, + bool inTypeArguments, + bool replaceErrorWithFresh +) +{ + bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; + ScopePtr signatureScope = nullptr; + + std::vector genericTypes; + std::vector genericTypePacks; + + // If we don't have generics, we do not need to generate a child scope + // for the generic bindings to live on. + if (hasGenerics) + { + signatureScope = childScope(fn, scope); + + std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); + std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); + + for (const auto& [name, g] : genericDefinitions) + { + genericTypes.push_back(g.ty); + } + + for (const auto& [name, g] : genericPackDefinitions) + { + genericTypePacks.push_back(g.tp); + } + } + else + { + // To eliminate the need to branch on hasGenerics below, we say that + // the signature scope is the parent scope if we don't have + // generics. + signatureScope = scope; + } + + TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments, replaceErrorWithFresh); + TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments, replaceErrorWithFresh); + + // TODO: FunctionType needs a pointer to the scope so that we know + // how to quantify/instantiate it. + FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; + ftv.isCheckedFunction = fn->isCheckedFunction(); + + // This replicates the behavior of the appropriate FunctionType + // constructors. + ftv.generics = std::move(genericTypes); + ftv.genericPacks = std::move(genericTypePacks); + + ftv.argNames.reserve(fn->argNames.size); + for (const auto& el : fn->argNames) + { + if (el) + { + const auto& [name, location] = *el; + ftv.argNames.push_back(FunctionArgument{name.value, location}); + } + else + { + ftv.argNames.push_back(std::nullopt); + } + } + + return arena->addType(std::move(ftv)); +} + TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool inTypeArguments, bool replaceErrorWithFresh) { TypeId result = nullptr; if (auto ref = ty->as()) { - if (FFlag::DebugLuauMagicTypes) - { - if (ref->name == "_luau_ice") - ice->ice("_luau_ice encountered", ty->location); - else if (ref->name == "_luau_print") - { - if (ref->parameters.size != 1 || !ref->parameters.data[0].type) - { - reportError(ty->location, GenericError{"_luau_print requires one generic parameter"}); - module->astResolvedTypes[ty] = builtinTypes->errorRecoveryType(); - return builtinTypes->errorRecoveryType(); - } - else - return resolveType(scope, ref->parameters.data[0].type, inTypeArguments); - } - } - - std::optional alias; - - if (ref->prefix.has_value()) - { - alias = scope->lookupImportedType(ref->prefix->value, ref->name.value); - } - else - { - alias = scope->lookupType(ref->name.value); - } - - if (alias.has_value()) - { - // If the alias is not generic, we don't need to set up a blocked - // type and an instantiation constraint. - if (alias.has_value() && alias->typeParams.empty() && alias->typePackParams.empty()) - { - result = alias->type; - } - else - { - std::vector parameters; - std::vector packParameters; - - for (const AstTypeOrPack& p : ref->parameters) - { - // We do not enforce the ordering of types vs. type packs here; - // that is done in the parser. - if (p.type) - { - parameters.push_back(resolveType(scope, p.type, /* inTypeArguments */ true)); - } - else if (p.typePack) - { - TypePackId tp = resolveTypePack(scope, p.typePack, /*inTypeArguments*/ true); - - // If we need more regular types, we can use single element type packs to fill those in - if (parameters.size() < alias->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) - parameters.push_back(*first(tp)); - else - packParameters.push_back(tp); - } - else - { - // This indicates a parser bug: one of these two pointers - // should be set. - LUAU_ASSERT(false); - } - } - - result = arena->addType(PendingExpansionType{ref->prefix, ref->name, parameters, packParameters}); - - // If we're not in a type argument context, we need to create a constraint that expands this. - // The dispatching of the above constraint will queue up additional constraints for nested - // type function applications. - if (!inTypeArguments) - addConstraint(scope, ty->location, TypeAliasExpansionConstraint{/* target */ result}); - } - } - else - { - result = builtinTypes->errorRecoveryType(); - if (replaceErrorWithFresh) - result = freshType(scope); - } + result = resolveReferenceType(scope, ty, ref, inTypeArguments, replaceErrorWithFresh); } else if (auto tab = ty->as()) { - TableType::Props props; - std::optional indexer; - - for (const AstTableProp& prop : tab->props) - { - // TODO: Recursion limit. - TypeId propTy = resolveType(scope, prop.type, inTypeArguments); - - Property& p = props[prop.name.value]; - p.typeLocation = prop.location; - - switch (prop.access) - { - case AstTableAccess::ReadWrite: - p.readTy = propTy; - p.writeTy = propTy; - break; - case AstTableAccess::Read: - p.readTy = propTy; - break; - case AstTableAccess::Write: - reportError(*prop.accessLocation, GenericError{"write keyword is illegal here"}); - p.readTy = propTy; - p.writeTy = propTy; - break; - default: - ice->ice("Unexpected property access " + std::to_string(int(prop.access))); - break; - } - } - - if (AstTableIndexer* astIndexer = tab->indexer) - { - if (astIndexer->access == AstTableAccess::Read) - reportError(astIndexer->accessLocation.value_or(Location{}), GenericError{"read keyword is illegal here"}); - else if (astIndexer->access == AstTableAccess::Write) - reportError(astIndexer->accessLocation.value_or(Location{}), GenericError{"write keyword is illegal here"}); - else if (astIndexer->access == AstTableAccess::ReadWrite) - { - // TODO: Recursion limit. - indexer = TableIndexer{ - resolveType(scope, astIndexer->indexType, inTypeArguments), - resolveType(scope, astIndexer->resultType, inTypeArguments), - }; - } - else - ice->ice("Unexpected property access " + std::to_string(int(astIndexer->access))); - } - - result = arena->addType(TableType{props, indexer, scope->level, scope.get(), TableState::Sealed}); + result = resolveTableType(scope, ty, tab, inTypeArguments, replaceErrorWithFresh); } else if (auto fn = ty->as()) { - // TODO: Recursion limit. - bool hasGenerics = fn->generics.size > 0 || fn->genericPacks.size > 0; - ScopePtr signatureScope = nullptr; - - std::vector genericTypes; - std::vector genericTypePacks; - - // If we don't have generics, we do not need to generate a child scope - // for the generic bindings to live on. - if (hasGenerics) - { - signatureScope = childScope(fn, scope); - - std::vector> genericDefinitions = createGenerics(signatureScope, fn->generics); - std::vector> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); - - for (const auto& [name, g] : genericDefinitions) - { - genericTypes.push_back(g.ty); - } - - for (const auto& [name, g] : genericPackDefinitions) - { - genericTypePacks.push_back(g.tp); - } - } - else - { - // To eliminate the need to branch on hasGenerics below, we say that - // the signature scope is the parent scope if we don't have - // generics. - signatureScope = scope; - } - - TypePackId argTypes = resolveTypePack(signatureScope, fn->argTypes, inTypeArguments, replaceErrorWithFresh); - TypePackId returnTypes = resolveTypePack(signatureScope, fn->returnTypes, inTypeArguments, replaceErrorWithFresh); - - // TODO: FunctionType needs a pointer to the scope so that we know - // how to quantify/instantiate it. - FunctionType ftv{TypeLevel{}, scope.get(), {}, {}, argTypes, returnTypes}; - ftv.isCheckedFunction = fn->isCheckedFunction(); - - // This replicates the behavior of the appropriate FunctionType - // constructors. - ftv.generics = std::move(genericTypes); - ftv.genericPacks = std::move(genericTypePacks); - - ftv.argNames.reserve(fn->argNames.size); - for (const auto& el : fn->argNames) - { - if (el) - { - const auto& [name, location] = *el; - ftv.argNames.push_back(FunctionArgument{name.value, location}); - } - else - { - ftv.argNames.push_back(std::nullopt); - } - } - - result = arena->addType(std::move(ftv)); + result = resolveFunctionType(scope, ty, fn, inTypeArguments, replaceErrorWithFresh); } else if (auto tof = ty->as()) { - // TODO: Recursion limit. TypeId exprType = check(scope, tof->expr).ty; result = exprType; } else if (auto unionAnnotation = ty->as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (unionAnnotation->types.size == 1) + return resolveType(scope, unionAnnotation->types.data[0], inTypeArguments); + } + std::vector parts; for (AstType* part : unionAnnotation->types) { - // TODO: Recursion limit. parts.push_back(resolveType(scope, part, inTypeArguments)); } @@ -3175,15 +3517,24 @@ TypeId ConstraintGenerator::resolveType(const ScopePtr& scope, AstType* ty, bool } else if (auto intersectionAnnotation = ty->as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (intersectionAnnotation->types.size == 1) + return resolveType(scope, intersectionAnnotation->types.data[0], inTypeArguments); + } + std::vector parts; for (AstType* part : intersectionAnnotation->types) { - // TODO: Recursion limit. parts.push_back(resolveType(scope, part, inTypeArguments)); } result = arena->addType(IntersectionType{parts}); } + else if (auto typeGroupAnnotation = ty->as()) + { + result = resolveType(scope, typeGroupAnnotation->type, inTypeArguments); + } else if (auto boolAnnotation = ty->as()) { if (boolAnnotation->value) @@ -3265,33 +3616,34 @@ TypePackId ConstraintGenerator::resolveTypePack(const ScopePtr& scope, const Ast std::vector> ConstraintGenerator::createGenerics( const ScopePtr& scope, - AstArray generics, + AstArray generics, bool useCache, bool addTypes ) { std::vector> result; - for (const auto& generic : generics) + for (const auto* generic : generics) { TypeId genericTy = nullptr; - if (auto it = scope->parent->typeAliasTypeParameters.find(generic.name.value); useCache && it != scope->parent->typeAliasTypeParameters.end()) + if (auto it = scope->parent->typeAliasTypeParameters.find(generic->name.value); + useCache && it != scope->parent->typeAliasTypeParameters.end()) genericTy = it->second; else { - genericTy = arena->addType(GenericType{scope.get(), generic.name.value}); - scope->parent->typeAliasTypeParameters[generic.name.value] = genericTy; + genericTy = arena->addType(GenericType{scope.get(), generic->name.value}); + scope->parent->typeAliasTypeParameters[generic->name.value] = genericTy; } std::optional defaultTy = std::nullopt; - if (generic.defaultValue) - defaultTy = resolveType(scope, generic.defaultValue, /* inTypeArguments */ false); + if (generic->defaultValue) + defaultTy = resolveType(scope, generic->defaultValue, /* inTypeArguments */ false); if (addTypes) - scope->privateTypeBindings[generic.name.value] = TypeFun{genericTy}; + scope->privateTypeBindings[generic->name.value] = TypeFun{genericTy}; - result.push_back({generic.name.value, GenericTypeDefinition{genericTy, defaultTy}}); + result.emplace_back(generic->name.value, GenericTypeDefinition{genericTy, defaultTy}); } return result; @@ -3299,34 +3651,34 @@ std::vector> ConstraintGenerator::createG std::vector> ConstraintGenerator::createGenericPacks( const ScopePtr& scope, - AstArray generics, + AstArray generics, bool useCache, bool addTypes ) { std::vector> result; - for (const auto& generic : generics) + for (const auto* generic : generics) { TypePackId genericTy; - if (auto it = scope->parent->typeAliasTypePackParameters.find(generic.name.value); + if (auto it = scope->parent->typeAliasTypePackParameters.find(generic->name.value); useCache && it != scope->parent->typeAliasTypePackParameters.end()) genericTy = it->second; else { - genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic.name.value}}); - scope->parent->typeAliasTypePackParameters[generic.name.value] = genericTy; + genericTy = arena->addTypePack(TypePackVar{GenericTypePack{scope.get(), generic->name.value}}); + scope->parent->typeAliasTypePackParameters[generic->name.value] = genericTy; } std::optional defaultTy = std::nullopt; - if (generic.defaultValue) - defaultTy = resolveTypePack(scope, generic.defaultValue, /* inTypeArguments */ false); + if (generic->defaultValue) + defaultTy = resolveTypePack(scope, generic->defaultValue, /* inTypeArguments */ false); if (addTypes) - scope->privateTypePackBindings[generic.name.value] = genericTy; + scope->privateTypePackBindings[generic->name.value] = genericTy; - result.push_back({generic.name.value, GenericTypePackDefinition{genericTy, defaultTy}}); + result.emplace_back(generic->name.value, GenericTypePackDefinition{genericTy, defaultTy}); } return result; @@ -3384,6 +3736,65 @@ TypeId ConstraintGenerator::makeIntersect(const ScopePtr& scope, Location locati return resultType; } +struct FragmentTypeCheckGlobalPrepopulator : AstVisitor +{ + const NotNull globalScope; + const NotNull currentScope; + const NotNull dfg; + const NotNull arena; + + FragmentTypeCheckGlobalPrepopulator( + NotNull globalScope, + NotNull currentScope, + NotNull dfg, + NotNull arena + ) + : globalScope(globalScope) + , currentScope(currentScope) + , dfg(dfg) + , arena(arena) + { + } + + bool visit(AstExprGlobal* global) override + { + if (auto ty = globalScope->lookup(global->name)) + { + DefId def = dfg->getDef(global); + // We only want to write into the current scope the type of the global + currentScope->lvalueTypes[def] = *ty; + } + else if (auto ty = currentScope->lookup(global->name)) + { + // We are trying to create a binding for a brand new function, so we actually do have to write it into the scope. + DefId def = dfg->getDef(global); + // We only want to write into the current scope the type of the global + currentScope->lvalueTypes[def] = *ty; + } + + return true; + } + + bool visit(AstStatFunction* function) override + { + if (AstExprGlobal* g = function->name->as()) + { + if (auto ty = globalScope->lookup(g->name)) + { + currentScope->bindings[g->name] = Binding{*ty}; + } + else + { + // Hasn't existed since a previous typecheck + TypeId bt = arena->addType(BlockedType{}); + currentScope->bindings[g->name] = Binding{bt}; + } + } + + return true; + } +}; + struct GlobalPrepopulator : AstVisitor { const NotNull globalScope; @@ -3408,6 +3819,23 @@ struct GlobalPrepopulator : AstVisitor return true; } + bool visit(AstStatAssign* assign) override + { + for (const Luau::AstExpr* expr : assign->vars) + { + if (const AstExprGlobal* g = expr->as()) + { + if (!globalScope->lookup(g->name)) + globalScope->globalsToWarn.insert(g->name.value); + + TypeId bt = arena->addType(BlockedType{}); + globalScope->bindings[g->name] = Binding{bt, g->location}; + } + } + + return true; + } + bool visit(AstStatFunction* function) override { if (AstExprGlobal* g = function->name->as()) @@ -3430,6 +3858,14 @@ struct GlobalPrepopulator : AstVisitor } }; +void ConstraintGenerator::prepopulateGlobalScopeForFragmentTypecheck(const ScopePtr& globalScope, const ScopePtr& resumeScope, AstStatBlock* program) +{ + FragmentTypeCheckGlobalPrepopulator gp{NotNull{globalScope.get()}, NotNull{resumeScope.get()}, dfg, arena}; + if (prepareModuleScope) + prepareModuleScope(module->name, resumeScope); + program->visit(&gp); +} + void ConstraintGenerator::prepopulateGlobalScope(const ScopePtr& globalScope, AstStatBlock* program) { GlobalPrepopulator gp{NotNull{globalScope.get()}, arena, dfg}; @@ -3581,6 +4017,11 @@ TypeId ConstraintGenerator::createTypeFunctionInstance( return result; } +TypeId ConstraintGenerator::simplifyUnion(const ScopePtr& scope, Location location, TypeId left, TypeId right) +{ + return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result; +} + std::vector> borrowConstraints(const std::vector& constraints) { std::vector> result; diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index d978ea1b..aae536e5 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -27,10 +27,16 @@ #include #include -LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); -LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies, false) -LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings, false); -LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500); +LUAU_FASTFLAGVARIABLE(DebugLuauAssertOnForcedConstraint) +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver) +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverIncludeDependencies) +LUAU_FASTFLAGVARIABLE(DebugLuauLogBindings) +LUAU_FASTINTVARIABLE(LuauSolverRecursionLimit, 500) +LUAU_FASTFLAGVARIABLE(DebugLuauEqSatSimplification) +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) +LUAU_FASTFLAGVARIABLE(LuauTrackInteriorFreeTablesOnScope) +LUAU_FASTFLAGVARIABLE(LuauPrecalculateMutatedFreeTypes2) +LUAU_FASTFLAGVARIABLE(DebugLuauGreedyGeneralization) namespace Luau { @@ -68,7 +74,7 @@ size_t HashBlockedConstraintId::operator()(const BlockedConstraintId& bci) const { if (auto blocked = get(ty)) { - Constraint* owner = blocked->getOwner(); + const Constraint* owner = blocked->getOwner(); LUAU_ASSERT(owner); return owner == constraint; } @@ -583,22 +589,30 @@ struct InstantiationQueuer : TypeOnceVisitor ConstraintSolver::ConstraintSolver( NotNull normalizer, + NotNull simplifier, + NotNull typeFunctionRuntime, NotNull rootScope, std::vector> constraints, + NotNull> scopeToFunction, ModuleName moduleName, NotNull moduleResolver, std::vector requireCycles, DcrLogger* logger, + NotNull dfg, TypeCheckLimits limits ) : arena(normalizer->arena) , builtinTypes(normalizer->builtinTypes) , normalizer(normalizer) + , simplifier(simplifier) + , typeFunctionRuntime(typeFunctionRuntime) , constraints(std::move(constraints)) + , scopeToFunction(scopeToFunction) , rootScope(rootScope) , currentModuleName(std::move(moduleName)) + , dfg(dfg) , moduleResolver(moduleResolver) - , requireCycles(requireCycles) + , requireCycles(std::move(requireCycles)) , logger(logger) , limits(std::move(limits)) { @@ -606,15 +620,35 @@ ConstraintSolver::ConstraintSolver( for (NotNull c : this->constraints) { - unsolvedConstraints.push_back(c); + unsolvedConstraints.emplace_back(c); - // initialize the reference counts for the free types in this constraint. - for (auto ty : c->getMaybeMutatedFreeTypes()) + if (FFlag::LuauPrecalculateMutatedFreeTypes2) { - // increment the reference count for `ty` - auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0); - refCount += 1; + auto maybeMutatedTypesPerConstraint = c->getMaybeMutatedFreeTypes(); + for (auto ty : maybeMutatedTypesPerConstraint) + { + auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0); + refCount += 1; + + if (FFlag::DebugLuauGreedyGeneralization) + { + auto [it, fresh] = mutatedFreeTypeToConstraint.try_emplace(ty, DenseHashSet{nullptr}); + it->second.insert(c.get()); + } + } + maybeMutatedFreeTypes.emplace(c, maybeMutatedTypesPerConstraint); } + else + { + // initialize the reference counts for the free types in this constraint. + for (auto ty : c->getMaybeMutatedFreeTypes()) + { + // increment the reference count for `ty` + auto [refCount, _] = unresolvedConstraints.try_insert(ty, 0); + refCount += 1; + } + } + for (NotNull dep : c->dependencies) { @@ -695,6 +729,9 @@ void ConstraintSolver::run() // Set current Constraint currentConstraintRef = c.get(); + if (FFlag::DebugLuauAssertOnForcedConstraint) + LUAU_ASSERT(!force); + bool success = tryDispatch(c, force); progress |= success; @@ -702,22 +739,62 @@ void ConstraintSolver::run() if (success) { unblock(c); - unsolvedConstraints.erase(unsolvedConstraints.begin() + i); + unsolvedConstraints.erase(unsolvedConstraints.begin() + ptrdiff_t(i)); - // decrement the referenced free types for this constraint if we dispatched successfully! - for (auto ty : c->getMaybeMutatedFreeTypes()) + if (FFlag::LuauPrecalculateMutatedFreeTypes2) { - size_t& refCount = unresolvedConstraints[ty]; - if (refCount > 0) - refCount -= 1; + const auto maybeMutated = maybeMutatedFreeTypes.find(c); + if (maybeMutated != maybeMutatedFreeTypes.end()) + { + DenseHashSet seen{nullptr}; + for (auto ty : maybeMutated->second) + { + // There is a high chance that this type has been rebound + // across blocked types, rebound free types, pending + // expansion types, etc, so we need to follow it. + ty = follow(ty); - // We have two constraints that are designed to wait for the - // refCount on a free type to be equal to 1: the - // PrimitiveTypeConstraint and ReduceConstraint. We - // therefore wake any constraint waiting for a free type's - // refcount to be 1 or 0. - if (refCount <= 1) - unblock(ty, Location{}); + if (FFlag::DebugLuauGreedyGeneralization) + { + if (seen.contains(ty)) + continue; + seen.insert(ty); + } + + size_t& refCount = unresolvedConstraints[ty]; + if (refCount > 0) + refCount -= 1; + + // We have two constraints that are designed to wait for the + // refCount on a free type to be equal to 1: the + // PrimitiveTypeConstraint and ReduceConstraint. We + // therefore wake any constraint waiting for a free type's + // refcount to be 1 or 0. + if (refCount <= 1) + unblock(ty, Location{}); + + if (FFlag::DebugLuauGreedyGeneralization && refCount == 0) + generalizeOneType(ty); + } + } + } + else + { + // decrement the referenced free types for this constraint if we dispatched successfully! + for (auto ty : c->getMaybeMutatedFreeTypes()) + { + size_t& refCount = unresolvedConstraints[ty]; + if (refCount > 0) + refCount -= 1; + + // We have two constraints that are designed to wait for the + // refCount on a free type to be equal to 1: the + // PrimitiveTypeConstraint and ReduceConstraint. We + // therefore wake any constraint waiting for a free type's + // refcount to be 1 or 0. + if (refCount <= 1) + unblock(ty, Location{}); + } } if (logger) @@ -809,21 +886,159 @@ void ConstraintSolver::finalizeTypeFunctions() } } -bool ConstraintSolver::isDone() +bool ConstraintSolver::isDone() const { return unsolvedConstraints.empty(); } -namespace +struct TypeSearcher : TypeVisitor { + enum struct Polarity: uint8_t + { + None = 0b00, + Positive = 0b01, + Negative = 0b10, + Mixed = 0b11, + }; -struct TypeAndLocation -{ - TypeId typeId; - Location location; + TypeId needle; + Polarity current = Polarity::Positive; + + Polarity result = Polarity::None; + + explicit TypeSearcher(TypeId needle) + : TypeSearcher(needle, Polarity::Positive) + {} + + explicit TypeSearcher(TypeId needle, Polarity initialPolarity) + : needle(needle) + , current(initialPolarity) + {} + + bool visit(TypeId ty) override + { + if (ty == needle) + result = Polarity(int(result) | int(current)); + + return true; + } + + void flip() + { + switch (current) + { + case Polarity::Positive: + current = Polarity::Negative; + break; + case Polarity::Negative: + current = Polarity::Positive; + break; + default: + break; + } + } + + bool visit(TypeId ty, const FunctionType& ft) override + { + flip(); + traverse(ft.argTypes); + + flip(); + traverse(ft.retTypes); + + return false; + } + + // bool visit(TypeId ty, const TableType& tt) override + // { + + // } + + bool visit(TypeId ty, const ClassType&) override + { + return false; + } }; -} // namespace +void ConstraintSolver::generalizeOneType(TypeId ty) +{ + ty = follow(ty); + const FreeType* freeTy = get(ty); + + std::string saveme = toString(ty, opts); + + // Some constraints (like prim) will also replace a free type with something + // concrete. If so, our work is already done. + if (!freeTy) + return; + + NotNull tyScope{freeTy->scope}; + + // TODO: If freeTy occurs within the enclosing function's type, we need to + // check to see whether this type should instead be generic. + + TypeId newBound = follow(freeTy->upperBound); + + TypeId* functionTyPtr = nullptr; + while (true) + { + functionTyPtr = scopeToFunction->find(tyScope); + if (functionTyPtr || !tyScope->parent) + break; + else if (tyScope->parent) + tyScope = NotNull{tyScope->parent.get()}; + else + break; + } + + if (ty == newBound) + ty = builtinTypes->unknownType; + + if (!functionTyPtr) + { + asMutable(ty)->reassign(Type{BoundType{follow(freeTy->upperBound)}}); + } + else + { + const TypeId functionTy = follow(*functionTyPtr); + FunctionType* const function = getMutable(functionTy); + LUAU_ASSERT(function); + + TypeSearcher ts{ty}; + ts.traverse(functionTy); + + const TypeId upperBound = follow(freeTy->upperBound); + const TypeId lowerBound = follow(freeTy->lowerBound); + + switch (ts.result) + { + case TypeSearcher::Polarity::None: + asMutable(ty)->reassign(Type{BoundType{upperBound}}); + break; + + case TypeSearcher::Polarity::Negative: + case TypeSearcher::Polarity::Mixed: + if (get(upperBound)) + { + asMutable(ty)->reassign(Type{GenericType{tyScope}}); + function->generics.emplace_back(ty); + } + else + asMutable(ty)->reassign(Type{BoundType{upperBound}}); + break; + + case TypeSearcher::Polarity::Positive: + if (get(lowerBound)) + { + asMutable(ty)->reassign(Type{GenericType{tyScope}}); + function->generics.emplace_back(ty); + } + else + asMutable(ty)->reassign(Type{BoundType{lowerBound}}); + break; + } + } +} void ConstraintSolver::bind(NotNull constraint, TypeId ty, TypeId boundTo) { @@ -883,11 +1098,11 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo bool success = false; if (auto sc = get(*constraint)) - success = tryDispatch(*sc, constraint, force); + success = tryDispatch(*sc, constraint); else if (auto psc = get(*constraint)) - success = tryDispatch(*psc, constraint, force); + success = tryDispatch(*psc, constraint); else if (auto gc = get(*constraint)) - success = tryDispatch(*gc, constraint, force); + success = tryDispatch(*gc, constraint); else if (auto ic = get(*constraint)) success = tryDispatch(*ic, constraint, force); else if (auto nc = get(*constraint)) @@ -898,6 +1113,8 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo success = tryDispatch(*fcc, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); + else if (auto tcc = get(*constraint)) + success = tryDispatch(*tcc, constraint); else if (auto fcc = get(*constraint)) success = tryDispatch(*fcc, constraint); else if (auto hpc = get(*constraint)) @@ -915,14 +1132,14 @@ bool ConstraintSolver::tryDispatch(NotNull constraint, bool fo else if (auto rpc = get(*constraint)) success = tryDispatch(*rpc, constraint, force); else if (auto eqc = get(*constraint)) - success = tryDispatch(*eqc, constraint, force); + success = tryDispatch(*eqc, constraint); else LUAU_ASSERT(false); return success; } -bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint) { if (isBlocked(c.subType)) return block(c.subType, constraint); @@ -934,7 +1151,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint) { if (isBlocked(c.subPack)) return block(c.subPack, constraint); @@ -946,7 +1163,7 @@ bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull constraint) { TypeId generalizedType = follow(c.generalizedType); @@ -982,8 +1199,20 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullerrorRecoveryType()); } - for (TypeId ty : c.interiorTypes) - generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + { + // We check if this member is initialized and then access it, but + // clang-tidy doesn't understand this is safe. + if (constraint->scope->interiorFreeTypes) + for (TypeId ty : *constraint->scope->interiorFreeTypes) // NOLINT(bugprone-unchecked-optional-access) + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); + } + else + { + for (TypeId ty : c.interiorTypes) + generalize(NotNull{arena}, builtinTypes, constraint->scope, generalizedTypes, ty, /* avoidSealingTables */ false); + } + return true; } @@ -1059,9 +1288,17 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNullscope); TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + { + trackInteriorFreeType(constraint->scope, keyTy); + trackInteriorFreeType(constraint->scope, valueTy); + } TypeId tableTy = arena->addType(TableType{TableType::Props{}, TableIndexer{keyTy, valueTy}, TypeLevel{}, constraint->scope, TableState::Free}); + if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope) + trackInteriorFreeType(constraint->scope, tableTy); + unify(constraint, nextTy, tableTy); auto it = begin(c.variables); @@ -1093,7 +1330,7 @@ bool ConstraintSolver::tryDispatch(const IterableConstraint& c, NotNull= 2) tableTy = iterator.head[1]; - return tryDispatchIterableFunction(nextTy, tableTy, c, constraint, force); + return tryDispatchIterableFunction(nextTy, tableTy, c, constraint); } else @@ -1178,9 +1415,10 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul auto bindResult = [this, &c, constraint](TypeId result) { - LUAU_ASSERT(get(c.target)); - shiftReferences(c.target, result); - bind(constraint, c.target, result); + auto cTarget = follow(c.target); + LUAU_ASSERT(get(cTarget)); + shiftReferences(cTarget, result); + bind(constraint, cTarget, result); }; std::optional tf = (petv->prefix) ? constraint->scope->lookupImportedType(petv->prefix->value, petv->name.value) @@ -1197,18 +1435,10 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul if (auto typeFn = get(follow(tf->type))) pushConstraint(NotNull(constraint->scope.get()), constraint->location, ReduceConstraint{tf->type}); - // If there are no parameters to the type function we can just use the type - // directly. - if (tf->typeParams.empty() && tf->typePackParams.empty()) - { - bindResult(tf->type); - return true; - } - // Due to how pending expansion types and TypeFun's are created // If this check passes, we have created a cyclic / corecursive type alias // of size 0 - TypeId lhs = c.target; + TypeId lhs = follow(c.target); TypeId rhs = tf->type; if (occursCheck(lhs, rhs)) { @@ -1217,6 +1447,13 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } + // If there are no parameters to the type function we can just use the type directly + if (tf->typeParams.empty() && tf->typePackParams.empty()) + { + bindResult(tf->type); + return true; + } + auto [typeArguments, packArguments] = saturateArguments(arena, builtinTypes, *tf, petv->typeArguments, petv->packArguments); bool sameTypes = std::equal( @@ -1362,9 +1599,12 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul target = follow(instantiated); } + // This is a new type - redefine the location. + ttv->definitionLocation = constraint->location; + ttv->definitionModuleName = currentModuleName; + ttv->instantiatedTypeParams = typeArguments; ttv->instantiatedTypePackParams = packArguments; - // TODO: Fill in definitionModuleName. } bindResult(target); @@ -1374,6 +1614,25 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul return true; } +void ConstraintSolver::fillInDiscriminantTypes(NotNull constraint, const std::vector>& discriminantTypes) +{ + for (std::optional ty : discriminantTypes) + { + if (!ty) + continue; + + // If the discriminant type has been transmuted, we need to unblock them. + if (!isBlocked(*ty)) + { + unblock(*ty, constraint->location); + continue; + } + + // We bind any unused discriminants to the `*no-refine*` type indicating that it can be safely ignored. + emplaceType(asMutable(follow(*ty)), builtinTypes->noRefineType); + } +} + bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull constraint) { TypeId fn = follow(c.fn); @@ -1389,6 +1648,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(asMutable(c.result), builtinTypes->anyTypePack); unblock(c.result, constraint->location); + fillInDiscriminantTypes(constraint, c.discriminantTypes); return true; } @@ -1396,12 +1656,14 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNull(fn)) { bind(constraint, c.result, builtinTypes->errorRecoveryTypePack()); + fillInDiscriminantTypes(constraint, c.discriminantTypes); return true; } if (get(fn)) { bind(constraint, c.result, builtinTypes->neverTypePack); + fillInDiscriminantTypes(constraint, c.discriminantTypes); return true; } @@ -1471,41 +1733,29 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulldcrMagicFunction) - usedMagic = ftv->dcrMagicFunction(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result}); - - if (ftv->dcrMagicRefinement) - ftv->dcrMagicRefinement(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); + if (ftv->magic) + { + usedMagic = ftv->magic->infer(MagicFunctionCallContext{NotNull{this}, constraint, c.callSite, c.argsPack, result}); + ftv->magic->refine(MagicRefinementContext{constraint->scope, c.callSite, c.discriminantTypes}); + } } if (!usedMagic) emplace(constraint, c.result, constraint->scope); } - for (std::optional ty : c.discriminantTypes) - { - if (!ty) - continue; - - // If the discriminant type has been transmuted, we need to unblock them. - if (!isBlocked(*ty)) - { - unblock(*ty, constraint->location); - continue; - } - - // We use `any` here because the discriminant type may be pointed at by both branches, - // where the discriminant type is not negated, and the other where it is negated, i.e. - // `unknown ~ unknown` and `~unknown ~ never`, so `T & unknown ~ T` and `T & ~unknown ~ never` - // v.s. - // `any ~ any` and `~any ~ any`, so `T & any ~ T` and `T & ~any ~ T` - // - // In practice, users cannot negate `any`, so this is an implementation detail we can always change. - emplaceType(asMutable(follow(*ty)), builtinTypes->anyType); - } + fillInDiscriminantTypes(constraint, c.discriminantTypes); OverloadResolver resolver{ - builtinTypes, NotNull{arena}, normalizer, constraint->scope, NotNull{&iceReporter}, NotNull{&limits}, constraint->location + builtinTypes, + NotNull{arena}, + simplifier, + normalizer, + typeFunctionRuntime, + constraint->scope, + NotNull{&iceReporter}, + NotNull{&limits}, + constraint->location }; auto [status, overload] = resolver.selectOverload(fn, argsPack); TypeId overloadToUse = fn; @@ -1535,7 +1785,7 @@ bool ConstraintSolver::tryDispatch(const FunctionCallConstraint& c, NotNulllocation, addition)); + upperBoundContributors[expanded].emplace_back(constraint->location, addition); } if (occursCheckPassed && c.callSite) @@ -1571,6 +1821,19 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullargs.data[i]); + AstExpr* expr = unwrapGroup(c.callSite->args.data[i]); (*c.astExpectedTypes)[expr] = expectedArgTy; @@ -1655,7 +1918,8 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullis() || expr->is() || expr->is() || expr->is()) + else if (expr->is() || expr->is() || expr->is() || + expr->is()) { Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; u2.unify(actualArgTy, expectedArgTy); @@ -1665,16 +1929,35 @@ bool ConstraintSolver::tryDispatch(const FunctionCheckConstraint& c, NotNullscope, NotNull{&iceReporter}}; std::vector toBlock; (void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, expectedArgTy, actualArgTy, expr, toBlock); - for (auto t : toBlock) - block(t, constraint); - if (!toBlock.empty()) - return false; + LUAU_ASSERT(toBlock.empty()); } } return true; } +bool ConstraintSolver::tryDispatch(const TableCheckConstraint& c, NotNull constraint) +{ + // This is expensive as we need to traverse a (potentially large) + // literal up front in order to determine if there are any blocked + // types, otherwise we may run `matchTypeLiteral` multiple times, + // which right now may fail due to being non-idempotent (it + // destructively updates the underlying literal type). + auto blockedTypes = findBlockedTypesIn(c.table, c.astTypes); + for (const auto ty : blockedTypes) + { + block(ty, constraint); + } + if (!blockedTypes.empty()) + return false; + + Unifier2 u2{arena, builtinTypes, constraint->scope, NotNull{&iceReporter}}; + std::vector toBlock; + (void)matchLiteralType(c.astTypes, c.astExpectedTypes, builtinTypes, arena, NotNull{&u2}, c.expectedType, c.exprType, c.table, toBlock); + LUAU_ASSERT(toBlock.empty()); + return true; +} + bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNull constraint) { std::optional expectedType = c.expectedType ? std::make_optional(follow(*c.expectedType)) : std::nullopt; @@ -1702,8 +1985,9 @@ bool ConstraintSolver::tryDispatch(const PrimitiveTypeConstraint& c, NotNulllowerBound; - shiftReferences(c.freeType, bindTo); - bind(constraint, c.freeType, bindTo); + auto ty = follow(c.freeType); + shiftReferences(ty, bindTo); + bind(constraint, ty, bindTo); return true; } @@ -1727,7 +2011,8 @@ bool ConstraintSolver::tryDispatch(const HasPropConstraint& c, NotNulladdType(BlockedType{}); - getMutable(r)->setOwner(const_cast(constraint.get())); + getMutable(r)->setOwner(constraint.get()); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); // If we've cut a recursive loop short, skip it. @@ -1874,7 +2159,7 @@ bool ConstraintSolver::tryDispatchHasIndexer( for (TypeId part : parts) { TypeId r = arena->addType(BlockedType{}); - getMutable(r)->setOwner(const_cast(constraint.get())); + getMutable(r)->setOwner(constraint.get()); bool ok = tryDispatchHasIndexer(recursionDepth, constraint, part, indexType, r, seen); // If we've cut a recursive loop short, skip it. @@ -1988,18 +2273,23 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNull(lhsType)) { - if (get(lhsFree->upperBound) || get(lhsFree->upperBound)) - lhsType = lhsFree->upperBound; + auto lhsFreeUpperBound = follow(lhsFree->upperBound); + if (get(lhsFreeUpperBound) || get(lhsFreeUpperBound)) + lhsType = lhsFreeUpperBound; else { TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, constraint->scope}); + + if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope) + trackInteriorFreeType(constraint->scope, newUpperBound); + TableType* upperTable = getMutable(newUpperBound); LUAU_ASSERT(upperTable); upperTable->props[c.propName] = rhsType; // Food for thought: Could we block if simplification encounters a blocked type? - lhsFree->upperBound = simplifyIntersection(builtinTypes, arena, lhsFree->upperBound, newUpperBound).result; + lhsFree->upperBound = simplifyIntersection(constraint->scope, constraint->location, lhsFreeUpperBound, newUpperBound); bind(constraint, c.propType, rhsType); return true; @@ -2008,7 +2298,7 @@ bool ConstraintSolver::tryDispatch(const AssignPropConstraint& c, NotNulladdType(UnionType{{propTy, builtinTypes->nilType}}) : propTy + ); unify(constraint, rhsType, propTy); return true; } @@ -2113,7 +2407,11 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullindexer->indexType); unify(constraint, rhsType, lhsTable->indexer->indexResultType); - bind(constraint, c.propType, lhsTable->indexer->indexResultType); + bind( + constraint, + c.propType, + arena->addType(UnionType{{lhsTable->indexer->indexResultType, builtinTypes->nilType}}) + ); return true; } @@ -2162,7 +2460,11 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullindexer->indexType); unify(constraint, rhsType, lhsClass->indexer->indexResultType); - bind(constraint, c.propType, lhsClass->indexer->indexResultType); + bind( + constraint, + c.propType, + arena->addType(UnionType{{lhsClass->indexer->indexResultType, builtinTypes->nilType}}) + ); return true; } @@ -2213,7 +2515,7 @@ bool ConstraintSolver::tryDispatch(const AssignIndexConstraint& c, NotNullscope, constraint->location, std::move(parts)); unify(constraint, rhsType, res); } @@ -2257,6 +2559,8 @@ bool ConstraintSolver::tryDispatch(const UnpackConstraint& c, NotNullscope); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(constraint->scope, f); shiftReferences(resultTy, f); emplaceType(asMutable(resultTy), f); } @@ -2312,6 +2616,11 @@ bool ConstraintSolver::tryDispatch(const ReduceConstraint& c, NotNull constraint, bool force) +bool ConstraintSolver::tryDispatch(const EqualityConstraint& c, NotNull constraint) { unify(constraint, c.resultType, c.assignmentType); unify(constraint, c.assignmentType, c.resultType); @@ -2387,6 +2696,11 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl { TypeId keyTy = freshType(arena, builtinTypes, constraint->scope); TypeId valueTy = freshType(arena, builtinTypes, constraint->scope); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + { + trackInteriorFreeType(constraint->scope, keyTy); + trackInteriorFreeType(constraint->scope, valueTy); + } TypeId tableTy = arena->addType(TableType{TableState::Sealed, {}, constraint->scope}); getMutable(tableTy)->indexer = TableIndexer{keyTy, valueTy}; @@ -2531,13 +2845,7 @@ bool ConstraintSolver::tryDispatchIterableTable(TypeId iteratorTy, const Iterabl return true; } -bool ConstraintSolver::tryDispatchIterableFunction( - TypeId nextTy, - TypeId tableTy, - const IterableConstraint& c, - NotNull constraint, - bool force -) +bool ConstraintSolver::tryDispatchIterableFunction(TypeId nextTy, TypeId tableTy, const IterableConstraint& c, NotNull constraint) { const FunctionType* nextFn = get(nextTy); // If this does not hold, we should've never called `tryDispatchIterableFunction` in the first place. @@ -2591,7 +2899,7 @@ NotNull ConstraintSolver::unpackAndAssign( return c; } -std::pair, std::optional> ConstraintSolver::lookupTableProp( +TablePropLookupResult ConstraintSolver::lookupTableProp( NotNull constraint, TypeId subjectType, const std::string& propName, @@ -2604,7 +2912,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa return lookupTableProp(constraint, subjectType, propName, context, inConditional, suppressSimplification, seen); } -std::pair, std::optional> ConstraintSolver::lookupTableProp( +TablePropLookupResult ConstraintSolver::lookupTableProp( NotNull constraint, TypeId subjectType, const std::string& propName, @@ -2644,11 +2952,13 @@ std::pair, std::optional> ConstraintSolver::lookupTa } if (ttv->indexer && maybeString(ttv->indexer->indexType)) - return {{}, ttv->indexer->indexResultType}; + return {{}, ttv->indexer->indexResultType, /* isIndex = */ true}; if (ttv->state == TableState::Free) { TypeId result = freshType(arena, builtinTypes, ttv->scope); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(ttv->scope, result); switch (context) { case ValueContext::RValue: @@ -2686,9 +2996,9 @@ std::pair, std::optional> ConstraintSolver::lookupTa } else if (auto mt = get(subjectType); mt && context == ValueContext::RValue) { - auto [blocked, result] = lookupTableProp(constraint, mt->table, propName, context, inConditional, suppressSimplification, seen); - if (!blocked.empty() || result) - return {blocked, result}; + auto result = lookupTableProp(constraint, mt->table, propName, context, inConditional, suppressSimplification, seen); + if (!result.blockedTypes.empty() || result.propType) + return result; TypeId mtt = follow(mt->metatable); @@ -2698,7 +3008,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa { auto indexProp = metatable->props.find("__index"); if (indexProp == metatable->props.end()) - return {{}, result}; + return {{}, result.propType}; // TODO: __index can be an overloaded function. @@ -2728,7 +3038,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa return {{}, context == ValueContext::RValue ? p->readTy : p->writeTy}; if (ct->indexer) { - return {{}, ct->indexer->indexResultType}; + return {{}, ct->indexer->indexResultType, /* isIndex = */ true}; } } else if (auto pt = get(subjectType); pt && pt->metatable) @@ -2754,10 +3064,17 @@ std::pair, std::optional> ConstraintSolver::lookupTa NotNull scope{ft->scope}; const TypeId newUpperBound = arena->addType(TableType{TableState::Free, TypeLevel{}, scope}); + + if (FFlag::LuauTrackInteriorFreeTypesOnScope && FFlag::LuauTrackInteriorFreeTablesOnScope) + trackInteriorFreeType(constraint->scope, newUpperBound); + TableType* tt = getMutable(newUpperBound); LUAU_ASSERT(tt); TypeId propType = freshType(arena, builtinTypes, scope); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(scope, propType); + switch (context) { case ValueContext::RValue: @@ -2779,10 +3096,10 @@ std::pair, std::optional> ConstraintSolver::lookupTa for (TypeId ty : utv) { - auto [innerBlocked, innerResult] = lookupTableProp(constraint, ty, propName, context, inConditional, suppressSimplification, seen); - blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); - if (innerResult) - options.insert(*innerResult); + auto result = lookupTableProp(constraint, ty, propName, context, inConditional, suppressSimplification, seen); + blocked.insert(blocked.end(), result.blockedTypes.begin(), result.blockedTypes.end()); + if (result.propType) + options.insert(*result.propType); } if (!blocked.empty()) @@ -2799,9 +3116,9 @@ std::pair, std::optional> ConstraintSolver::lookupTa // if we're in an lvalue context, we need the _common_ type here. if (context == ValueContext::LValue) - return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)}; - return {{}, simplifyUnion(builtinTypes, arena, one, two).result}; + return {{}, simplifyUnion(constraint->scope, constraint->location, one, two)}; } // if we're in an lvalue context, we need the _common_ type here. else if (context == ValueContext::LValue) @@ -2816,10 +3133,10 @@ std::pair, std::optional> ConstraintSolver::lookupTa for (TypeId ty : itv) { - auto [innerBlocked, innerResult] = lookupTableProp(constraint, ty, propName, context, inConditional, suppressSimplification, seen); - blocked.insert(blocked.end(), innerBlocked.begin(), innerBlocked.end()); - if (innerResult) - options.insert(*innerResult); + auto result = lookupTableProp(constraint, ty, propName, context, inConditional, suppressSimplification, seen); + blocked.insert(blocked.end(), result.blockedTypes.begin(), result.blockedTypes.end()); + if (result.propType) + options.insert(*result.propType); } if (!blocked.empty()) @@ -2833,7 +3150,7 @@ std::pair, std::optional> ConstraintSolver::lookupTa { TypeId one = *begin(options); TypeId two = *(++begin(options)); - return {{}, simplifyIntersection(builtinTypes, arena, one, two).result}; + return {{}, simplifyIntersection(constraint->scope, constraint->location, one, two)}; } else return {{}, arena->addType(IntersectionType{std::vector(begin(options), end(options))})}; @@ -2868,7 +3185,7 @@ bool ConstraintSolver::unify(NotNull constraint, TID subTy, TI for (const auto& [expanded, additions] : u2.expandedFreeTypes) { for (TypeId addition : additions) - upperBoundContributors[expanded].push_back(std::make_pair(constraint->location, addition)); + upperBoundContributors[expanded].emplace_back(constraint->location, addition); } } else @@ -2988,10 +3305,10 @@ bool ConstraintSolver::blockOnPendingTypes(TypeId target, NotNull constraint) +bool ConstraintSolver::blockOnPendingTypes(TypePackId targetPack, NotNull constraint) { Blocker blocker{NotNull{this}, constraint}; - blocker.traverse(pack); + blocker.traverse(targetPack); return !blocker.blocked; } @@ -3085,7 +3402,7 @@ void ConstraintSolver::reproduceConstraints(NotNull scope, const Location } } -bool ConstraintSolver::isBlocked(TypeId ty) +bool ConstraintSolver::isBlocked(TypeId ty) const { ty = follow(ty); @@ -3095,7 +3412,7 @@ bool ConstraintSolver::isBlocked(TypeId ty) return nullptr != get(ty) || nullptr != get(ty); } -bool ConstraintSolver::isBlocked(TypePackId tp) +bool ConstraintSolver::isBlocked(TypePackId tp) const { tp = follow(tp); @@ -3105,7 +3422,7 @@ bool ConstraintSolver::isBlocked(TypePackId tp) return nullptr != get(tp); } -bool ConstraintSolver::isBlocked(NotNull constraint) +bool ConstraintSolver::isBlocked(NotNull constraint) const { auto blockedIt = blockedConstraints.find(constraint); return blockedIt != blockedConstraints.end() && blockedIt->second > 0; @@ -3116,7 +3433,7 @@ NotNull ConstraintSolver::pushConstraint(NotNull scope, const std::unique_ptr c = std::make_unique(scope, location, std::move(cv)); NotNull borrow = NotNull(c.get()); solverConstraints.push_back(std::move(c)); - unsolvedConstraints.push_back(borrow); + unsolvedConstraints.emplace_back(borrow); return borrow; } @@ -3151,7 +3468,7 @@ TypeId ConstraintSolver::resolveModule(const ModuleInfo& info, const Location& l } TypePackId modulePack = module->returnType; - if (get(modulePack)) + if (get(modulePack)) return errorRecoveryType(); std::optional moduleType = first(modulePack); @@ -3194,6 +3511,24 @@ void ConstraintSolver::shiftReferences(TypeId source, TypeId target) auto [targetRefs, _] = unresolvedConstraints.try_insert(target, 0); targetRefs += count; + + // Any constraint that might have mutated source may now mutate target + + if (FFlag::DebugLuauGreedyGeneralization) + { + auto it = mutatedFreeTypeToConstraint.find(source); + if (it != mutatedFreeTypeToConstraint.end()) + { + auto [it2, fresh] = mutatedFreeTypeToConstraint.try_emplace(target, DenseHashSet{nullptr}); + for (const Constraint* constraint : it->second) + { + it2->second.insert(constraint); + + auto [it3, fresh2] = maybeMutatedFreeTypes.try_emplace(NotNull{constraint}, DenseHashSet{nullptr}); + it3->second.insert(target); + } + } + } } std::optional ConstraintSolver::generalizeFreeType(NotNull scope, TypeId type, bool avoidSealingTables) @@ -3222,6 +3557,63 @@ bool ConstraintSolver::hasUnresolvedConstraints(TypeId ty) return false; } +TypeId ConstraintSolver::simplifyIntersection(NotNull scope, Location location, TypeId left, TypeId right) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(IntersectionType{{left, right}}); + + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId ty : res->newTypeFunctions) + pushConstraint(scope, location, ReduceConstraint{ty}); + + return res->result; + } + else + return ::Luau::simplifyIntersection(builtinTypes, arena, left, right).result; +} + +TypeId ConstraintSolver::simplifyIntersection(NotNull scope, Location location, std::set parts) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(IntersectionType{std::vector(parts.begin(), parts.end())}); + + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId ty : res->newTypeFunctions) + pushConstraint(scope, location, ReduceConstraint{ty}); + + return res->result; + } + else + return ::Luau::simplifyIntersection(builtinTypes, arena, std::move(parts)).result; +} + +TypeId ConstraintSolver::simplifyUnion(NotNull scope, Location location, TypeId left, TypeId right) +{ + if (FFlag::DebugLuauEqSatSimplification) + { + TypeId ty = arena->addType(UnionType{{left, right}}); + + std::optional res = eqSatSimplify(simplifier, ty); + if (!res) + return ty; + + for (TypeId ty : res->newTypeFunctions) + pushConstraint(scope, location, ReduceConstraint{ty}); + + return res->result; + } + else + return ::Luau::simplifyUnion(builtinTypes, arena, left, right).result; +} + TypeId ConstraintSolver::errorRecoveryType() const { return builtinTypes->errorRecoveryType(); @@ -3262,12 +3654,12 @@ TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp) return arena->addTypePack(resultTypes, resultTail); } -LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() +LUAU_NOINLINE void ConstraintSolver::throwTimeLimitError() const { throw TimeLimitError(currentModuleName); } -LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() +LUAU_NOINLINE void ConstraintSolver::throwUserCancelError() const { throw UserCancelError(currentModuleName); } diff --git a/Analysis/src/DataFlowGraph.cpp b/Analysis/src/DataFlowGraph.cpp index 9c42e4d8..46c87845 100644 --- a/Analysis/src/DataFlowGraph.cpp +++ b/Analysis/src/DataFlowGraph.cpp @@ -2,11 +2,13 @@ #include "Luau/DataFlowGraph.h" #include "Luau/Ast.h" +#include "Luau/BuiltinDefinitions.h" #include "Luau/Def.h" #include "Luau/Common.h" #include "Luau/Error.h" #include "Luau/TimeTrace.h" +#include #include LUAU_FASTFLAG(DebugLuauFreezeArena) @@ -17,6 +19,38 @@ namespace Luau bool doesCallError(const AstExprCall* call); // TypeInfer.cpp +struct ReferencedDefFinder : public AstVisitor +{ + bool visit(AstExprLocal* local) override + { + referencedLocalDefs.push_back(local->local); + return true; + } + // ast defs is just a mapping from expr -> def in general + // will get built up by the dfg builder + + // localDefs, we need to copy over + std::vector referencedLocalDefs; +}; + +struct PushScope +{ + ScopeStack& stack; + + PushScope(ScopeStack& stack, DfgScope* scope) + : stack(stack) + { + // `scope` should never be `nullptr` here. + LUAU_ASSERT(scope); + stack.push_back(scope); + } + + ~PushScope() + { + stack.pop_back(); + } +}; + const RefinementKey* RefinementKeyArena::leaf(DefId def) { return allocator.allocate(RefinementKey{nullptr, def, std::nullopt}); @@ -27,6 +61,12 @@ const RefinementKey* RefinementKeyArena::node(const RefinementKey* parent, DefId return allocator.allocate(RefinementKey{parent, def, propName}); } +DataFlowGraph::DataFlowGraph(NotNull defArena, NotNull keyArena) + : defArena{defArena} + , keyArena{keyArena} +{ +} + DefId DataFlowGraph::getDef(const AstExpr* expr) const { auto def = astDefs.find(expr); @@ -34,6 +74,14 @@ DefId DataFlowGraph::getDef(const AstExpr* expr) const return NotNull{*def}; } +std::optional DataFlowGraph::getDefOptional(const AstExpr* expr) const +{ + auto def = astDefs.find(expr); + if (!def) + return std::nullopt; + return NotNull{*def}; +} + std::optional DataFlowGraph::getRValueDefForCompoundAssign(const AstExpr* expr) const { auto def = compoundAssignDefs.find(expr); @@ -135,16 +183,27 @@ bool DfgScope::canUpdateDefinition(DefId def, const std::string& key) const return true; } -DataFlowGraph DataFlowGraphBuilder::build(AstStatBlock* block, NotNull handle) +DataFlowGraphBuilder::DataFlowGraphBuilder(NotNull defArena, NotNull keyArena) + : graph{defArena, keyArena} + , defArena{defArena} + , keyArena{keyArena} +{ +} + +DataFlowGraph DataFlowGraphBuilder::build( + AstStatBlock* block, + NotNull defArena, + NotNull keyArena, + NotNull handle +) { LUAU_TIMETRACE_SCOPE("DataFlowGraphBuilder::build", "Typechecking"); - LUAU_ASSERT(FFlag::LuauSolverV2); - - DataFlowGraphBuilder builder; + DataFlowGraphBuilder builder(defArena, keyArena); builder.handle = handle; - builder.moduleScope = builder.childScope(nullptr); // nullptr is the root DFG scope. - builder.visitBlockWithoutChildScope(builder.moduleScope, block); + DfgScope* moduleScope = builder.makeChildScope(); + PushScope ps{builder.scopeStack, moduleScope}; + builder.visitBlockWithoutChildScope(block); builder.resolveCaptures(); if (FFlag::DebugLuauFreezeArena) @@ -174,9 +233,16 @@ void DataFlowGraphBuilder::resolveCaptures() } } -DfgScope* DataFlowGraphBuilder::childScope(DfgScope* scope, DfgScope::ScopeType scopeType) +DfgScope* DataFlowGraphBuilder::currentScope() { - return scopes.emplace_back(new DfgScope{scope, scopeType}).get(); + if (scopeStack.empty()) + return nullptr; // nullptr is the root DFG scope. + return scopeStack.back(); +} + +DfgScope* DataFlowGraphBuilder::makeChildScope(DfgScope::ScopeType scopeType) +{ + return scopes.emplace_back(new DfgScope{currentScope(), scopeType}).get(); } void DataFlowGraphBuilder::join(DfgScope* p, DfgScope* a, DfgScope* b) @@ -251,8 +317,10 @@ void DataFlowGraphBuilder::joinProps(DfgScope* result, const DfgScope& a, const } } -DefId DataFlowGraphBuilder::lookup(DfgScope* scope, Symbol symbol) +DefId DataFlowGraphBuilder::lookup(Symbol symbol) { + DfgScope* scope = currentScope(); + // true if any of the considered scopes are a loop. bool outsideLoopScope = false; for (DfgScope* current = scope; current; current = current->parent) @@ -282,8 +350,9 @@ DefId DataFlowGraphBuilder::lookup(DfgScope* scope, Symbol symbol) return result; } -DefId DataFlowGraphBuilder::lookup(DfgScope* scope, DefId def, const std::string& key) +DefId DataFlowGraphBuilder::lookup(DefId def, const std::string& key) { + DfgScope* scope = currentScope(); for (DfgScope* current = scope; current; current = current->parent) { if (auto props = current->props.find(def)) @@ -303,7 +372,7 @@ DefId DataFlowGraphBuilder::lookup(DfgScope* scope, DefId def, const std::string { std::vector defs; for (DefId operand : phi->operands) - defs.push_back(lookup(scope, operand, key)); + defs.push_back(lookup(operand, key)); DefId result = defArena->phi(defs); scope->props[def][key] = result; @@ -319,20 +388,26 @@ DefId DataFlowGraphBuilder::lookup(DfgScope* scope, DefId def, const std::string handle->ice("Inexhaustive lookup cases in DataFlowGraphBuilder::lookup"); } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBlock* b) +ControlFlow DataFlowGraphBuilder::visit(AstStatBlock* b) { - DfgScope* child = childScope(scope); - ControlFlow cf = visitBlockWithoutChildScope(child, b); - scope->inherit(child); + DfgScope* child = makeChildScope(); + + ControlFlow cf; + { + PushScope ps{scopeStack, child}; + cf = visitBlockWithoutChildScope(b); + } + + currentScope()->inherit(child); return cf; } -ControlFlow DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, AstStatBlock* b) +ControlFlow DataFlowGraphBuilder::visitBlockWithoutChildScope(AstStatBlock* b) { std::optional firstControlFlow; for (AstStat* stat : b->body) { - ControlFlow cf = visit(scope, stat); + ControlFlow cf = visit(stat); if (cf != ControlFlow::None && !firstControlFlow) firstControlFlow = cf; } @@ -340,66 +415,75 @@ ControlFlow DataFlowGraphBuilder::visitBlockWithoutChildScope(DfgScope* scope, A return firstControlFlow.value_or(ControlFlow::None); } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStat* s) +ControlFlow DataFlowGraphBuilder::visit(AstStat* s) { if (auto b = s->as()) - return visit(scope, b); + return visit(b); else if (auto i = s->as()) - return visit(scope, i); + return visit(i); else if (auto w = s->as()) - return visit(scope, w); + return visit(w); else if (auto r = s->as()) - return visit(scope, r); + return visit(r); else if (auto b = s->as()) - return visit(scope, b); + return visit(b); else if (auto c = s->as()) - return visit(scope, c); + return visit(c); else if (auto r = s->as()) - return visit(scope, r); + return visit(r); else if (auto e = s->as()) - return visit(scope, e); + return visit(e); else if (auto l = s->as()) - return visit(scope, l); + return visit(l); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto a = s->as()) - return visit(scope, a); + return visit(a); else if (auto c = s->as()) - return visit(scope, c); + return visit(c); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto l = s->as()) - return visit(scope, l); + return visit(l); else if (auto t = s->as()) - return visit(scope, t); + return visit(t); else if (auto f = s->as()) - return visit(scope, f); + return visit(f); else if (auto d = s->as()) - return visit(scope, d); + return visit(d); else if (auto d = s->as()) - return visit(scope, d); + return visit(d); else if (auto d = s->as()) - return visit(scope, d); + return visit(d); else if (auto error = s->as()) - return visit(scope, error); + return visit(error); else handle->ice("Unknown AstStat in DataFlowGraphBuilder::visit"); } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) +ControlFlow DataFlowGraphBuilder::visit(AstStatIf* i) { - visitExpr(scope, i->condition); + visitExpr(i->condition); - DfgScope* thenScope = childScope(scope); - DfgScope* elseScope = childScope(scope); + DfgScope* thenScope = makeChildScope(); + DfgScope* elseScope = makeChildScope(); + + ControlFlow thencf; + { + PushScope ps{scopeStack, thenScope}; + thencf = visit(i->thenbody); + } - ControlFlow thencf = visit(thenScope, i->thenbody); ControlFlow elsecf = ControlFlow::None; if (i->elsebody) - elsecf = visit(elseScope, i->elsebody); + { + PushScope ps{scopeStack, elseScope}; + elsecf = visit(i->elsebody); + } + DfgScope* scope = currentScope(); if (thencf != ControlFlow::None && elsecf == ControlFlow::None) join(scope, scope, elseScope); else if (thencf == ControlFlow::None && elsecf != ControlFlow::None) @@ -415,70 +499,78 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatIf* i) return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatWhile* w) +ControlFlow DataFlowGraphBuilder::visit(AstStatWhile* w) { // TODO(controlflow): entry point has a back edge from exit point - DfgScope* whileScope = childScope(scope, DfgScope::Loop); - visitExpr(whileScope, w->condition); - visit(whileScope, w->body); + DfgScope* whileScope = makeChildScope(DfgScope::Loop); - scope->inherit(whileScope); + { + PushScope ps{scopeStack, whileScope}; + visitExpr(w->condition); + visit(w->body); + } + + currentScope()->inherit(whileScope); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatRepeat* r) +ControlFlow DataFlowGraphBuilder::visit(AstStatRepeat* r) { // TODO(controlflow): entry point has a back edge from exit point - DfgScope* repeatScope = childScope(scope, DfgScope::Loop); - visitBlockWithoutChildScope(repeatScope, r->body); - visitExpr(repeatScope, r->condition); + DfgScope* repeatScope = makeChildScope(DfgScope::Loop); - scope->inherit(repeatScope); + { + PushScope ps{scopeStack, repeatScope}; + visitBlockWithoutChildScope(r->body); + visitExpr(r->condition); + } + + currentScope()->inherit(repeatScope); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatBreak* b) +ControlFlow DataFlowGraphBuilder::visit(AstStatBreak* b) { return ControlFlow::Breaks; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatContinue* c) +ControlFlow DataFlowGraphBuilder::visit(AstStatContinue* c) { return ControlFlow::Continues; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatReturn* r) +ControlFlow DataFlowGraphBuilder::visit(AstStatReturn* r) { for (AstExpr* e : r->list) - visitExpr(scope, e); + visitExpr(e); return ControlFlow::Returns; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatExpr* e) +ControlFlow DataFlowGraphBuilder::visit(AstStatExpr* e) { - visitExpr(scope, e->expr); + visitExpr(e->expr); if (auto call = e->expr->as(); call && doesCallError(call)) return ControlFlow::Throws; else return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) +ControlFlow DataFlowGraphBuilder::visit(AstStatLocal* l) { // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) std::vector defs; defs.reserve(l->values.size); for (AstExpr* e : l->values) - defs.push_back(visitExpr(scope, e).def); + defs.push_back(visitExpr(e).def); for (size_t i = 0; i < l->vars.size; ++i) { AstLocal* local = l->vars.data[i]; if (local->annotation) - visitType(scope, local->annotation); + visitType(local->annotation); // We need to create a new def to intentionally avoid alias tracking, but we'd like to // make sure that the non-aliased defs are also marked as a subscript for refinements. @@ -493,90 +585,98 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocal* l) } } graph.localDefs[local] = def; - scope->bindings[local] = def; + currentScope()->bindings[local] = def; captures[local].allVersions.push_back(def); } return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFor* f) +ControlFlow DataFlowGraphBuilder::visit(AstStatFor* f) { - DfgScope* forScope = childScope(scope, DfgScope::Loop); + DfgScope* forScope = makeChildScope(DfgScope::Loop); - visitExpr(scope, f->from); - visitExpr(scope, f->to); + visitExpr(f->from); + visitExpr(f->to); if (f->step) - visitExpr(scope, f->step); + visitExpr(f->step); - if (f->var->annotation) - visitType(forScope, f->var->annotation); - - DefId def = defArena->freshCell(); - graph.localDefs[f->var] = def; - scope->bindings[f->var] = def; - captures[f->var].allVersions.push_back(def); - - // TODO(controlflow): entry point has a back edge from exit point - visit(forScope, f->body); - - scope->inherit(forScope); - - return ControlFlow::None; -} - -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatForIn* f) -{ - DfgScope* forScope = childScope(scope, DfgScope::Loop); - - for (AstLocal* local : f->vars) { - if (local->annotation) - visitType(forScope, local->annotation); + PushScope ps{scopeStack, forScope}; + + if (f->var->annotation) + visitType(f->var->annotation); DefId def = defArena->freshCell(); - graph.localDefs[local] = def; - forScope->bindings[local] = def; - captures[local].allVersions.push_back(def); + graph.localDefs[f->var] = def; + currentScope()->bindings[f->var] = def; + captures[f->var].allVersions.push_back(def); + + // TODO(controlflow): entry point has a back edge from exit point + visit(f->body); } - // TODO(controlflow): entry point has a back edge from exit point - // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) - for (AstExpr* e : f->values) - visitExpr(forScope, e); - - visit(forScope, f->body); - - scope->inherit(forScope); + currentScope()->inherit(forScope); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatAssign* a) +ControlFlow DataFlowGraphBuilder::visit(AstStatForIn* f) +{ + DfgScope* forScope = makeChildScope(DfgScope::Loop); + + { + PushScope ps{scopeStack, forScope}; + + for (AstLocal* local : f->vars) + { + if (local->annotation) + visitType(local->annotation); + + DefId def = defArena->freshCell(); + graph.localDefs[local] = def; + currentScope()->bindings[local] = def; + captures[local].allVersions.push_back(def); + } + + // TODO(controlflow): entry point has a back edge from exit point + // We're gonna need a `visitExprList` and `visitVariadicExpr` (function calls and `...`) + for (AstExpr* e : f->values) + visitExpr(e); + + visit(f->body); + } + + currentScope()->inherit(forScope); + + return ControlFlow::None; +} + +ControlFlow DataFlowGraphBuilder::visit(AstStatAssign* a) { std::vector defs; defs.reserve(a->values.size); for (AstExpr* e : a->values) - defs.push_back(visitExpr(scope, e).def); + defs.push_back(visitExpr(e).def); for (size_t i = 0; i < a->vars.size; ++i) { AstExpr* v = a->vars.data[i]; - visitLValue(scope, v, i < defs.size() ? defs[i] : defArena->freshCell()); + visitLValue(v, i < defs.size() ? defs[i] : defArena->freshCell()); } return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatCompoundAssign* c) +ControlFlow DataFlowGraphBuilder::visit(AstStatCompoundAssign* c) { - (void) visitExpr(scope, c->value); - (void) visitExpr(scope, c->var); + (void)visitExpr(c->value); + (void)visitExpr(c->var); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) +ControlFlow DataFlowGraphBuilder::visit(AstStatFunction* f) { // In the old solver, we assumed that the name of the function is always a function in the body // but this isn't true, e.g. the following example will print `5`, not a function address. @@ -588,8 +688,8 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) // // which is evidence that references to variables must be a phi node of all possible definitions, // but for bug compatibility, we'll assume the same thing here. - visitLValue(scope, f->name, defArena->freshCell()); - visitExpr(scope, f->func); + visitLValue(f->name, defArena->freshCell()); + visitExpr(f->func); if (auto local = f->name->as()) { @@ -606,87 +706,97 @@ ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatFunction* f) return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatLocalFunction* l) +ControlFlow DataFlowGraphBuilder::visit(AstStatLocalFunction* l) { DefId def = defArena->freshCell(); graph.localDefs[l->name] = def; - scope->bindings[l->name] = def; + currentScope()->bindings[l->name] = def; captures[l->name].allVersions.push_back(def); - visitExpr(scope, l->func); + visitExpr(l->func); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeAlias* t) +ControlFlow DataFlowGraphBuilder::visit(AstStatTypeAlias* t) { - DfgScope* unreachable = childScope(scope); - visitGenerics(unreachable, t->generics); - visitGenericPacks(unreachable, t->genericPacks); - visitType(unreachable, t->type); + DfgScope* unreachable = makeChildScope(); + PushScope ps{scopeStack, unreachable}; + + visitGenerics(t->generics); + visitGenericPacks(t->genericPacks); + visitType(t->type); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatTypeFunction* f) +ControlFlow DataFlowGraphBuilder::visit(AstStatTypeFunction* f) { - DfgScope* unreachable = childScope(scope); - visitExpr(unreachable, f->body); + DfgScope* unreachable = makeChildScope(); + PushScope ps{scopeStack, unreachable}; + + visitExpr(f->body); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareGlobal* d) +ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareGlobal* d) { DefId def = defArena->freshCell(); graph.declaredDefs[d] = def; - scope->bindings[d->name] = def; + currentScope()->bindings[d->name] = def; captures[d->name].allVersions.push_back(def); - visitType(scope, d->type); + visitType(d->type); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareFunction* d) +ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareFunction* d) { DefId def = defArena->freshCell(); graph.declaredDefs[d] = def; - scope->bindings[d->name] = def; + currentScope()->bindings[d->name] = def; captures[d->name].allVersions.push_back(def); - DfgScope* unreachable = childScope(scope); - visitGenerics(unreachable, d->generics); - visitGenericPacks(unreachable, d->genericPacks); - visitTypeList(unreachable, d->params); - visitTypeList(unreachable, d->retTypes); + DfgScope* unreachable = makeChildScope(); + PushScope ps{scopeStack, unreachable}; + + visitGenerics(d->generics); + visitGenericPacks(d->genericPacks); + visitTypeList(d->params); + visitTypeList(d->retTypes); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatDeclareClass* d) +ControlFlow DataFlowGraphBuilder::visit(AstStatDeclareClass* d) { // This declaration does not "introduce" any bindings in value namespace, // so there's no symbolic value to begin with. We'll traverse the properties // because their type annotations may depend on something in the value namespace. - DfgScope* unreachable = childScope(scope); + DfgScope* unreachable = makeChildScope(); + PushScope ps{scopeStack, unreachable}; + for (AstDeclaredClassProp prop : d->props) - visitType(unreachable, prop.ty); + visitType(prop.ty); return ControlFlow::None; } -ControlFlow DataFlowGraphBuilder::visit(DfgScope* scope, AstStatError* error) +ControlFlow DataFlowGraphBuilder::visit(AstStatError* error) { - DfgScope* unreachable = childScope(scope); + DfgScope* unreachable = makeChildScope(); + PushScope ps{scopeStack, unreachable}; + for (AstStat* s : error->statements) - visit(unreachable, s); + visit(s); for (AstExpr* e : error->expressions) - visitExpr(unreachable, e); + visitExpr(e); return ControlFlow::None; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExpr* e) { // Some subexpressions could be visited two times. If we've already seen it, just extract it. if (auto def = graph.astDefs.find(e)) @@ -698,7 +808,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) auto go = [&]() -> DataFlowResult { if (auto g = e->as()) - return visitExpr(scope, g); + return visitExpr(g); else if (auto c = e->as()) return {defArena->freshCell(), nullptr}; // ok else if (auto c = e->as()) @@ -708,33 +818,33 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) else if (auto c = e->as()) return {defArena->freshCell(), nullptr}; // ok else if (auto l = e->as()) - return visitExpr(scope, l); + return visitExpr(l); else if (auto g = e->as()) - return visitExpr(scope, g); + return visitExpr(g); else if (auto v = e->as()) return {defArena->freshCell(), nullptr}; // ok else if (auto c = e->as()) - return visitExpr(scope, c); + return visitExpr(c); else if (auto i = e->as()) - return visitExpr(scope, i); + return visitExpr(i); else if (auto i = e->as()) - return visitExpr(scope, i); + return visitExpr(i); else if (auto f = e->as()) - return visitExpr(scope, f); + return visitExpr(f); else if (auto t = e->as()) - return visitExpr(scope, t); + return visitExpr(t); else if (auto u = e->as()) - return visitExpr(scope, u); + return visitExpr(u); else if (auto b = e->as()) - return visitExpr(scope, b); + return visitExpr(b); else if (auto t = e->as()) - return visitExpr(scope, t); + return visitExpr(t); else if (auto i = e->as()) - return visitExpr(scope, i); + return visitExpr(i); else if (auto i = e->as()) - return visitExpr(scope, i); + return visitExpr(i); else if (auto error = e->as()) - return visitExpr(scope, error); + return visitExpr(error); else handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitExpr"); }; @@ -746,64 +856,94 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExpr* e) return {def, key}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGroup* group) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprGroup* group) { - return visitExpr(scope, group->expr); + return visitExpr(group->expr); } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprLocal* l) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprLocal* l) { - DefId def = lookup(scope, l->local); + DefId def = lookup(l->local); const RefinementKey* key = keyArena->leaf(def); return {def, key}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprGlobal* g) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprGlobal* g) { - DefId def = lookup(scope, g->name); + DefId def = lookup(g->name); return {def, keyArena->leaf(def)}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprCall* c) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprCall* c) { - visitExpr(scope, c->func); + visitExpr(c->func); + + if (shouldTypestateForFirstArgument(*c) && c->args.size > 1 && isLValue(*c->args.begin())) + { + AstExpr* firstArg = *c->args.begin(); + + // this logic has to handle the name-like subset of expressions. + std::optional result; + if (auto l = firstArg->as()) + result = visitExpr(l); + else if (auto g = firstArg->as()) + result = visitExpr(g); + else if (auto i = firstArg->as()) + result = visitExpr(i); + else if (auto i = firstArg->as()) + result = visitExpr(i); + else + LUAU_UNREACHABLE(); // This is unreachable because the whole thing is guarded by `isLValue`. + + LUAU_ASSERT(result); + + DfgScope* child = makeChildScope(); + scopeStack.push_back(child); + + auto [def, key] = *result; + graph.astDefs[firstArg] = def; + if (key) + graph.astRefinementKeys[firstArg] = key; + + visitLValue(firstArg, def); + } for (AstExpr* arg : c->args) - visitExpr(scope, arg); + visitExpr(arg); // calls should be treated as subscripted. return {defArena->freshCell(/* subscripted */ true), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexName* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexName* i) { - auto [parentDef, parentKey] = visitExpr(scope, i->expr); - + auto [parentDef, parentKey] = visitExpr(i->expr); std::string index = i->index.value; - DefId def = lookup(scope, parentDef, index); + DefId def = lookup(parentDef, index); return {def, keyArena->node(parentKey, def, index)}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIndexExpr* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIndexExpr* i) { - auto [parentDef, parentKey] = visitExpr(scope, i->expr); - visitExpr(scope, i->index); + auto [parentDef, parentKey] = visitExpr(i->expr); + visitExpr(i->index); if (auto string = i->index->as()) { std::string index{string->value.data, string->value.size}; - DefId def = lookup(scope, parentDef, index); + DefId def = lookup(parentDef, index); return {def, keyArena->node(parentKey, def, index)}; } return {defArena->freshCell(/* subscripted= */ true), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* f) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprFunction* f) { - DfgScope* signatureScope = childScope(scope, DfgScope::Function); + DfgScope* signatureScope = makeChildScope(DfgScope::Function); + PushScope ps{scopeStack, signatureScope}; if (AstLocal* self = f->self) { @@ -819,7 +959,7 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* for (AstLocal* param : f->args) { if (param->annotation) - visitType(signatureScope, param->annotation); + visitType(param->annotation); DefId def = defArena->freshCell(); graph.localDefs[param] = def; @@ -828,10 +968,10 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* } if (f->varargAnnotation) - visitTypePack(scope, f->varargAnnotation); + visitTypePack(f->varargAnnotation); if (f->returnAnnotation) - visitTypeList(signatureScope, *f->returnAnnotation); + visitTypeList(*f->returnAnnotation); // TODO: function body can be re-entrant, as in mutations that occurs at the end of the function can also be // visible to the beginning of the function, so statically speaking, the body of the function has an exit point @@ -841,92 +981,94 @@ DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprFunction* // local g = f // g() --> function: address // g() --> 5 - visit(signatureScope, f->body); + visit(f->body); return {defArena->freshCell(), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTable* t) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprTable* t) { DefId tableCell = defArena->freshCell(); - scope->props[tableCell] = {}; + currentScope()->props[tableCell] = {}; for (AstExprTable::Item item : t->items) { - DataFlowResult result = visitExpr(scope, item.value); + DataFlowResult result = visitExpr(item.value); if (item.key) { - visitExpr(scope, item.key); + visitExpr(item.key); if (auto string = item.key->as()) - scope->props[tableCell][string->value.data] = result.def; + currentScope()->props[tableCell][string->value.data] = result.def; } } return {tableCell, nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprUnary* u) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprUnary* u) { - visitExpr(scope, u->expr); + visitExpr(u->expr); return {defArena->freshCell(), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprBinary* b) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprBinary* b) { - visitExpr(scope, b->left); - visitExpr(scope, b->right); + visitExpr(b->left); + visitExpr(b->right); return {defArena->freshCell(), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprTypeAssertion* t) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprTypeAssertion* t) { - auto [def, key] = visitExpr(scope, t->expr); - visitType(scope, t->annotation); + auto [def, key] = visitExpr(t->expr); + visitType(t->annotation); return {def, key}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprIfElse* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprIfElse* i) { - visitExpr(scope, i->condition); - visitExpr(scope, i->trueExpr); - visitExpr(scope, i->falseExpr); + visitExpr(i->condition); + visitExpr(i->trueExpr); + visitExpr(i->falseExpr); return {defArena->freshCell(), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprInterpString* i) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprInterpString* i) { for (AstExpr* e : i->expressions) - visitExpr(scope, e); + visitExpr(e); return {defArena->freshCell(), nullptr}; } -DataFlowResult DataFlowGraphBuilder::visitExpr(DfgScope* scope, AstExprError* error) +DataFlowResult DataFlowGraphBuilder::visitExpr(AstExprError* error) { - DfgScope* unreachable = childScope(scope); + DfgScope* unreachable = makeChildScope(); + PushScope ps{scopeStack, unreachable}; + for (AstExpr* e : error->expressions) - visitExpr(unreachable, e); + visitExpr(e); return {defArena->freshCell(), nullptr}; } -void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomingDef) +void DataFlowGraphBuilder::visitLValue(AstExpr* e, DefId incomingDef) { auto go = [&]() { if (auto l = e->as()) - return visitLValue(scope, l, incomingDef); + return visitLValue(l, incomingDef); else if (auto g = e->as()) - return visitLValue(scope, g, incomingDef); + return visitLValue(g, incomingDef); else if (auto i = e->as()) - return visitLValue(scope, i, incomingDef); + return visitLValue(i, incomingDef); else if (auto i = e->as()) - return visitLValue(scope, i, incomingDef); + return visitLValue(i, incomingDef); else if (auto error = e->as()) - return visitLValue(scope, error, incomingDef); + return visitLValue(error, incomingDef); else handle->ice("Unknown AstExpr in DataFlowGraphBuilder::visitLValue"); }; @@ -934,8 +1076,10 @@ void DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExpr* e, DefId incomi graph.astDefs[e] = go(); } -DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(AstExprLocal* l, DefId incomingDef) { + DfgScope* scope = currentScope(); + // In order to avoid alias tracking, we need to clip the reference to the parent def. if (scope->canUpdateDefinition(l->local)) { @@ -945,11 +1089,13 @@ DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprLocal* l, DefId return updated; } else - return visitExpr(scope, static_cast(l)).def; + return visitExpr(static_cast(l)).def; } -DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(AstExprGlobal* g, DefId incomingDef) { + DfgScope* scope = currentScope(); + // In order to avoid alias tracking, we need to clip the reference to the parent def. if (scope->canUpdateDefinition(g->name)) { @@ -959,13 +1105,14 @@ DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprGlobal* g, DefId return updated; } else - return visitExpr(scope, static_cast(g)).def; + return visitExpr(static_cast(g)).def; } -DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(AstExprIndexName* i, DefId incomingDef) { - DefId parentDef = visitExpr(scope, i->expr).def; + DefId parentDef = visitExpr(i->expr).def; + DfgScope* scope = currentScope(); if (scope->canUpdateDefinition(parentDef, i->index.value)) { DefId updated = defArena->freshCell(containsSubscriptedDefinition(incomingDef)); @@ -973,14 +1120,15 @@ DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexName* i, De return updated; } else - return visitExpr(scope, static_cast(i)).def; + return visitExpr(static_cast(i)).def; } -DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(AstExprIndexExpr* i, DefId incomingDef) { - DefId parentDef = visitExpr(scope, i->expr).def; - visitExpr(scope, i->index); + DefId parentDef = visitExpr(i->expr).def; + visitExpr(i->index); + DfgScope* scope = currentScope(); if (auto string = i->index->as()) { if (scope->canUpdateDefinition(parentDef, string->value.data)) @@ -990,141 +1138,143 @@ DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprIndexExpr* i, De return updated; } else - return visitExpr(scope, static_cast(i)).def; + return visitExpr(static_cast(i)).def; } else return defArena->freshCell(/*subscripted=*/true); } -DefId DataFlowGraphBuilder::visitLValue(DfgScope* scope, AstExprError* error, DefId incomingDef) +DefId DataFlowGraphBuilder::visitLValue(AstExprError* error, DefId incomingDef) { - return visitExpr(scope, error).def; + return visitExpr(error).def; } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstType* t) +void DataFlowGraphBuilder::visitType(AstType* t) { if (auto r = t->as()) - return visitType(scope, r); + return visitType(r); else if (auto table = t->as()) - return visitType(scope, table); + return visitType(table); else if (auto f = t->as()) - return visitType(scope, f); + return visitType(f); else if (auto tyof = t->as()) - return visitType(scope, tyof); + return visitType(tyof); else if (auto u = t->as()) - return visitType(scope, u); + return visitType(u); else if (auto i = t->as()) - return visitType(scope, i); + return visitType(i); else if (auto e = t->as()) - return visitType(scope, e); + return visitType(e); else if (auto s = t->as()) return; // ok else if (auto s = t->as()) return; // ok + else if (auto g = t->as()) + return visitType(g->type); else handle->ice("Unknown AstType in DataFlowGraphBuilder::visitType"); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeReference* r) +void DataFlowGraphBuilder::visitType(AstTypeReference* r) { for (AstTypeOrPack param : r->parameters) { if (param.type) - visitType(scope, param.type); + visitType(param.type); else - visitTypePack(scope, param.typePack); + visitTypePack(param.typePack); } } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeTable* t) +void DataFlowGraphBuilder::visitType(AstTypeTable* t) { for (AstTableProp p : t->props) - visitType(scope, p.type); + visitType(p.type); if (t->indexer) { - visitType(scope, t->indexer->indexType); - visitType(scope, t->indexer->resultType); + visitType(t->indexer->indexType); + visitType(t->indexer->resultType); } } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeFunction* f) +void DataFlowGraphBuilder::visitType(AstTypeFunction* f) { - visitGenerics(scope, f->generics); - visitGenericPacks(scope, f->genericPacks); - visitTypeList(scope, f->argTypes); - visitTypeList(scope, f->returnTypes); + visitGenerics(f->generics); + visitGenericPacks(f->genericPacks); + visitTypeList(f->argTypes); + visitTypeList(f->returnTypes); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeTypeof* t) +void DataFlowGraphBuilder::visitType(AstTypeTypeof* t) { - visitExpr(scope, t->expr); + visitExpr(t->expr); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeUnion* u) +void DataFlowGraphBuilder::visitType(AstTypeUnion* u) { for (AstType* t : u->types) - visitType(scope, t); + visitType(t); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeIntersection* i) +void DataFlowGraphBuilder::visitType(AstTypeIntersection* i) { for (AstType* t : i->types) - visitType(scope, t); + visitType(t); } -void DataFlowGraphBuilder::visitType(DfgScope* scope, AstTypeError* error) +void DataFlowGraphBuilder::visitType(AstTypeError* error) { for (AstType* t : error->types) - visitType(scope, t); + visitType(t); } -void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePack* p) +void DataFlowGraphBuilder::visitTypePack(AstTypePack* p) { if (auto e = p->as()) - return visitTypePack(scope, e); + return visitTypePack(e); else if (auto v = p->as()) - return visitTypePack(scope, v); + return visitTypePack(v); else if (auto g = p->as()) return; // ok else handle->ice("Unknown AstTypePack in DataFlowGraphBuilder::visitTypePack"); } -void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePackExplicit* e) +void DataFlowGraphBuilder::visitTypePack(AstTypePackExplicit* e) { - visitTypeList(scope, e->typeList); + visitTypeList(e->typeList); } -void DataFlowGraphBuilder::visitTypePack(DfgScope* scope, AstTypePackVariadic* v) +void DataFlowGraphBuilder::visitTypePack(AstTypePackVariadic* v) { - visitType(scope, v->variadicType); + visitType(v->variadicType); } -void DataFlowGraphBuilder::visitTypeList(DfgScope* scope, AstTypeList l) +void DataFlowGraphBuilder::visitTypeList(AstTypeList l) { for (AstType* t : l.types) - visitType(scope, t); + visitType(t); if (l.tailType) - visitTypePack(scope, l.tailType); + visitTypePack(l.tailType); } -void DataFlowGraphBuilder::visitGenerics(DfgScope* scope, AstArray g) +void DataFlowGraphBuilder::visitGenerics(AstArray g) { - for (AstGenericType generic : g) + for (AstGenericType* generic : g) { - if (generic.defaultValue) - visitType(scope, generic.defaultValue); + if (generic->defaultValue) + visitType(generic->defaultValue); } } -void DataFlowGraphBuilder::visitGenericPacks(DfgScope* scope, AstArray g) +void DataFlowGraphBuilder::visitGenericPacks(AstArray g) { - for (AstGenericTypePack generic : g) + for (AstGenericTypePack* generic : g) { - if (generic.defaultValue) - visitTypePack(scope, generic.defaultValue); + if (generic->defaultValue) + visitTypePack(generic->defaultValue); } } diff --git a/Analysis/src/Differ.cpp b/Analysis/src/Differ.cpp index 25687e11..e6222067 100644 --- a/Analysis/src/Differ.cpp +++ b/Analysis/src/Differ.cpp @@ -718,7 +718,7 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig env.popVisiting(); return diffRes; } - if (auto le = get(left)) + if (auto le = get(left)) { // TODO: return debug-friendly result state env.popVisiting(); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index e539661a..ff2f02c0 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,102 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAGVARIABLE(LuauDebugInfoDefn) + namespace Luau { -static const std::string kBuiltinDefinitionLuaSrcChecked = R"BUILTIN_SRC( - -declare bit32: { - band: @checked (...number) -> number, - bor: @checked (...number) -> number, - bxor: @checked (...number) -> number, - btest: @checked (number, ...number) -> boolean, - rrotate: @checked (x: number, disp: number) -> number, - lrotate: @checked (x: number, disp: number) -> number, - lshift: @checked (x: number, disp: number) -> number, - arshift: @checked (x: number, disp: number) -> number, - rshift: @checked (x: number, disp: number) -> number, - bnot: @checked (x: number) -> number, - extract: @checked (n: number, field: number, width: number?) -> number, - replace: @checked (n: number, v: number, field: number, width: number?) -> number, - countlz: @checked (n: number) -> number, - countrz: @checked (n: number) -> number, - byteswap: @checked (n: number) -> number, -} - -declare math: { - frexp: @checked (n: number) -> (number, number), - ldexp: @checked (s: number, e: number) -> number, - fmod: @checked (x: number, y: number) -> number, - modf: @checked (n: number) -> (number, number), - pow: @checked (x: number, y: number) -> number, - exp: @checked (n: number) -> number, - - ceil: @checked (n: number) -> number, - floor: @checked (n: number) -> number, - abs: @checked (n: number) -> number, - sqrt: @checked (n: number) -> number, - - log: @checked (n: number, base: number?) -> number, - log10: @checked (n: number) -> number, - - rad: @checked (n: number) -> number, - deg: @checked (n: number) -> number, - - sin: @checked (n: number) -> number, - cos: @checked (n: number) -> number, - tan: @checked (n: number) -> number, - sinh: @checked (n: number) -> number, - cosh: @checked (n: number) -> number, - tanh: @checked (n: number) -> number, - atan: @checked (n: number) -> number, - acos: @checked (n: number) -> number, - asin: @checked (n: number) -> number, - atan2: @checked (y: number, x: number) -> number, - - min: @checked (number, ...number) -> number, - max: @checked (number, ...number) -> number, - - pi: number, - huge: number, - - randomseed: @checked (seed: number) -> (), - random: @checked (number?, number?) -> number, - - sign: @checked (n: number) -> number, - clamp: @checked (n: number, min: number, max: number) -> number, - noise: @checked (x: number, y: number?, z: number?) -> number, - round: @checked (n: number) -> number, -} - -type DateTypeArg = { - year: number, - month: number, - day: number, - hour: number?, - min: number?, - sec: number?, - isdst: boolean?, -} - -type DateTypeResult = { - year: number, - month: number, - wday: number, - yday: number, - day: number, - hour: number, - min: number, - sec: number, - isdst: boolean, -} - -declare os: { - time: (time: DateTypeArg?) -> number, - date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string), - difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, - clock: () -> number, -} +static const std::string kBuiltinDefinitionBaseSrc = R"BUILTIN_SRC( @checked declare function require(target: any): any @@ -144,6 +54,119 @@ declare function loadstring(src: string, chunkname: string?): (((A...) -> @checked declare function newproxy(mt: boolean?): any +-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. +declare function unpack(tab: {V}, i: number?, j: number?): ...V + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionBit32Src = R"BUILTIN_SRC( + +declare bit32: { + band: @checked (...number) -> number, + bor: @checked (...number) -> number, + bxor: @checked (...number) -> number, + btest: @checked (number, ...number) -> boolean, + rrotate: @checked (x: number, disp: number) -> number, + lrotate: @checked (x: number, disp: number) -> number, + lshift: @checked (x: number, disp: number) -> number, + arshift: @checked (x: number, disp: number) -> number, + rshift: @checked (x: number, disp: number) -> number, + bnot: @checked (x: number) -> number, + extract: @checked (n: number, field: number, width: number?) -> number, + replace: @checked (n: number, v: number, field: number, width: number?) -> number, + countlz: @checked (n: number) -> number, + countrz: @checked (n: number) -> number, + byteswap: @checked (n: number) -> number, +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionMathSrc = R"BUILTIN_SRC( + +declare math: { + frexp: @checked (n: number) -> (number, number), + ldexp: @checked (s: number, e: number) -> number, + fmod: @checked (x: number, y: number) -> number, + modf: @checked (n: number) -> (number, number), + pow: @checked (x: number, y: number) -> number, + exp: @checked (n: number) -> number, + + ceil: @checked (n: number) -> number, + floor: @checked (n: number) -> number, + abs: @checked (n: number) -> number, + sqrt: @checked (n: number) -> number, + + log: @checked (n: number, base: number?) -> number, + log10: @checked (n: number) -> number, + + rad: @checked (n: number) -> number, + deg: @checked (n: number) -> number, + + sin: @checked (n: number) -> number, + cos: @checked (n: number) -> number, + tan: @checked (n: number) -> number, + sinh: @checked (n: number) -> number, + cosh: @checked (n: number) -> number, + tanh: @checked (n: number) -> number, + atan: @checked (n: number) -> number, + acos: @checked (n: number) -> number, + asin: @checked (n: number) -> number, + atan2: @checked (y: number, x: number) -> number, + + min: @checked (number, ...number) -> number, + max: @checked (number, ...number) -> number, + + pi: number, + huge: number, + + randomseed: @checked (seed: number) -> (), + random: @checked (number?, number?) -> number, + + sign: @checked (n: number) -> number, + clamp: @checked (n: number, min: number, max: number) -> number, + noise: @checked (x: number, y: number?, z: number?) -> number, + round: @checked (n: number) -> number, + map: @checked (x: number, inmin: number, inmax: number, outmin: number, outmax: number) -> number, + lerp: @checked (a: number, b: number, t: number) -> number, +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionOsSrc = R"BUILTIN_SRC( + +type DateTypeArg = { + year: number, + month: number, + day: number, + hour: number?, + min: number?, + sec: number?, + isdst: boolean?, +} + +type DateTypeResult = { + year: number, + month: number, + wday: number, + yday: number, + day: number, + hour: number, + min: number, + sec: number, + isdst: boolean, +} + +declare os: { + time: (time: DateTypeArg?) -> number, + date: ((formatString: "*t" | "!*t", time: number?) -> DateTypeResult) & ((formatString: string?, time: number?) -> string), + difftime: (t2: DateTypeResult | number, t1: DateTypeResult | number) -> number, + clock: () -> number, +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionCoroutineSrc = R"BUILTIN_SRC( + declare coroutine: { create: (f: (A...) -> R...) -> thread, resume: (co: thread, A...) -> (boolean, R...), @@ -155,6 +178,10 @@ declare coroutine: { close: @checked (co: thread) -> (boolean, any) } +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionTableSrc = R"BUILTIN_SRC( + declare table: { concat: (t: {V}, sep: string?, i: number?, j: number?) -> string, insert: ((t: {V}, value: V) -> ()) & ((t: {V}, pos: number, value: V) -> ()), @@ -177,11 +204,28 @@ declare table: { isfrozen: (t: {[K]: V}) -> boolean, } +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionDebugSrc = R"BUILTIN_SRC( + +declare debug: { + info: ((thread: thread, level: number, options: string) -> ...any) & ((level: number, options: string) -> ...any) & ((func: (A...) -> R1..., options: string) -> ...any), + traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionDebugSrc_DEPRECATED = R"BUILTIN_SRC( + declare debug: { info: ((thread: thread, level: number, options: string) -> R...) & ((level: number, options: string) -> R...) & ((func: (A...) -> R1..., options: string) -> R2...), traceback: ((message: string?, level: number?) -> string) & ((thread: thread, message: string?, level: number?) -> string), } +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionUtf8Src = R"BUILTIN_SRC( + declare utf8: { char: @checked (...number) -> string, charpattern: string, @@ -191,10 +235,9 @@ declare utf8: { offset: @checked (s: string, n: number?, i: number?) -> number, } --- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. -declare function unpack(tab: {V}, i: number?, j: number?): ...V - +)BUILTIN_SRC"; +static const std::string kBuiltinDefinitionBufferSrc = R"BUILTIN_SRC( --- Buffer API declare buffer: { create: @checked (size: number) -> buffer, @@ -221,13 +264,56 @@ declare buffer: { writef64: @checked (b: buffer, offset: number, value: number) -> (), readstring: @checked (b: buffer, offset: number, count: number) -> string, writestring: @checked (b: buffer, offset: number, value: string, count: number?) -> (), + readbits: @checked (b: buffer, bitOffset: number, bitCount: number) -> number, + writebits: @checked (b: buffer, bitOffset: number, bitCount: number, value: number) -> (), +} + +)BUILTIN_SRC"; + +static const std::string kBuiltinDefinitionVectorSrc = R"BUILTIN_SRC( + +-- While vector would have been better represented as a built-in primitive type, type solver class handling covers most of the properties +declare class vector + x: number + y: number + z: number +end + +declare vector: { + create: @checked (x: number, y: number, z: number?) -> vector, + magnitude: @checked (vec: vector) -> number, + normalize: @checked (vec: vector) -> vector, + cross: @checked (vec1: vector, vec2: vector) -> vector, + dot: @checked (vec1: vector, vec2: vector) -> number, + angle: @checked (vec1: vector, vec2: vector, axis: vector?) -> number, + floor: @checked (vec: vector) -> vector, + ceil: @checked (vec: vector) -> vector, + abs: @checked (vec: vector) -> vector, + sign: @checked (vec: vector) -> vector, + clamp: @checked (vec: vector, min: vector, max: vector) -> vector, + max: @checked (vector, ...vector) -> vector, + min: @checked (vector, ...vector) -> vector, + + zero: vector, + one: vector, } )BUILTIN_SRC"; std::string getBuiltinDefinitionSource() { - std::string result = kBuiltinDefinitionLuaSrcChecked; + std::string result = kBuiltinDefinitionBaseSrc; + + result += kBuiltinDefinitionBit32Src; + result += kBuiltinDefinitionMathSrc; + result += kBuiltinDefinitionOsSrc; + result += kBuiltinDefinitionCoroutineSrc; + result += kBuiltinDefinitionTableSrc; + result += FFlag::LuauDebugInfoDefn ? kBuiltinDefinitionDebugSrc : kBuiltinDefinitionDebugSrc_DEPRECATED; + result += kBuiltinDefinitionUtf8Src; + result += kBuiltinDefinitionBufferSrc; + result += kBuiltinDefinitionVectorSrc; + return result; } diff --git a/Analysis/src/EqSatSimplification.cpp b/Analysis/src/EqSatSimplification.cpp new file mode 100644 index 00000000..edcc42fb --- /dev/null +++ b/Analysis/src/EqSatSimplification.cpp @@ -0,0 +1,2698 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/EqSatSimplification.h" +#include "Luau/EqSatSimplificationImpl.h" + +#include "Luau/EGraph.h" +#include "Luau/Id.h" +#include "Luau/Language.h" + +#include "Luau/StringUtils.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFunction.h" +#include "Luau/VisitType.h" + +#include +#include +#include +#include +#include +#include +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauLogSimplification) +LUAU_FASTFLAGVARIABLE(DebugLuauLogSimplificationToDot) +LUAU_FASTFLAGVARIABLE(DebugLuauExtraEqSatSanityChecks) + +namespace Luau::EqSatSimplification +{ +using Id = Luau::EqSat::Id; + +using EGraph = Luau::EqSat::EGraph; +using Luau::EqSat::Slice; + +TTable::TTable(Id basis) +{ + storage.push_back(basis); +} + +// I suspect that this is going to become a performance hotspot. It would be +// nice to avoid allocating propTypes_ +TTable::TTable(Id basis, std::vector propNames_, std::vector propTypes_) + : propNames(std::move(propNames_)) +{ + storage.reserve(propTypes_.size() + 1); + storage.push_back(basis); + storage.insert(storage.end(), propTypes_.begin(), propTypes_.end()); + + LUAU_ASSERT(storage.size() == 1 + propTypes_.size()); +} + +Id TTable::getBasis() const +{ + LUAU_ASSERT(!storage.empty()); + return storage[0]; +} + +Slice TTable::propTypes() const +{ + LUAU_ASSERT(propNames.size() + 1 == storage.size()); + + return Slice{storage.data() + 1, propNames.size()}; +} + +Slice TTable::mutableOperands() +{ + return Slice{storage.data(), storage.size()}; +} + +Slice TTable::operands() const +{ + return Slice{storage.data(), storage.size()}; +} + +bool TTable::operator==(const TTable& rhs) const +{ + return storage == rhs.storage && propNames == rhs.propNames; +} + +size_t TTable::Hash::operator()(const TTable& value) const +{ + size_t hash = 0; + + // We're using pointers here, which does mean platform divergence. I think + // it's okay? (famous last words, I know) + for (StringId s : value.propNames) + EqSat::hashCombine(hash, EqSat::languageHash(s)); + + EqSat::hashCombine(hash, EqSat::languageHash(value.storage)); + + return hash; +} + +StringId StringCache::add(std::string_view s) +{ + /* Important subtlety: This use of DenseHashMap + * is okay because std::hash works solely on the bytes + * referred by the string_view. + * + * In other words, two string views which contain the same bytes will have + * the same hash whether or not their addresses are the same. + */ + if (StringId* it = strings.find(s)) + return *it; + + char* storage = static_cast(allocator.allocate(s.size())); + memcpy(storage, s.data(), s.size()); + + StringId result = StringId(views.size()); + views.emplace_back(storage, s.size()); + strings[s] = result; + return result; +} + +std::string_view StringCache::asStringView(StringId id) const +{ + LUAU_ASSERT(id < views.size()); + return views[id]; +} + +std::string StringCache::asString(StringId id) const +{ + return std::string{asStringView(id)}; +} + +template +Simplify::Data Simplify::make(const EGraph&, const T&) const +{ + return true; +} + +void Simplify::join(Data& left, const Data& right) const +{ + left = left || right; +} + +using EClass = Luau::EqSat::EClass; + +// A terminal type is a type that does not contain any other types. +// Examples: any, unknown, number, string, boolean, nil, table, class, thread, function +// +// All class types are also terminal. +static bool isTerminal(const EType& node) +{ + return node.get() || node.get() || node.get() || node.get() || node.get() || + node.get() || node.get() || node.get() || node.get() || node.get() || + node.get() || node.get() || node.get() || node.get() || node.get() || node.get() || + node.get() || node.get(); +} + +static bool areTerminalAndDefinitelyDisjoint(const EType& lhs, const EType& rhs) +{ + // If either node is non-terminal, then we early exit: we're not going to + // do a state space search for whether something like: + // (A | B | C | D) & (E | F | G | H) + // ... is a disjoint intersection. + if (!isTerminal(lhs) || !isTerminal(rhs)) + return false; + + // Special case some types that aren't strict, disjoint subsets. + if (lhs.get() || lhs.get()) + return !(rhs.get() || rhs.get()); + + // Handling strings / booleans: these are the types for which we + // expect something like: + // + // "foo" & ~"bar" + // + // ... to simplify to "foo". + if (lhs.get()) + return !(rhs.get() || rhs.get()); + + if (lhs.get()) + return !(rhs.get() || rhs.get()); + + if (auto lhsSString = lhs.get()) + { + auto rhsSString = rhs.get(); + if (!rhsSString) + return !rhs.get(); + return lhsSString->value() != rhsSString->value(); + } + + if (auto lhsSBoolean = lhs.get()) + { + auto rhsSBoolean = rhs.get(); + if (!rhsSBoolean) + return !rhs.get(); + return lhsSBoolean->value() != rhsSBoolean->value(); + } + + // At this point: + // - We know both nodes are terminal + // - We know that the LHS is not any boolean, string, or class + // At this point, we have two classes of checks left: + // - Whether the two enodes are exactly the same set (now that the static + // sets have been covered). + // - Whether one of the enodes is a large semantic set such as TAny, + // TUnknown, or TError. + return !( + lhs.index() == rhs.index() || lhs.get() || rhs.get() || lhs.get() || rhs.get() || lhs.get() || + rhs.get() || lhs.get() || rhs.get() || lhs.get() || rhs.get() + ); +} + +static bool isTerminal(const EGraph& egraph, Id eclass) +{ + const auto& nodes = egraph[eclass].nodes; + return std::any_of( + nodes.begin(), + nodes.end(), + [](auto& a) + { + return isTerminal(a.node); + } + ); +} + +Id mkUnion(EGraph& egraph, std::vector parts) +{ + if (parts.size() == 0) + return egraph.add(TNever{}); + else if (parts.size() == 1) + return parts[0]; + else + return egraph.add(Union{std::move(parts)}); +} + +Id mkIntersection(EGraph& egraph, std::vector parts) +{ + if (parts.size() == 0) + return egraph.add(TUnknown{}); + else if (parts.size() == 1) + return parts[0]; + else + return egraph.add(Intersection{std::move(parts)}); +} + +struct ListRemover +{ + std::unordered_map>& mappings2; + TypeId ty; + + ~ListRemover() + { + mappings2.erase(ty); + } +}; + +/* + * Crucial subtlety: It is very extremely important that enodes and eclasses are + * immutable. Mutating an enode would mean that it is no longer equivalent to + * other nodes in the same eclass. + * + * At the same time, many TypeIds are NOT immutable! + * + * The thing that makes this navigable is that it is okay if the same TypeId is + * imported as a different Id at different times as type inference runs. For + * example, if we at one point import a BlockedType as a TOpaque, and later + * import that same TypeId as some other enode type, this is all completely + * okay. + * + * The main thing we have to be very cautious about, I think, is unsealed + * tables. Unsealed table types have properties imperatively inserted into them + * as type inference runs. If we were to encode that TypeId as part of an + * enode, we could run into a situation where the egraph makes incorrect + * assumptions about the table. + * + * The solution is pretty simple: Never use the contents of a mutable TypeId in + * any reduction rule. TOpaque is always okay because we never actually poke + * around inside the TypeId to do anything. + */ +Id toId( + EGraph& egraph, + NotNull builtinTypes, + std::unordered_map& mappingIdToClass, + std::unordered_map>& typeToMappingId, // (TypeId: (MappingId, count)) + std::unordered_set& boundNodes, + StringCache& strings, + TypeId ty +) +{ + ty = follow(ty); + + // First, handle types which do not contain other types. They obviously + // cannot participate in cycles, so we don't have to check for that. + + if (auto freeTy = get(ty)) + return egraph.add(TOpaque{ty}); + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (auto prim = get(ty)) + { + switch (prim->type) + { + case Luau::PrimitiveType::NilType: + return egraph.add(TNil{}); + case Luau::PrimitiveType::Boolean: + return egraph.add(TBoolean{}); + case Luau::PrimitiveType::Number: + return egraph.add(TNumber{}); + case Luau::PrimitiveType::String: + return egraph.add(TString{}); + case Luau::PrimitiveType::Thread: + return egraph.add(TThread{}); + case Luau::PrimitiveType::Function: + return egraph.add(TTopFunction{}); + case Luau::PrimitiveType::Table: + return egraph.add(TTopTable{}); + case Luau::PrimitiveType::Buffer: + return egraph.add(TBuffer{}); + default: + LUAU_ASSERT(!"Unimplemented"); + return egraph.add(Invalid{}); + } + } + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + return egraph.add(SBoolean{bs->value}); + else if (auto ss = get(s)) + return egraph.add(SString{strings.add(ss->value)}); + else + LUAU_ASSERT(!"Unexpected"); + } + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (get(ty)) + return egraph.add(TFunction{ty}); + else if (ty == builtinTypes->classType) + return egraph.add(TTopClass{}); + else if (get(ty)) + return egraph.add(TClass{ty}); + else if (get(ty)) + return egraph.add(TAny{}); + else if (get(ty)) + return egraph.add(TError{}); + else if (get(ty)) + return egraph.add(TUnknown{}); + else if (get(ty)) + return egraph.add(TNever{}); + + // Now handle composite types. + + if (auto it = typeToMappingId.find(ty); it != typeToMappingId.end()) + { + auto& [mappingId, count] = it->second; + ++count; + Id res = egraph.add(TBound{mappingId}); + boundNodes.insert(res); + return res; + } + + typeToMappingId.emplace(ty, std::pair{mappingIdToClass.size(), 0}); + ListRemover lr{typeToMappingId, ty}; + + auto cache = [&](Id res) + { + const auto& [mappingId, count] = typeToMappingId.at(ty); + if (count > 0) + mappingIdToClass.emplace(mappingId, res); + return res; + }; + + if (auto tt = get(ty)) + return egraph.add(TImportedTable{ty}); + else if (get(ty)) + return egraph.add(TOpaque{ty}); + else if (auto ut = get(ty)) + { + std::vector parts; + for (TypeId part : ut) + parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); + + return cache(mkUnion(egraph, std::move(parts))); + } + else if (auto it = get(ty)) + { + std::vector parts; + for (TypeId part : it) + parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); + + LUAU_ASSERT(parts.size() > 1); + + return cache(mkIntersection(egraph, std::move(parts))); + } + else if (auto negation = get(ty)) + { + Id part = toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, negation->ty); + return cache(egraph.add(Negation{std::array{part}})); + } + else if (auto tfun = get(ty)) + { + LUAU_ASSERT(tfun->packArguments.empty()); + + if (tfun->userFuncName) + { + // TODO: User defined type functions are pseudo-effectful: error + // reporting is done via the `print` statement, so running a + // UDTF multiple times may end up double erroring. egraphs + // currently may induce type functions to be reduced multiple + // times. We should probably opt _not_ to process user defined + // type functions at all. + return egraph.add(TOpaque{ty}); + } + + std::vector parts; + parts.reserve(tfun->typeArguments.size()); + for (TypeId part : tfun->typeArguments) + parts.push_back(toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, part)); + + // This looks sily, but we're making a copy of the specific + // `TypeFunctionInstanceType` outside of the provided arena so that + // we can access the members without fear of the specific TFIT being + // overwritten with a bound type. + return cache(egraph.add(TTypeFun{ + std::make_shared( + tfun->function, tfun->typeArguments, tfun->packArguments, tfun->userFuncName, tfun->userFuncData + ), + std::move(parts) + })); + } + else if (get(ty)) + return egraph.add(TNoRefine{}); + else + { + LUAU_ASSERT(!"Unhandled Type"); + return cache(egraph.add(Invalid{})); + } +} + +Id toId(EGraph& egraph, NotNull builtinTypes, std::unordered_map& mappingIdToClass, StringCache& strings, TypeId ty) +{ + std::unordered_map> typeToMappingId; + std::unordered_set boundNodes; + Id id = toId(egraph, builtinTypes, mappingIdToClass, typeToMappingId, boundNodes, strings, ty); + + for (Id id : boundNodes) + { + for (const auto [tb, _index] : Query(&egraph, id)) + { + Id bindee = mappingIdToClass.at(tb->value()); + egraph.merge(id, bindee); + } + } + + egraph.rebuild(); + + return egraph.find(id); +} + +// We apply a penalty to cyclic types to guide the system away from them where +// possible. +static const int CYCLE_PENALTY = 5000; + +// Composite types have cost equal to the sum of the costs of their parts plus a +// constant factor. +static const int SET_TYPE_PENALTY = 1; +static const int TABLE_TYPE_PENALTY = 2; +static const int NEGATION_PENALTY = 2; +static const int TFUN_PENALTY = 2; + +// FIXME. We don't have an accurate way to score a TImportedTable table against +// a TTable. +static const int IMPORTED_TABLE_PENALTY = 50; + +// TBound shouldn't ever be selected as the best node of a class unless we are +// debugging eqsat itself and need to stringify eclasses. We thus penalize it +// so heavily that we'll use any other alternative. +static const int BOUND_PENALTY = 999999999; + +// TODO iteration count limit +// TODO also: accept an argument which is the maximum cost to consider before +// abandoning the count. +// TODO: the egraph should be the first parameter. +static size_t computeCost(std::unordered_map& bestNodes, const EGraph& egraph, std::unordered_map& costs, Id id) +{ + if (auto it = costs.find(id); it != costs.end()) + return it->second; + + const std::vector>& nodes = egraph[id].nodes; + + size_t minCost = std::numeric_limits::max(); + size_t bestNode = std::numeric_limits::max(); + + const auto updateCost = [&](size_t cost, size_t node) + { + if (cost < minCost) + { + minCost = cost; + bestNode = node; + } + }; + + // First, quickly scan for a terminal type. If we can find one, it is obviously the best. + for (size_t index = 0; index < nodes.size(); ++index) + { + if (isTerminal(nodes[index].node)) + { + minCost = 1; + bestNode = index; + + costs[id] = 1; + const auto [iter, isFresh] = bestNodes.insert({id, index}); + + // If we are forcing the cost function to select a specific node, + // then we still need to traverse into that node, even if this + // particular node is the obvious choice under normal circumstances. + if (isFresh || iter->second == index) + return 1; + } + } + + // If we recur into this type before this call frame completes, it is + // because this type participates in a cycle. + costs[id] = CYCLE_PENALTY; + + auto computeChildren = [&](Slice parts, size_t maxCost) -> std::optional + { + size_t cost = 0; + for (Id part : parts) + { + cost += computeCost(bestNodes, egraph, costs, part); + + // Abandon this node if it is too costly + if (cost > maxCost) + return std::nullopt; + } + return cost; + }; + + size_t startIndex = 0; + size_t endIndex = nodes.size(); + + // FFlag::DebugLuauLogSimplification will sometimes stringify an Id and pass + // in a prepopulated bestNodes map. If that mapping already has an index + // for this Id, don't look at the other nodes of this class. + if (auto it = bestNodes.find(id); it != bestNodes.end()) + { + LUAU_ASSERT(it->second < nodes.size()); + + startIndex = it->second; + endIndex = startIndex + 1; + } + + for (size_t index = startIndex; index < endIndex; ++index) + { + const auto& node = nodes[index]; + + if (node.node.get()) + updateCost(BOUND_PENALTY, index); // TODO: This could probably be an assert now that we don't need rewrite rules to handle TBound. + else if (node.node.get()) + { + minCost = 1; + bestNode = index; + } + else if (auto tbl = node.node.get()) + { + // TODO: We could make the penalty a parameter to computeChildren. + std::optional maybeCost = computeChildren(tbl->operands(), minCost); + if (maybeCost) + updateCost(TABLE_TYPE_PENALTY + *maybeCost, index); + } + else if (node.node.get()) + { + minCost = IMPORTED_TABLE_PENALTY; + bestNode = index; + } + else if (auto u = node.node.get()) + { + std::optional maybeCost = computeChildren(u->operands(), minCost); + if (maybeCost) + updateCost(SET_TYPE_PENALTY + *maybeCost, index); + } + else if (auto i = node.node.get()) + { + std::optional maybeCost = computeChildren(i->operands(), minCost); + if (maybeCost) + updateCost(SET_TYPE_PENALTY + *maybeCost, index); + } + else if (auto negation = node.node.get()) + { + std::optional maybeCost = computeChildren(negation->operands(), minCost); + if (maybeCost) + updateCost(NEGATION_PENALTY + *maybeCost, index); + } + else if (auto tfun = node.node.get()) + { + std::optional maybeCost = computeChildren(tfun->operands(), minCost); + if (maybeCost) + updateCost(TFUN_PENALTY + *maybeCost, index); + } + } + + LUAU_ASSERT(bestNode < nodes.size()); + + costs[id] = minCost; + bestNodes.insert({id, bestNode}); + return minCost; +} + +static std::unordered_map computeBestResult(const EGraph& egraph, Id id, const std::unordered_map& forceNodes) +{ + std::unordered_map costs; + std::unordered_map bestNodes = forceNodes; + computeCost(bestNodes, egraph, costs, id); + return bestNodes; +} + +static std::unordered_map computeBestResult(const EGraph& egraph, Id id) +{ + std::unordered_map costs; + std::unordered_map bestNodes; + computeCost(bestNodes, egraph, costs, id); + return bestNodes; +} + +TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& bestNodes, + std::unordered_map& seen, + std::vector& newTypeFunctions, + Id rootId +); + +TypeId flattenTableNode( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& bestNodes, + std::unordered_map& seen, + std::vector& newTypeFunctions, + Id rootId +) +{ + std::vector stack; + std::unordered_set seenIds; + + Id id = rootId; + const TImportedTable* importedTable = nullptr; + while (true) + { + size_t index = bestNodes.at(id); + const auto& eclass = egraph[id]; + + const auto [_iter, isFresh] = seenIds.insert(id); + if (!isFresh) + { + // If a TTable is its own basis, it must be the case that some other + // node on this eclass is a TImportedTable. Let's use that. + + bool found = false; + + for (size_t i = 0; i < eclass.nodes.size(); ++i) + { + if (eclass.nodes[i].node.get()) + { + found = true; + index = i; + break; + } + } + + if (!found) + { + // If we couldn't find one, we don't know what to do. Use ErrorType. + LUAU_ASSERT(0); + return builtinTypes->errorType; + } + } + + const auto& node = eclass.nodes[index]; + if (const TTable* ttable = node.node.get()) + { + stack.push_back(ttable); + id = ttable->getBasis(); + continue; + } + else if (const TImportedTable* ti = node.node.get()) + { + importedTable = ti; + break; + } + else + LUAU_ASSERT(0); + } + + TableType resultTable; + if (importedTable) + { + const TableType* t = Luau::get(importedTable->value()); + LUAU_ASSERT(t); + resultTable = *t; // Intentional shallow clone here + } + + while (!stack.empty()) + { + const TTable* t = stack.back(); + stack.pop_back(); + + for (size_t i = 0; i < t->propNames.size(); ++i) + { + StringId propName = t->propNames[i]; + const Id propType = t->propTypes()[i]; + + resultTable.props[strings.asString(propName)] = + Property{fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, propType)}; + } + } + + return arena->addType(std::move(resultTable)); +} + +TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& bestNodes, + std::unordered_map& seen, + std::vector& newTypeFunctions, + Id rootId +) +{ + if (auto it = seen.find(rootId); it != seen.end()) + return it->second; + + size_t index = bestNodes.at(rootId); + LUAU_ASSERT(index <= egraph[rootId].nodes.size()); + + const EType& node = egraph[rootId].nodes[index].node; + + if (node.get()) + return builtinTypes->nilType; + else if (node.get()) + return builtinTypes->booleanType; + else if (node.get()) + return builtinTypes->numberType; + else if (node.get()) + return builtinTypes->stringType; + else if (node.get()) + return builtinTypes->threadType; + else if (node.get()) + return builtinTypes->functionType; + else if (node.get()) + return builtinTypes->tableType; + else if (node.get()) + return builtinTypes->classType; + else if (node.get()) + return builtinTypes->bufferType; + else if (auto opaque = node.get()) + return opaque->value(); + else if (auto b = node.get()) + return b->value() ? builtinTypes->trueType : builtinTypes->falseType; + else if (auto s = node.get()) + return arena->addType(SingletonType{StringSingleton{strings.asString(s->value())}}); + else if (auto fun = node.get()) + return fun->value(); + else if (auto tbl = node.get()) + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + TypeId flattened = flattenTableNode(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); + + asMutable(res)->ty.emplace(flattened); + return flattened; + } + else if (auto tbl = node.get()) + return tbl->value(); + else if (auto cls = node.get()) + return cls->value(); + else if (node.get()) + return builtinTypes->anyType; + else if (node.get()) + return builtinTypes->errorType; + else if (node.get()) + return builtinTypes->unknownType; + else if (node.get()) + return builtinTypes->neverType; + else if (auto u = node.get()) + { + Slice parts = u->operands(); + + if (parts.empty()) + return builtinTypes->neverType; + else if (parts.size() == 1) + { + TypeId placeholder = arena->addType(BlockedType{}); + seen[rootId] = placeholder; + auto result = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); + if (follow(result) == placeholder) + { + emplaceType(asMutable(placeholder), "EGRAPH-SINGLETON-CYCLE"); + } + else + { + emplaceType(asMutable(placeholder), result); + } + return result; + } + else + { + TypeId res = arena->addType(BlockedType{}); + + seen[rootId] = res; + + std::vector partTypes; + partTypes.reserve(parts.size()); + + for (Id part : parts) + partTypes.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); + + asMutable(res)->ty.emplace(std::move(partTypes)); + + return res; + } + } + else if (auto i = node.get()) + { + Slice parts = i->operands(); + + if (parts.empty()) + return builtinTypes->neverType; + else if (parts.size() == 1) + { + LUAU_ASSERT(parts[0] != rootId); + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, parts[0]); + } + else + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + std::vector partTypes; + partTypes.reserve(parts.size()); + + for (Id part : parts) + partTypes.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); + + asMutable(res)->ty.emplace(std::move(partTypes)); + + return res; + } + } + else if (auto negation = node.get()) + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + TypeId ty = fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, negation->operands()[0]); + + asMutable(res)->ty.emplace(ty); + + return res; + } + else if (auto tfun = node.get()) + { + TypeId res = arena->addType(BlockedType{}); + seen[rootId] = res; + + std::vector args; + for (Id part : tfun->operands()) + args.push_back(fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, part)); + + auto oldInstance = tfun->value(); + + asMutable(res)->ty.emplace( + oldInstance->function, std::move(args), std::vector(), oldInstance->userFuncName, oldInstance->userFuncData + ); + + newTypeFunctions.push_back(res); + + return res; + } + else if (node.get()) + return builtinTypes->errorType; + else if (node.get()) + return builtinTypes->noRefineType; + else + { + LUAU_ASSERT(!"Unimplemented"); + return nullptr; + } +} + +static TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + const std::unordered_map& forceNodes, + std::vector& newTypeFunctions, + Id rootId +) +{ + const std::unordered_map bestNodes = computeBestResult(egraph, rootId, forceNodes); + std::unordered_map seen; + + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); +} + +static TypeId fromId( + EGraph& egraph, + const StringCache& strings, + NotNull builtinTypes, + NotNull arena, + std::vector& newTypeFunctions, + Id rootId +) +{ + const std::unordered_map bestNodes = computeBestResult(egraph, rootId); + std::unordered_map seen; + + return fromId(egraph, strings, builtinTypes, arena, bestNodes, seen, newTypeFunctions, rootId); +} + +Subst::Subst(Id eclass, Id newClass, std::string desc) + : eclass(std::move(eclass)) + , newClass(std::move(newClass)) + , desc(std::move(desc)) +{ +} + +std::string mkDesc( + EGraph& egraph, + const StringCache& strings, + NotNull arena, + NotNull builtinTypes, + Id from, + Id to, + const std::unordered_map& forceNodes, + const std::string& rule +) +{ + if (!FFlag::DebugLuauLogSimplification) + return ""; + + std::vector newTypeFunctions; + + TypeId fromTy = fromId(egraph, strings, builtinTypes, arena, forceNodes, newTypeFunctions, from); + TypeId toTy = fromId(egraph, strings, builtinTypes, arena, forceNodes, newTypeFunctions, to); + + ToStringOptions opts; + opts.useQuestionMarks = false; + + const int RULE_PADDING = 35; + const std::string rulePadding(std::max(0, RULE_PADDING - rule.size()), ' '); + const std::string fromIdStr = ""; // "(" + std::to_string(uint32_t(from)) + ") "; + const std::string toIdStr = ""; // "(" + std::to_string(uint32_t(to)) + ") "; + + return rule + ":" + rulePadding + fromIdStr + toString(fromTy, opts) + " <=> " + toIdStr + toString(toTy, opts); +} + +std::string mkDesc( + EGraph& egraph, + const StringCache& strings, + NotNull arena, + NotNull builtinTypes, + Id from, + Id to, + const std::string& rule +) +{ + if (!FFlag::DebugLuauLogSimplification) + return ""; + + return mkDesc(egraph, strings, arena, builtinTypes, from, to, {}, rule); +} + +static std::string getNodeName(const StringCache& strings, const EType& node) +{ + if (node.get()) + return "nil"; + else if (node.get()) + return "boolean"; + else if (node.get()) + return "number"; + else if (node.get()) + return "string"; + else if (node.get()) + return "thread"; + else if (node.get()) + return "function"; + else if (node.get()) + return "table"; + else if (node.get()) + return "class"; + else if (node.get()) + return "buffer"; + else if (node.get()) + return "opaque"; + else if (auto b = node.get()) + return b->value() ? "true" : "false"; + else if (auto s = node.get()) + return "\"" + strings.asString(s->value()) + "\""; + else if (node.get()) + return "\xe2\x88\xaa"; + else if (node.get()) + return "\xe2\x88\xa9"; + else if (auto cls = node.get()) + { + const ClassType* ct = get(cls->value()); + LUAU_ASSERT(ct); + return ct->name; + } + else if (node.get()) + return "any"; + else if (node.get()) + return "error"; + else if (node.get()) + return "unknown"; + else if (node.get()) + return "never"; + else if (auto tfun = node.get()) + return "tfun " + tfun->value()->function->name; + else if (node.get()) + return "~"; + else if (node.get()) + return "invalid?"; + else if (node.get()) + return "bound"; + + return "???"; +} + +std::string toDot(const StringCache& strings, const EGraph& egraph) +{ + std::stringstream ss; + ss << "digraph G {" << '\n'; + ss << " graph [fontsize=10 fontname=\"Verdana\" compound=true];" << '\n'; + ss << " node [shape=record fontsize=10 fontname=\"Verdana\"];" << '\n'; + + std::set populated; + + for (const auto& [id, eclass] : egraph.getAllClasses()) + { + for (const auto& n : eclass.nodes) + { + const EType& node = n.node; + if (!node.operands().empty()) + populated.insert(id); + for (Id op : node.operands()) + populated.insert(op); + } + } + + for (const auto& [id, eclass] : egraph.getAllClasses()) + { + if (!populated.count(id)) + continue; + + const std::string className = "cluster_" + std::to_string(uint32_t(id)); + ss << " subgraph " << className << " {" << '\n'; + ss << " node [style=\"rounded,filled\"];" << '\n'; + ss << " label = \"" << uint32_t(id) << "\";" << '\n'; + ss << " color = blue;" << '\n'; + + for (size_t index = 0; index < eclass.nodes.size(); ++index) + { + const auto& node = eclass.nodes[index].node; + + const std::string label = getNodeName(strings, node); + const std::string nodeName = "n" + std::to_string(uint32_t(id)) + "_" + std::to_string(index); + + ss << " " << nodeName << " [label=\"" << label << "\"];" << '\n'; + } + + ss << " }" << '\n'; + } + + for (const auto& [id, eclass] : egraph.getAllClasses()) + { + for (size_t index = 0; index < eclass.nodes.size(); ++index) + { + const auto& node = eclass.nodes[index].node; + + const std::string label = getNodeName(strings, node); + const std::string nodeName = "n" + std::to_string(uint32_t(egraph.find(id))) + "_" + std::to_string(index); + + for (Id op : node.operands()) + { + op = egraph.find(op); + const std::string destNodeName = "n" + std::to_string(uint32_t(op)) + "_0"; + ss << " " << nodeName << " -> " << destNodeName << " [lhead=cluster_" << uint32_t(op) << "];" << '\n'; + } + } + } + + ss << "}" << '\n'; + + return ss.str(); +} + +template +static Tag const* isTag(const EType& node) +{ + return node.get(); +} + +/// Important: Only use this to test for leaf node types like TUnknown and +/// TNumber. Things that we know cannot be simplified any further and are safe +/// to short-circuit on. +/// +/// It does a linear scan and exits early, so if a particular eclass has +/// multiple "interesting" representations, this function can surprise you. +template +static Tag const* isTag(const EGraph& egraph, Id id) +{ + for (const auto& node : egraph[id].nodes) + { + if (auto n = isTag(node.node)) + return n; + } + return nullptr; +} + +struct RewriteRule +{ + explicit RewriteRule(EGraph* egraph) + : egraph(egraph) + { + } + + virtual void read(std::vector& substs, Id eclass, const EType* enode) = 0; + +protected: + const EqSat::EClass& get(Id id) + { + return (*egraph)[id]; + } + + Id find(Id id) + { + return egraph->find(id); + } + + Id add(EType enode) + { + return egraph->add(std::move(enode)); + } + + template + const Tag* isTag(Id id) + { + for (const auto& node : (*egraph)[id].nodes) + { + if (auto n = node.node.get()) + return n; + } + return nullptr; + } + + template + bool isTag(const EType& enode) + { + return enode.get(); + } + +public: + EGraph* egraph; +}; + +enum SubclassRelationship +{ + LeftSuper, + RightSuper, + Unrelated +}; + +static SubclassRelationship relateClasses(const TClass* leftClass, const TClass* rightClass) +{ + const ClassType* leftClassType = Luau::get(leftClass->value()); + const ClassType* rightClassType = Luau::get(rightClass->value()); + + if (isSubclass(leftClassType, rightClassType)) + return RightSuper; + else if (isSubclass(rightClassType, leftClassType)) + return LeftSuper; + else + return Unrelated; +} + +// Entirely analogous to NormalizedType except that it operates on eclasses instead of TypeIds. +struct CanonicalizedType +{ + std::optional nilPart; + std::optional truePart; + std::optional falsePart; + std::optional numberPart; + std::optional stringPart; + std::vector stringSingletons; + std::optional threadPart; + std::optional functionPart; + std::optional tablePart; + std::vector classParts; + std::optional bufferPart; + std::optional errorPart; + + // Functions that have been union'd into the type + std::unordered_set functionParts; + + // Anything that isn't canonical: Intersections, unions, free types, and so on. + std::unordered_set otherParts; + + bool isUnknown() const + { + return nilPart && truePart && falsePart && numberPart && stringPart && threadPart && functionPart && tablePart && bufferPart; + } +}; + +void unionUnknown(EGraph& egraph, CanonicalizedType& ct) +{ + ct.nilPart = egraph.add(TNil{}); + ct.truePart = egraph.add(SBoolean{true}); + ct.falsePart = egraph.add(SBoolean{false}); + ct.numberPart = egraph.add(TNumber{}); + ct.stringPart = egraph.add(TString{}); + ct.threadPart = egraph.add(TThread{}); + ct.functionPart = egraph.add(TTopFunction{}); + ct.tablePart = egraph.add(TTopTable{}); + ct.bufferPart = egraph.add(TBuffer{}); + + ct.functionParts.clear(); + ct.otherParts.clear(); +} + +void unionAny(EGraph& egraph, CanonicalizedType& ct) +{ + unionUnknown(egraph, ct); + ct.errorPart = egraph.add(TError{}); +} + +void unionClasses(EGraph& egraph, std::vector& hereParts, Id there) +{ + if (1 == hereParts.size() && isTag(egraph, hereParts[0])) + return; + + const auto thereClass = isTag(egraph, there); + if (!thereClass) + return; + + for (size_t index = 0; index < hereParts.size(); ++index) + { + const Id herePart = hereParts[index]; + + if (auto partClass = isTag(egraph, herePart)) + { + switch (relateClasses(partClass, thereClass)) + { + case LeftSuper: + return; + case RightSuper: + hereParts[index] = there; + std::sort(hereParts.begin(), hereParts.end()); + return; + case Unrelated: + continue; + } + } + } + + hereParts.push_back(there); + std::sort(hereParts.begin(), hereParts.end()); +} + +void unionWithType(EGraph& egraph, CanonicalizedType& ct, Id part) +{ + if (isTag(egraph, part)) + ct.nilPart = part; + else if (isTag(egraph, part)) + ct.truePart = ct.falsePart = part; + else if (auto b = isTag(egraph, part)) + { + if (b->value()) + ct.truePart = part; + else + ct.falsePart = part; + } + else if (isTag(egraph, part)) + ct.numberPart = part; + else if (isTag(egraph, part)) + ct.stringPart = part; + else if (isTag(egraph, part)) + ct.stringSingletons.push_back(part); + else if (isTag(egraph, part)) + ct.threadPart = part; + else if (isTag(egraph, part)) + { + ct.functionPart = part; + ct.functionParts.clear(); + } + else if (isTag(egraph, part)) + ct.tablePart = part; + else if (isTag(egraph, part)) + ct.classParts = {part}; + else if (isTag(egraph, part)) + ct.bufferPart = part; + else if (isTag(egraph, part)) + { + if (!ct.functionPart) + ct.functionParts.insert(part); + } + else if (auto tclass = isTag(egraph, part)) + unionClasses(egraph, ct.classParts, part); + else if (isTag(egraph, part)) + { + unionAny(egraph, ct); + return; + } + else if (isTag(egraph, part)) + ct.errorPart = part; + else if (isTag(egraph, part)) + unionUnknown(egraph, ct); + else if (isTag(egraph, part)) + { + // Nothing + } + else + ct.otherParts.insert(part); +} + +// Find an enode under the given eclass which is simple enough that it could be +// subtracted from a CanonicalizedType easily. +// +// A union is "simple enough" if it is acyclic and is only comprised of terminal +// types and unions that are themselves subtractable +const EType* findSubtractableClass(const EGraph& egraph, std::unordered_set& seen, Id id) +{ + if (seen.count(id)) + return nullptr; + + const EType* bestUnion = nullptr; + std::optional unionSize; + + for (const auto& n : egraph[id].nodes) + { + const EType& node = n.node; + + if (isTerminal(node)) + return &node; + + if (const auto u = node.get()) + { + seen.insert(id); + + for (Id part : u->operands()) + { + if (!findSubtractableClass(egraph, seen, part)) + return nullptr; + } + + // If multiple unions in this class are all simple enough, prefer + // the shortest one. + if (!unionSize || u->operands().size() < unionSize) + { + unionSize = u->operands().size(); + bestUnion = &node; + } + } + } + + return bestUnion; +} + +const EType* findSubtractableClass(const EGraph& egraph, Id id) +{ + std::unordered_set seen; + + return findSubtractableClass(egraph, seen, id); +} + +// Subtract the type 'part' from 'ct' +// Returns true if the subtraction succeeded. This function will fail if 'part` is too complicated. +bool subtract(EGraph& egraph, CanonicalizedType& ct, Id part) +{ + const EType* etype = findSubtractableClass(egraph, part); + if (!etype) + return false; + + if (etype->get()) + ct.nilPart.reset(); + else if (etype->get()) + { + ct.truePart.reset(); + ct.falsePart.reset(); + } + else if (auto b = etype->get()) + { + if (b->value()) + ct.truePart.reset(); + else + ct.falsePart.reset(); + } + else if (etype->get()) + ct.numberPart.reset(); + else if (etype->get()) + ct.stringPart.reset(); + else if (etype->get()) + return false; + else if (etype->get()) + ct.threadPart.reset(); + else if (etype->get()) + ct.functionPart.reset(); + else if (etype->get()) + ct.tablePart.reset(); + else if (etype->get()) + ct.classParts.clear(); + else if (auto tclass = etype->get()) + { + auto it = std::find(ct.classParts.begin(), ct.classParts.end(), part); + if (it != ct.classParts.end()) + ct.classParts.erase(it); + else + return false; + } + else if (etype->get()) + ct.bufferPart.reset(); + else if (etype->get()) + ct = {}; + else if (etype->get()) + ct.errorPart.reset(); + else if (etype->get()) + { + std::optional errorPart = ct.errorPart; + ct = {}; + ct.errorPart = errorPart; + } + else if (etype->get()) + { + // Nothing + } + else if (auto u = etype->get()) + { + // TODO cycles + // TODO this is super promlematic because 'part' represents a whole group of equivalent enodes. + for (Id unionPart : u->operands()) + { + // TODO: This recursive call will require that we re-traverse this + // eclass to find the subtractible enode. It would be nice to do the + // work just once and reuse it. + bool ok = subtract(egraph, ct, unionPart); + if (!ok) + return false; + } + } + else if (etype->get()) + return false; + else + return false; + + return true; +} + +static std::pair fromCanonicalized(EGraph& egraph, CanonicalizedType& ct) +{ + if (ct.isUnknown()) + { + if (ct.errorPart) + return {egraph.add(TAny{}), 1}; + else + return {egraph.add(TUnknown{}), 1}; + } + + std::vector parts; + + if (ct.nilPart) + parts.push_back(*ct.nilPart); + + if (ct.truePart && ct.falsePart) + parts.push_back(egraph.add(TBoolean{})); + else if (ct.truePart) + parts.push_back(*ct.truePart); + else if (ct.falsePart) + parts.push_back(*ct.falsePart); + + if (ct.numberPart) + parts.push_back(*ct.numberPart); + + if (ct.stringPart) + parts.push_back(*ct.stringPart); + else if (!ct.stringSingletons.empty()) + parts.insert(parts.end(), ct.stringSingletons.begin(), ct.stringSingletons.end()); + + if (ct.threadPart) + parts.push_back(*ct.threadPart); + if (ct.functionPart) + parts.push_back(*ct.functionPart); + if (ct.tablePart) + parts.push_back(*ct.tablePart); + parts.insert(parts.end(), ct.classParts.begin(), ct.classParts.end()); + if (ct.bufferPart) + parts.push_back(*ct.bufferPart); + if (ct.errorPart) + parts.push_back(*ct.errorPart); + + parts.insert(parts.end(), ct.functionParts.begin(), ct.functionParts.end()); + parts.insert(parts.end(), ct.otherParts.begin(), ct.otherParts.end()); + + std::sort(parts.begin(), parts.end()); + auto it = std::unique(parts.begin(), parts.end()); + parts.erase(it, parts.end()); + + const size_t size = parts.size(); + return {mkUnion(egraph, std::move(parts)), size}; +} + +void addChildren(const EGraph& egraph, const EType* enode, VecDeque& worklist) +{ + for (Id id : enode->operands()) + worklist.push_back(id); +} + +static bool occurs(EGraph& egraph, Id outerId, Slice operands) +{ + for (const Id i : operands) + { + if (egraph.find(i) == outerId) + return true; + } + return false; +} + +Simplifier::Simplifier(NotNull arena, NotNull builtinTypes) + : arena(arena) + , builtinTypes(builtinTypes) + , egraph(Simplify{}) +{ +} + +const EqSat::EClass& Simplifier::get(Id id) const +{ + return egraph[id]; +} + +Id Simplifier::find(Id id) const +{ + return egraph.find(id); +} + +Id Simplifier::add(EType enode) +{ + return egraph.add(std::move(enode)); +} + +template +const Tag* Simplifier::isTag(Id id) const +{ + for (const auto& node : get(id).nodes) + { + if (const Tag* ty = node.node.get()) + return ty; + } + + return nullptr; +} + +template +const Tag* Simplifier::isTag(const EType& enode) const +{ + return enode.get(); +} + +void Simplifier::subst(Id from, Id to) +{ + substs.emplace_back(from, to, " - "); +} + +void Simplifier::subst(Id from, Id to, const std::string& ruleName) +{ + std::string desc; + if (FFlag::DebugLuauLogSimplification) + desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, std::move(ruleName)); + substs.emplace_back(from, to, desc); +} + +void Simplifier::subst(Id from, Id to, const std::string& ruleName, const std::unordered_map& forceNodes) +{ + std::string desc; + if (FFlag::DebugLuauLogSimplification) + desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName); + substs.emplace_back(from, to, desc); +} + +void Simplifier::subst(Id from, size_t boringIndex, Id to, const std::string& ruleName, const std::unordered_map& forceNodes) +{ + std::string desc; + if (FFlag::DebugLuauLogSimplification) + desc = mkDesc(egraph, stringCache, arena, builtinTypes, from, to, forceNodes, ruleName); + + egraph.markBoring(from, boringIndex); + substs.emplace_back(from, to, desc); +} + +void Simplifier::unionClasses(std::vector& hereParts, Id there) +{ + if (1 == hereParts.size() && isTag(hereParts[0])) + return; + + const auto thereClass = isTag(there); + if (!thereClass) + return; + + for (size_t index = 0; index < hereParts.size(); ++index) + { + const Id herePart = hereParts[index]; + + if (auto partClass = isTag(herePart)) + { + switch (relateClasses(partClass, thereClass)) + { + case LeftSuper: + return; + case RightSuper: + hereParts[index] = there; + std::sort(hereParts.begin(), hereParts.end()); + return; + case Unrelated: + continue; + } + } + } + + hereParts.push_back(there); + std::sort(hereParts.begin(), hereParts.end()); +} + +void Simplifier::simplifyUnion(Id id) +{ + id = find(id); + + for (const auto [u, unionIndex] : Query(&egraph, id)) + { + std::vector newParts; + std::unordered_set seen; + + CanonicalizedType canonicalized; + + if (occurs(egraph, id, u->operands())) + continue; + + for (Id part : u->operands()) + unionWithType(egraph, canonicalized, find(part)); + + const auto [resultId, newSize] = fromCanonicalized(egraph, canonicalized); + + if (newSize < u->operands().size()) + subst(id, unionIndex, resultId, "simplifyUnion", {{id, unionIndex}}); + else + subst(id, resultId, "simplifyUnion", {{id, unionIndex}}); + } +} + +// If one of the nodes matches the given Tag, succeed and return the id and node for the other half. +// If neither matches, return nullopt. +template +static std::optional> matchOne(Id hereId, const EType* hereNode, Id thereId, const EType* thereNode) +{ + if (hereNode->get()) + return std::pair{thereId, thereNode}; + else if (thereNode->get()) + return std::pair{hereId, hereNode}; + else + return std::nullopt; +} + +// If the two nodes can be intersected into a "simple" type, return that, else return nullopt. +std::optional intersectOne(EGraph& egraph, Id hereId, const EType* hereNode, Id thereId, const EType* thereNode) +{ + hereId = egraph.find(hereId); + thereId = egraph.find(thereId); + + if (hereId == thereId) + return *hereNode; + + if (hereNode->get() || thereNode->get()) + return TNever{}; + + if (hereNode->get() || hereNode->get() || hereNode->get() || thereNode->get() || + thereNode->get() || thereNode->get() || hereNode->get() || thereNode->get()) + return std::nullopt; + + if (hereNode->get()) + return *thereNode; + if (thereNode->get()) + return *hereNode; + + if (hereNode->get() || thereNode->get()) + return std::nullopt; + + if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) + { + const auto [otherId, otherNode] = *res; + + if (otherNode->get() || otherNode->get()) + return *otherNode; + else + return TNever{}; + } + if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) + { + const auto [otherId, otherNode] = *res; + + if (otherNode->get() || otherNode->get()) + return *otherNode; + } + if (auto res = matchOne(hereId, hereNode, thereId, thereNode)) + { + const auto [otherId, otherNode] = *res; + + if (otherNode->get()) + return std::nullopt; // TODO + else + return TNever{}; + } + if (auto hereClass = hereNode->get()) + { + if (auto thereClass = thereNode->get()) + { + switch (relateClasses(hereClass, thereClass)) + { + case LeftSuper: + return *thereNode; + case RightSuper: + return *hereNode; + case Unrelated: + return TNever{}; + } + } + else + return TNever{}; + } + if (auto hereBool = hereNode->get()) + { + if (auto thereBool = thereNode->get()) + { + if (hereBool->value() == thereBool->value()) + return *hereNode; + else + return TNever{}; + } + else if (thereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (auto thereBool = thereNode->get()) + { + if (auto hereBool = hereNode->get()) + { + if (thereBool->value() == hereBool->value()) + return *thereNode; + else + return TNever{}; + } + else if (hereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (hereNode->get()) + { + if (thereNode->get()) + return TBoolean{}; + else if (thereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (thereNode->get()) + { + if (hereNode->get()) + return TBoolean{}; + else if (hereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (hereNode->get()) + { + if (thereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (thereNode->get()) + { + if (hereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (hereNode->get()) + { + if (thereNode->get() || thereNode->get()) + return *thereNode; + else + return TNever{}; + } + if (thereNode->get()) + { + if (hereNode->get() || hereNode->get()) + return *hereNode; + else + return TNever{}; + } + if (hereNode->get() && thereNode->get()) + return std::nullopt; + if (hereNode->get() && isTerminal(*thereNode)) + return TNever{}; + if (thereNode->get() && isTerminal(*hereNode)) + return TNever{}; + if (isTerminal(*hereNode) && isTerminal(*thereNode)) + { + // We already know that 'here' and 'there' are different classes. + return TNever{}; + } + + return std::nullopt; +} + +void Simplifier::uninhabitedIntersection(Id id) +{ + for (const auto [intersection, index] : Query(&egraph, id)) + { + Slice parts = intersection->operands(); + + if (parts.empty()) + { + Id never = egraph.add(TNever{}); + subst(id, never, "uninhabitedIntersection"); + return; + } + else if (1 == parts.size()) + { + subst(id, parts[0], "uninhabitedIntersection"); + return; + } + + Id accumulator = egraph.add(TUnknown{}); + EType accumulatorNode = TUnknown{}; + + std::vector unsimplified; + + if (occurs(egraph, id, parts)) + continue; + + for (Id partId : parts) + { + if (isTag(partId)) + return; + + bool found = false; + + const auto& partNodes = egraph[partId].nodes; + for (size_t partIndex = 0; partIndex < partNodes.size(); ++partIndex) + { + const EType& N = partNodes[partIndex].node; + if (std::optional intersection = intersectOne(egraph, accumulator, &accumulatorNode, partId, &N)) + { + if (isTag(*intersection)) + { + subst(id, egraph.add(TNever{}), "uninhabitedIntersection", {{id, index}, {partId, partIndex}}); + return; + } + + accumulator = egraph.add(*intersection); + accumulatorNode = *intersection; + found = true; + break; + } + } + + if (!found) + unsimplified.push_back(partId); + } + + if ((unsimplified.empty() || !isTag(accumulator)) && find(accumulator) != id) + unsimplified.push_back(accumulator); + + const bool isSmaller = unsimplified.size() < parts.size(); + + const Id result = mkIntersection(egraph, std::move(unsimplified)); + + if (isSmaller) + subst(id, index, result, "uninhabitedIntersection", {{id, index}}); + else + subst(id, result, "uninhabitedIntersection", {{id, index}}); + } +} + +void Simplifier::intersectWithNegatedClass(Id id) +{ + for (const auto pair : Query(&egraph, id)) + { + const Intersection* intersection = pair.first; + const size_t intersectionIndex = pair.second; + + auto trySubst = [&](size_t i, size_t j) + { + Id iId = intersection->operands()[i]; + Id jId = intersection->operands()[j]; + + for (const auto [negation, negationIndex] : Query(&egraph, jId)) + { + const Id negated = negation->operands()[0]; + + if (iId == negated) + { + subst(id, egraph.add(TNever{}), "intersectClassWithNegatedClass", {{id, intersectionIndex}, {jId, negationIndex}}); + return; + } + + for (const auto [negatedClass, negatedClassIndex] : Query(&egraph, negated)) + { + const auto& iNodes = egraph[iId].nodes; + for (size_t iIndex = 0; iIndex < iNodes.size(); ++iIndex) + { + const EType& iNode = iNodes[iIndex].node; + if (isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode) || + isTag(iNode) || + // isTag(iNode) || // I'm not sure about this one. + isTag(iNode) || isTag(iNode) || isTag(iNode) || isTag(iNode)) + { + // eg string & ~SomeClass + subst( + id, + iId, + "intersectClassWithNegatedClass", + {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}} + ); + return; + } + + if (const TClass* class_ = iNode.get()) + { + switch (relateClasses(class_, negatedClass)) + { + case LeftSuper: + // eg Instance & ~Part + // This cannot be meaningfully reduced. + continue; + case RightSuper: + subst( + id, + egraph.add(TNever{}), + "intersectClassWithNegatedClass", + {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}} + ); + return; + case Unrelated: + // Part & ~Folder == Part + { + std::vector newParts; + newParts.reserve(intersection->operands().size() - 1); + for (Id part : intersection->operands()) + { + if (part != jId) + newParts.push_back(part); + } + + Id substId = mkIntersection(egraph, newParts); + subst( + id, + substId, + "intersectClassWithNegatedClass", + {{id, intersectionIndex}, {iId, iIndex}, {jId, negationIndex}, {negated, negatedClassIndex}} + ); + } + } + } + } + } + } + }; + + if (2 != intersection->operands().size()) + continue; + + trySubst(0, 1); + trySubst(1, 0); + } +} + +void Simplifier::intersectWithNegatedAtom(Id id) +{ + // Let I and ~J be two arbitrary distinct operands of an intersection where + // I and J are terminal but are not type variables. (free, generic, or + // otherwise opaque) + // + // If I and J are equal, then the whole intersection is equivalent to never. + // + // If I and J are inequal, then J & ~I == J + + for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) + { + const Slice& intersectionOperands = intersection->operands(); + for (size_t i = 0; i < intersectionOperands.size(); ++i) + { + for (const auto [negation, negationIndex] : Query(&egraph, intersectionOperands[i])) + { + for (size_t negationOperandIndex = 0; negationOperandIndex < egraph[negation->operands()[0]].nodes.size(); ++negationOperandIndex) + { + const EType* negationOperand = &egraph[negation->operands()[0]].nodes[negationOperandIndex].node; + if (!isTerminal(*negationOperand) || negationOperand->get()) + continue; + + for (size_t j = 0; j < intersectionOperands.size(); ++j) + { + if (j == i) + continue; + + for (size_t jNodeIndex = 0; jNodeIndex < egraph[intersectionOperands[j]].nodes.size(); ++jNodeIndex) + { + const EType* jNode = &egraph[intersectionOperands[j]].nodes[jNodeIndex].node; + if (!isTerminal(*jNode) || jNode->get()) + continue; + + if (*negationOperand == *jNode) + { + // eg "Hello" & ~"Hello" + // or boolean & ~boolean + subst( + id, + egraph.add(TNever{}), + "intersectWithNegatedAtom", + {{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}} + ); + return; + } + else if (areTerminalAndDefinitelyDisjoint(*jNode, *negationOperand)) + { + // eg "Hello" & ~"World" + // or boolean & ~string + std::vector newOperands(intersectionOperands.begin(), intersectionOperands.end()); + newOperands.erase(newOperands.begin() + std::vector::difference_type(i)); + + subst( + id, + mkIntersection(egraph, std::move(newOperands)), + "intersectWithNegatedAtom", + {{id, intersectionIndex}, {intersectionOperands[i], negationIndex}, {intersectionOperands[j], jNodeIndex}} + ); + } + } + } + } + } + } + } +} + +void Simplifier::intersectWithNoRefine(Id id) +{ + for (const auto pair : Query(&egraph, id)) + { + const Intersection* intersection = pair.first; + const size_t intersectionIndex = pair.second; + + const Slice intersectionOperands = intersection->operands(); + + for (size_t index = 0; index < intersectionOperands.size(); ++index) + { + const auto replace = [&]() + { + std::vector newOperands{intersectionOperands.begin(), intersectionOperands.end()}; + newOperands.erase(newOperands.begin() + index); + + Id substId = egraph.add(Intersection{std::move(newOperands)}); + + subst(id, substId, "intersectWithNoRefine", {{id, intersectionIndex}}); + }; + + if (isTag(intersectionOperands[index])) + replace(); + else + { + for (const auto [negation, negationIndex] : Query(&egraph, intersectionOperands[index])) + { + if (isTag(negation->operands()[0])) + { + replace(); + break; + } + } + } + } + } +} + +/* + * Replace x where x = A & (B | x) with A + * + * Important subtlety: The egraph is routinely going to create cyclic unions and + * intersections. We can't arbitrarily remove things from a union just because + * it can be referred to in a cyclic way. We must only do this for things that + * can only be expressed in a cyclic way. + * + * As an example, we will bind the following type to true: + * + * (true | buffer | class | function | number | string | table | thread) & + * boolean + * + * The egraph represented by this type will indeed be cyclic as the 'true' class + * includes both 'true' itself and the above type, but removing true from the + * union will result is an incorrect judgment! + * + * The solution (for now) is only to consider a type to be cyclic if it was + * cyclic on its original import. + * + * FIXME: I still don't think this is quite right, but I don't know how to + * articulate what the actual rule ought to be. + */ +void Simplifier::cyclicIntersectionOfUnion(Id id) +{ + // FIXME: This has pretty terrible runtime complexity. + + for (const auto [i, intersectionIndex] : Query(&egraph, id)) + { + Slice intersectionParts = i->operands(); + for (size_t intersectionOperandIndex = 0; intersectionOperandIndex < intersectionParts.size(); ++intersectionOperandIndex) + { + const Id intersectionPart = find(intersectionParts[intersectionOperandIndex]); + + for (const auto [bound, _boundIndex] : Query(&egraph, intersectionPart)) + { + const Id pointee = find(mappingIdToClass.at(bound->value())); + + for (const auto [u, unionIndex] : Query(&egraph, pointee)) + { + const Slice& unionOperands = u->operands(); + for (size_t unionOperandIndex = 0; unionOperandIndex < unionOperands.size(); ++unionOperandIndex) + { + Id unionOperand = find(unionOperands[unionOperandIndex]); + if (unionOperand == id) + { + std::vector newIntersectionParts(intersectionParts.begin(), intersectionParts.end()); + newIntersectionParts.erase(newIntersectionParts.begin() + intersectionOperandIndex); + + subst( + id, + mkIntersection(egraph, std::move(newIntersectionParts)), + "cyclicIntersectionOfUnion", + {{id, intersectionIndex}, {pointee, unionIndex}} + ); + } + } + } + } + } + } +} + +void Simplifier::cyclicUnionOfIntersection(Id id) +{ + // FIXME: This has pretty terrible runtime complexity. + + for (const auto [union_, unionIndex] : Query(&egraph, id)) + { + Slice unionOperands = union_->operands(); + for (size_t unionOperandIndex = 0; unionOperandIndex < unionOperands.size(); ++unionOperandIndex) + { + const Id unionPart = find(unionOperands[unionOperandIndex]); + + for (const auto [bound, _boundIndex] : Query(&egraph, unionPart)) + { + const Id pointee = find(mappingIdToClass.at(bound->value())); + + for (const auto [intersection, intersectionIndex] : Query(&egraph, pointee)) + { + Slice intersectionOperands = intersection->operands(); + for (size_t intersectionOperandIndex = 0; intersectionOperandIndex < intersectionOperands.size(); ++intersectionOperandIndex) + { + const Id intersectionPart = find(intersectionOperands[intersectionOperandIndex]); + if (intersectionPart == id) + { + std::vector newIntersectionParts(intersectionOperands.begin(), intersectionOperands.end()); + newIntersectionParts.erase(newIntersectionParts.begin() + intersectionOperandIndex); + + if (!newIntersectionParts.empty()) + { + Id newIntersection = mkIntersection(egraph, std::move(newIntersectionParts)); + + std::vector newIntersectionParts(unionOperands.begin(), unionOperands.end()); + newIntersectionParts.erase(newIntersectionParts.begin() + unionOperandIndex); + newIntersectionParts.push_back(newIntersection); + + subst( + id, + mkUnion(egraph, std::move(newIntersectionParts)), + "cyclicUnionOfIntersection", + {{id, unionIndex}, {pointee, intersectionIndex}} + ); + } + } + } + } + } + } + } +} + +void Simplifier::expandNegation(Id id) +{ + for (const auto [negation, index] : Query{&egraph, id}) + { + if (isTag(negation->operands()[0])) + return; + + CanonicalizedType canonicalized; + unionUnknown(egraph, canonicalized); + + const bool ok = subtract(egraph, canonicalized, negation->operands()[0]); + if (!ok) + continue; + + subst(id, fromCanonicalized(egraph, canonicalized).first, "expandNegation", {{id, index}}); + } +} + +/** + * Let A be a class-node having the form B & C1 & ... & Cn + * And B be a class-node having the form (D | E) + * + * Create a class containing the node (C1 & ... & Cn & D) | (C1 & ... & Cn & E) + * + * This function does nothing and returns nullopt if A and B are cyclic. + */ +static std::optional distributeIntersectionOfUnion( + EGraph& egraph, + Id outerClass, + const Intersection* outerIntersection, + Id innerClass, + const Union* innerUnion +) +{ + Slice outerOperands = outerIntersection->operands(); + + std::vector newOperands; + newOperands.reserve(innerUnion->operands().size()); + for (Id innerOperand : innerUnion->operands()) + { + if (isTag(egraph, innerOperand)) + continue; + + if (innerOperand == outerClass) + { + // Skip cyclic intersections of unions. There's a separate + // rule to get rid of those. + return std::nullopt; + } + + std::vector intersectionParts; + intersectionParts.reserve(outerOperands.size()); + intersectionParts.push_back(innerOperand); + + for (const Id op : outerOperands) + { + if (isTag(egraph, op)) + { + break; + } + if (op != innerClass) + intersectionParts.push_back(op); + } + + newOperands.push_back(mkIntersection(egraph, intersectionParts)); + } + + return mkUnion(egraph, std::move(newOperands)); +} + +// A & (B | C) -> (A & B) | (A & C) +// +// A & B & (C | D) -> A & (B & (C | D)) +// -> A & ((B & C) | (B & D)) +// -> (A & B & C) | (A & B & D) +void Simplifier::intersectionOfUnion(Id id) +{ + id = find(id); + + for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) + { + // For each operand O + // For each node N + // If N is a union U + // Create a new union comprised of every operand except O intersected with every operand of U + const Slice operands = intersection->operands(); + + if (operands.size() < 2) + return; + + if (occurs(egraph, id, operands)) + continue; + + for (Id operand : operands) + { + operand = find(operand); + if (operand == id) + break; + // Optimization: Decline to distribute any unions on an eclass that + // also contains a terminal node. + if (isTerminal(egraph, operand)) + continue; + + for (const auto [operandUnion, unionIndex] : Query(&egraph, operand)) + { + if (occurs(egraph, id, operandUnion->operands())) + continue; + + std::optional distributed = distributeIntersectionOfUnion(egraph, id, intersection, operand, operandUnion); + + if (distributed) + subst(id, *distributed, "intersectionOfUnion", {{id, intersectionIndex}, {operand, unionIndex}}); + } + } + } +} + +// {"a": b} & {"a": c, ...} => {"a": b & c, ...} +void Simplifier::intersectTableProperty(Id id) +{ + for (const auto [intersection, intersectionIndex] : Query(&egraph, id)) + { + const Slice intersectionParts = intersection->operands(); + for (size_t i = 0; i < intersection->operands().size(); ++i) + { + const Id iId = intersection->operands()[i]; + + for (size_t j = 0; j < intersection->operands().size(); ++j) + { + if (i == j) + continue; + + const Id jId = intersection->operands()[j]; + + if (iId == jId) + continue; + + for (const auto [table1, table1Index] : Query(&egraph, iId)) + { + const TableType* table1Ty = Luau::get(table1->value()); + LUAU_ASSERT(table1Ty); + + if (table1Ty->props.size() != 1) + continue; + + for (const auto [table2, table2Index] : Query(&egraph, jId)) + { + const TableType* table2Ty = Luau::get(table2->value()); + LUAU_ASSERT(table2Ty); + + auto it = table2Ty->props.find(table1Ty->props.begin()->first); + if (it != table2Ty->props.end()) + { + std::vector newIntersectionParts; + newIntersectionParts.reserve(intersectionParts.size() - 1); + + for (size_t index = 0; index < intersectionParts.size(); ++index) + { + if (index != i && index != j) + newIntersectionParts.push_back(intersectionParts[index]); + } + + Id newTableProp = egraph.add(Intersection{ + toId(egraph, builtinTypes, mappingIdToClass, stringCache, it->second.type()), + toId(egraph, builtinTypes, mappingIdToClass, stringCache, table1Ty->props.begin()->second.type()) + }); + + newIntersectionParts.push_back(egraph.add(TTable{jId, {stringCache.add(it->first)}, {newTableProp}})); + + subst( + id, + mkIntersection(egraph, std::move(newIntersectionParts)), + "intersectTableProperty", + {{id, intersectionIndex}, {iId, table1Index}, {jId, table2Index}} + ); + } + } + } + } + } + } +} + +// { prop: never } == never +void Simplifier::uninhabitedTable(Id id) +{ + for (const auto [table, tableIndex] : Query(&egraph, id)) + { + const TableType* tt = Luau::get(table->value()); + LUAU_ASSERT(tt); + + for (const auto& [propName, prop] : tt->props) + { + if (prop.readTy && Luau::get(follow(*prop.readTy))) + { + subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); + return; + } + + if (prop.writeTy && Luau::get(follow(*prop.writeTy))) + { + subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); + return; + } + } + } + + for (const auto [table, tableIndex] : Query(&egraph, id)) + { + for (Id propType : table->propTypes()) + { + if (isTag(propType)) + { + subst(id, egraph.add(TNever{}), "uninhabitedTable", {{id, tableIndex}}); + return; + } + } + } +} + +void Simplifier::unneededTableModification(Id id) +{ + for (const auto [tbl, tblIndex] : Query(&egraph, id)) + { + const Id basis = tbl->getBasis(); + for (const auto [importedTbl, importedTblIndex] : Query(&egraph, basis)) + { + const TableType* tt = Luau::get(importedTbl->value()); + LUAU_ASSERT(tt); + + bool skip = false; + + for (size_t i = 0; i < tbl->propNames.size(); ++i) + { + StringId propName = tbl->propNames[i]; + const Id propType = tbl->propTypes()[i]; + + Id importedProp = toId(egraph, builtinTypes, mappingIdToClass, stringCache, tt->props.at(stringCache.asString(propName)).type()); + + if (find(importedProp) != find(propType)) + { + skip = true; + break; + } + } + + if (!skip) + subst(id, basis, "unneededTableModification", {{id, tblIndex}, {basis, importedTblIndex}}); + } + } +} + +void Simplifier::builtinTypeFunctions(Id id) +{ + for (const auto [tfun, index] : Query(&egraph, id)) + { + const Slice& args = tfun->operands(); + + if (args.size() != 2) + continue; + + const std::string& name = tfun->value()->function->name; + if (name == "add" || name == "sub" || name == "mul" || name == "div" || name == "idiv" || name == "pow" || name == "mod") + { + if (isTag(args[0]) && isTag(args[1])) + { + subst(id, add(TNumber{}), "builtinTypeFunctions", {{id, index}}); + } + } + } +} + +// Replace union<>, intersect<>, and refine<> with unions or intersections. +// These type functions exist primarily to cause simplification to defer until +// particular points in execution, so it is safe to get rid of them here. +// +// It's not clear that these type functions should exist at all. +void Simplifier::iffyTypeFunctions(Id id) +{ + for (const auto [tfun, index] : Query(&egraph, id)) + { + const Slice& args = tfun->operands(); + + const std::string& name = tfun->value()->function->name; + + if (name == "union") + subst(id, add(Union{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); + else if (name == "intersect") + subst(id, add(Intersection{std::vector(args.begin(), args.end())}), "iffyTypeFunctions", {{id, index}}); + } +} + +// Replace instances of `lt` and `le` when either X or Y is `number` +// or `string` with `boolean`. Lua semantics are that if we see the expression: +// +// x < y +// +// ... we error if `x` and `y` don't have the same type. We know that for +// `string` and `number`, comparisons will always return a boolean. So if either +// of the arguments to `lt<>` are equivalent to `number` or `string`, then the +// type is effectively `boolean`: either the other type is equivalent, in which +// case we eval to `boolean`, or we diverge (raise an error). +void Simplifier::strictMetamethods(Id id) +{ + for (const auto [tfun, index] : Query(&egraph, id)) + { + const Slice& args = tfun->operands(); + + const std::string& name = tfun->value()->function->name; + + if (!(name == "lt" || name == "le") || args.size() != 2) + continue; + + if (isTag(args[0]) || isTag(args[0]) || isTag(args[1]) || isTag(args[1])) + { + subst(id, add(TBoolean{}), __FUNCTION__, {{id, index}}); + } + } +} + +static void deleteSimplifier(Simplifier* s) +{ + delete s; +} + +SimplifierPtr newSimplifier(NotNull arena, NotNull builtinTypes) +{ + return SimplifierPtr{new Simplifier(arena, builtinTypes), &deleteSimplifier}; +} + +} // namespace Luau::EqSatSimplification + +namespace Luau +{ + +std::optional eqSatSimplify(NotNull simplifier, TypeId ty) +{ + using namespace Luau::EqSatSimplification; + + std::unordered_map newMappings; + Id rootId = toId(simplifier->egraph, simplifier->builtinTypes, newMappings, simplifier->stringCache, ty); + simplifier->mappingIdToClass.insert(newMappings.begin(), newMappings.end()); + + Simplifier::RewriteRuleFn rules[] = { + &Simplifier::simplifyUnion, + &Simplifier::uninhabitedIntersection, + &Simplifier::intersectWithNegatedClass, + &Simplifier::intersectWithNegatedAtom, + &Simplifier::intersectWithNoRefine, + &Simplifier::cyclicIntersectionOfUnion, + &Simplifier::cyclicUnionOfIntersection, + &Simplifier::expandNegation, + &Simplifier::intersectionOfUnion, + &Simplifier::intersectTableProperty, + &Simplifier::uninhabitedTable, + &Simplifier::unneededTableModification, + &Simplifier::builtinTypeFunctions, + &Simplifier::iffyTypeFunctions, + &Simplifier::strictMetamethods, + }; + + std::unordered_set seen; + VecDeque worklist; + + bool progressed = true; + + int count = 0; + const int MAX_COUNT = 1000; + + if (FFlag::DebugLuauLogSimplificationToDot) + std::ofstream("begin.dot") << toDot(simplifier->stringCache, simplifier->egraph); + + auto& egraph = simplifier->egraph; + const auto& builtinTypes = simplifier->builtinTypes; + auto& arena = simplifier->arena; + + if (FFlag::DebugLuauLogSimplification) + printf(">> simplify %s\n", toString(ty).c_str()); + + while (progressed && count < MAX_COUNT) + { + progressed = false; + worklist.clear(); + seen.clear(); + + rootId = egraph.find(rootId); + + worklist.push_back(rootId); + + if (FFlag::DebugLuauLogSimplification) + { + std::vector newTypeFunctions; + const TypeId t = fromId(egraph, simplifier->stringCache, builtinTypes, arena, newTypeFunctions, rootId); + + std::cout << "Begin (" << uint32_t(egraph.find(rootId)) << ")\t" << toString(t) << '\n'; + } + + while (!worklist.empty() && count < MAX_COUNT) + { + Id id = egraph.find(worklist.front()); + worklist.pop_front(); + + const bool isFresh = seen.insert(id).second; + if (!isFresh) + continue; + + simplifier->substs.clear(); + + // Optimization: If this class alraedy has a terminal node, don't + // try to run any rules on it. + bool shouldAbort = false; + + for (const auto& enode : egraph[id].nodes) + { + if (isTerminal(enode.node)) + { + shouldAbort = true; + break; + } + } + + if (shouldAbort) + continue; + + for (const auto& enode : egraph[id].nodes) + addChildren(egraph, &enode.node, worklist); + + for (Simplifier::RewriteRuleFn rule : rules) + (simplifier.get()->*rule)(id); + + if (simplifier->substs.empty()) + continue; + + for (const Subst& subst : simplifier->substs) + { + if (subst.newClass == subst.eclass) + continue; + + if (FFlag::DebugLuauExtraEqSatSanityChecks) + { + const Id never = egraph.find(egraph.add(TNever{})); + const Id str = egraph.find(egraph.add(TString{})); + const Id unk = egraph.find(egraph.add(TUnknown{})); + LUAU_ASSERT(never != str); + LUAU_ASSERT(never != unk); + } + + const bool isFresh = egraph.merge(subst.newClass, subst.eclass); + + ++count; + + if (FFlag::DebugLuauLogSimplification && isFresh) + std::cout << "count=" << std::setw(3) << count << "\t" << subst.desc << '\n'; + + if (FFlag::DebugLuauLogSimplificationToDot) + { + std::string filename = format("step%03d.dot", count); + std::ofstream(filename) << toDot(simplifier->stringCache, egraph); + } + + if (FFlag::DebugLuauExtraEqSatSanityChecks) + { + const Id never = egraph.find(egraph.add(TNever{})); + const Id str = egraph.find(egraph.add(TString{})); + const Id unk = egraph.find(egraph.add(TUnknown{})); + const Id trueId = egraph.find(egraph.add(SBoolean{true})); + + LUAU_ASSERT(never != str); + LUAU_ASSERT(never != unk); + LUAU_ASSERT(never != trueId); + } + + progressed |= isFresh; + } + + egraph.rebuild(); + } + } + + EqSatSimplificationResult result; + result.result = fromId(egraph, simplifier->stringCache, builtinTypes, arena, result.newTypeFunctions, rootId); + + if (FFlag::DebugLuauLogSimplification) + printf("<< simplify %s\n", toString(result.result).c_str()); + + return result; +} + +} // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 60058d99..66b61d6b 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -18,8 +18,6 @@ LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10) -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauImproveNonFunctionCallError, false) - static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, @@ -408,35 +406,30 @@ struct ErrorConverter std::string operator()(const Luau::CannotCallNonFunction& e) const { - if (DFFlag::LuauImproveNonFunctionCallError) + if (auto unionTy = get(follow(e.ty))) { - if (auto unionTy = get(follow(e.ty))) + std::string err = "Cannot call a value of the union type:"; + + for (auto option : unionTy) { - std::string err = "Cannot call a value of the union type:"; + option = follow(option); - for (auto option : unionTy) + if (get(option) || findCallMetamethod(option)) { - option = follow(option); - - if (get(option) || findCallMetamethod(option)) - { - err += "\n | " + toString(option); - continue; - } - - // early-exit if we find something that isn't callable in the union. - return "Cannot call a value of type " + toString(option) + " in union:\n " + toString(e.ty); + err += "\n | " + toString(option); + continue; } - err += "\nWe are unable to determine the appropriate result type for such a call."; - - return err; + // early-exit if we find something that isn't callable in the union. + return "Cannot call a value of type " + toString(option) + " in union:\n " + toString(e.ty); } - return "Cannot call a value of type " + toString(e.ty); + err += "\nWe are unable to determine the appropriate result type for such a call."; + + return err; } - return "Cannot call non-function " + toString(e.ty); + return "Cannot call a value of type " + toString(e.ty); } std::string operator()(const Luau::ExtraInformation& e) const { @@ -793,6 +786,11 @@ struct ErrorConverter return "Encountered an unexpected type pack in subtyping: " + toString(e.tp); } + std::string operator()(const UserDefinedTypeFunctionError& e) const + { + return e.message; + } + std::string operator()(const CannotAssignToNever& e) const { std::string result = "Cannot assign a value of type " + toString(e.rhsType) + " to a field of type never"; @@ -1175,6 +1173,11 @@ bool UnexpectedTypePackInSubtyping::operator==(const UnexpectedTypePackInSubtypi return tp == rhs.tp; } +bool UserDefinedTypeFunctionError::operator==(const UserDefinedTypeFunctionError& rhs) const +{ + return message == rhs.message; +} + bool CannotAssignToNever::operator==(const CannotAssignToNever& rhs) const { if (cause.size() != rhs.cause.size()) @@ -1384,6 +1387,9 @@ void copyError(T& e, TypeArena& destArena, CloneState& cloneState) e.ty = clone(e.ty); else if constexpr (std::is_same_v) e.tp = clone(e.tp); + else if constexpr (std::is_same_v) + { + } else if constexpr (std::is_same_v) { e.rhsType = clone(e.rhsType); diff --git a/Analysis/src/FragmentAutocomplete.cpp b/Analysis/src/FragmentAutocomplete.cpp new file mode 100644 index 00000000..47c0c1a1 --- /dev/null +++ b/Analysis/src/FragmentAutocomplete.cpp @@ -0,0 +1,708 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/FragmentAutocomplete.h" + +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/Autocomplete.h" +#include "Luau/Common.h" +#include "Luau/EqSatSimplification.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Parser.h" +#include "Luau/ParseOptions.h" +#include "Luau/Module.h" +#include "Luau/TimeTrace.h" +#include "Luau/UnifierSharedState.h" +#include "Luau/TypeFunction.h" +#include "Luau/DataFlowGraph.h" +#include "Luau/ConstraintGenerator.h" +#include "Luau/ConstraintSolver.h" +#include "Luau/Frontend.h" +#include "Luau/Parser.h" +#include "Luau/ParseOptions.h" +#include "Luau/Module.h" +#include "Luau/Clone.h" +#include "AutocompleteCore.h" + +LUAU_FASTINT(LuauTypeInferRecursionLimit); +LUAU_FASTINT(LuauTypeInferIterationLimit); +LUAU_FASTINT(LuauTarjanChildLimit) +LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) + +LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteBugfixes) +LUAU_FASTFLAGVARIABLE(LuauMixedModeDefFinderTraversesTypeOf) +LUAU_FASTFLAG(LuauBetterReverseDependencyTracking) +LUAU_FASTFLAGVARIABLE(LuauCloneIncrementalModule) +LUAU_FASTFLAGVARIABLE(LogFragmentsFromAutocomplete) +namespace +{ +template +void copyModuleVec(std::vector& result, const std::vector& input) +{ + result.insert(result.end(), input.begin(), input.end()); +} + +template +void copyModuleMap(Luau::DenseHashMap& result, const Luau::DenseHashMap& input) +{ + for (auto [k, v] : input) + result[k] = v; +} + +} // namespace + +namespace Luau +{ + +template +void cloneModuleMap(TypeArena& destArena, CloneState& cloneState, const Luau::DenseHashMap& source, Luau::DenseHashMap& dest) +{ + for (auto [k, v] : source) + { + dest[k] = Luau::clone(v, destArena, cloneState); + } +} + +struct MixedModeIncrementalTCDefFinder : public AstVisitor +{ + bool visit(AstExprLocal* local) override + { + referencedLocalDefs.emplace_back(local->local, local); + return true; + } + + bool visit(AstTypeTypeof* node) override + { + // We need to traverse typeof expressions because they may refer to locals that we need + // to populate the local environment for fragment typechecking. For example, `typeof(m)` + // requires that we find the local/global `m` and place it in the environment. + // The default behaviour here is to return false, and have individual visitors override + // the specific behaviour they need. + return FFlag::LuauMixedModeDefFinderTraversesTypeOf; + } + + // ast defs is just a mapping from expr -> def in general + // will get built up by the dfg builder + + // localDefs, we need to copy over + std::vector> referencedLocalDefs; +}; + +void cloneAndSquashScopes( + CloneState& cloneState, + const Scope* staleScope, + const ModulePtr& staleModule, + NotNull destArena, + NotNull dfg, + AstStatBlock* program, + Scope* destScope +) +{ + LUAU_TIMETRACE_SCOPE("Luau::cloneAndSquashScopes", "FragmentAutocomplete"); + std::vector scopes; + for (const Scope* current = staleScope; current; current = current->parent.get()) + { + scopes.emplace_back(current); + } + + // in reverse order (we need to clone the parents and override defs as we go down the list) + for (auto it = scopes.rbegin(); it != scopes.rend(); ++it) + { + const Scope* curr = *it; + // Clone the lvalue types + for (const auto& [def, ty] : curr->lvalueTypes) + destScope->lvalueTypes[def] = Luau::clone(ty, *destArena, cloneState); + // Clone the rvalueRefinements + for (const auto& [def, ty] : curr->rvalueRefinements) + destScope->rvalueRefinements[def] = Luau::clone(ty, *destArena, cloneState); + for (const auto& [n, m] : curr->importedTypeBindings) + { + std::unordered_map importedBindingTypes; + for (const auto& [v, tf] : m) + importedBindingTypes[v] = Luau::clone(tf, *destArena, cloneState); + destScope->importedTypeBindings[n] = m; + } + + // Finally, clone up the bindings + for (const auto& [s, b] : curr->bindings) + { + destScope->bindings[s] = Luau::clone(b, *destArena, cloneState); + } + } + + // The above code associates defs with TypeId's in the scope + // so that lookup to locals will succeed. + MixedModeIncrementalTCDefFinder finder; + program->visit(&finder); + std::vector> locals = std::move(finder.referencedLocalDefs); + for (auto [loc, expr] : locals) + { + if (std::optional binding = staleScope->linearSearchForBinding(loc->name.value, true)) + { + destScope->lvalueTypes[dfg->getDef(expr)] = Luau::clone(binding->typeId, *destArena, cloneState); + } + } + return; +} + +static FrontendModuleResolver& getModuleResolver(Frontend& frontend, std::optional options) +{ + if (FFlag::LuauSolverV2 || !options) + return frontend.moduleResolver; + + return options->forAutocomplete ? frontend.moduleResolverForAutocomplete : frontend.moduleResolver; +} + +FragmentAutocompleteAncestryResult findAncestryForFragmentParse(AstStatBlock* root, const Position& cursorPos) +{ + std::vector ancestry = findAncestryAtPositionForAutocomplete(root, cursorPos); + // Should always contain the root AstStat + LUAU_ASSERT(ancestry.size() >= 1); + DenseHashMap localMap{AstName()}; + std::vector localStack; + AstStat* nearestStatement = nullptr; + for (AstNode* node : ancestry) + { + if (auto block = node->as()) + { + for (auto stat : block->body) + { + if (stat->location.begin <= cursorPos) + nearestStatement = stat; + if (stat->location.begin < cursorPos && stat->location.begin.line < cursorPos.line) + { + // This statement precedes the current one + if (auto loc = stat->as()) + { + for (auto v : loc->vars) + { + localStack.push_back(v); + localMap[v->name] = v; + } + } + else if (auto locFun = stat->as()) + { + localStack.push_back(locFun->name); + localMap[locFun->name->name] = locFun->name; + if (locFun->location.contains(cursorPos)) + { + for (AstLocal* loc : locFun->func->args) + { + localStack.push_back(loc); + localMap[loc->name] = loc; + } + } + } + else if (auto globFun = stat->as()) + { + if (globFun->location.contains(cursorPos)) + { + for (AstLocal* loc : globFun->func->args) + { + localStack.push_back(loc); + localMap[loc->name] = loc; + } + } + } + } + } + } + } + + if (!nearestStatement) + nearestStatement = ancestry[0]->asStat(); + LUAU_ASSERT(nearestStatement); + return {std::move(localMap), std::move(localStack), std::move(ancestry), std::move(nearestStatement)}; +} + +/** + * Get document offsets is a function that takes a source text document as well as a start position and end position(line, column) in that + * document and attempts to get the concrete text between those points. It returns a pair of: + * - start offset that represents an index in the source `char*` corresponding to startPos + * - length, that represents how many more bytes to read to get to endPos. + * Example - your document is "foo bar baz" and getDocumentOffsets is passed (0, 4), (0, 8). This function returns the pair {3, 5} + * which corresponds to the string " bar " + */ +std::pair getDocumentOffsets(const std::string_view& src, const Position& startPos, const Position& endPos) +{ + size_t lineCount = 0; + size_t colCount = 0; + + size_t docOffset = 0; + size_t startOffset = 0; + size_t endOffset = 0; + bool foundStart = false; + bool foundEnd = false; + + for (char c : src) + { + if (foundStart && foundEnd) + break; + + if (startPos.line == lineCount && startPos.column == colCount) + { + foundStart = true; + startOffset = docOffset; + } + + if (endPos.line == lineCount && endPos.column == colCount) + { + endOffset = docOffset; + while (endOffset < src.size() && src[endOffset] != '\n') + endOffset++; + foundEnd = true; + } + + // We put a cursor position that extends beyond the extents of the current line + if (foundStart && !foundEnd && (lineCount > endPos.line)) + { + foundEnd = true; + endOffset = docOffset - 1; + } + + if (c == '\n') + { + lineCount++; + colCount = 0; + } + else + { + colCount++; + } + docOffset++; + } + + if (foundStart && !foundEnd) + endOffset = src.length(); + + size_t min = std::min(startOffset, endOffset); + size_t len = std::max(startOffset, endOffset) - min; + return {min, len}; +} + +ScopePtr findClosestScope(const ModulePtr& module, const AstStat* nearestStatement) +{ + LUAU_ASSERT(module->hasModuleScope()); + + ScopePtr closest = module->getModuleScope(); + + // find the scope the nearest statement belonged to. + for (auto [loc, sc] : module->scopes) + { + if (loc.encloses(nearestStatement->location) && closest->location.begin <= loc.begin) + closest = sc; + } + + return closest; +} + +std::optional parseFragment( + const SourceModule& srcModule, + std::string_view src, + const Position& cursorPos, + std::optional fragmentEndPosition +) +{ + FragmentAutocompleteAncestryResult result = findAncestryForFragmentParse(srcModule.root, cursorPos); + AstStat* nearestStatement = result.nearestStatement; + + const Location& rootSpan = srcModule.root->location; + // Did we append vs did we insert inline + bool appended = cursorPos >= rootSpan.end; + // statement spans multiple lines + bool multiline = nearestStatement->location.begin.line != nearestStatement->location.end.line; + + const Position endPos = fragmentEndPosition.value_or(cursorPos); + + // We start by re-parsing everything (we'll refine this as we go) + Position startPos = srcModule.root->location.begin; + + // If we added to the end of the sourceModule, use the end of the nearest location + if (appended && multiline) + startPos = nearestStatement->location.end; + // Statement spans one line && cursorPos is either on the same line or after + else if (!multiline && cursorPos.line >= nearestStatement->location.end.line) + startPos = nearestStatement->location.begin; + else if (multiline && nearestStatement->location.end.line < cursorPos.line) + startPos = nearestStatement->location.end; + else + startPos = nearestStatement->location.begin; + + auto [offsetStart, parseLength] = getDocumentOffsets(src, startPos, endPos); + const char* srcStart = src.data() + offsetStart; + std::string_view dbg = src.substr(offsetStart, parseLength); + const std::shared_ptr& nameTbl = srcModule.names; + FragmentParseResult fragmentResult; + fragmentResult.fragmentToParse = std::string(dbg.data(), parseLength); + // For the duration of the incremental parse, we want to allow the name table to re-use duplicate names + if (FFlag::LogFragmentsFromAutocomplete) + logLuau(dbg); + + ParseOptions opts; + opts.allowDeclarationSyntax = false; + opts.captureComments = true; + opts.parseFragment = FragmentParseResumeSettings{std::move(result.localMap), std::move(result.localStack), startPos}; + ParseResult p = Luau::Parser::parse(srcStart, parseLength, *nameTbl, *fragmentResult.alloc.get(), opts); + // This means we threw a ParseError and we should decline to offer autocomplete here. + if (p.root == nullptr) + return std::nullopt; + + std::vector fabricatedAncestry = std::move(result.ancestry); + + // Get the ancestry for the fragment at the offset cursor position. + // Consumers have the option to request with fragment end position, so we cannot just use the end position of our parse result as the + // cursor position. Instead, use the cursor position calculated as an offset from our start position. + std::vector fragmentAncestry = findAncestryAtPositionForAutocomplete(p.root, cursorPos); + fabricatedAncestry.insert(fabricatedAncestry.end(), fragmentAncestry.begin(), fragmentAncestry.end()); + if (nearestStatement == nullptr) + nearestStatement = p.root; + fragmentResult.root = std::move(p.root); + fragmentResult.ancestry = std::move(fabricatedAncestry); + fragmentResult.nearestStatement = nearestStatement; + fragmentResult.commentLocations = std::move(p.commentLocations); + return fragmentResult; +} + +ModulePtr cloneModule(CloneState& cloneState, const ModulePtr& source, std::unique_ptr alloc) +{ + LUAU_TIMETRACE_SCOPE("Luau::cloneModule", "FragmentAutocomplete"); + freeze(source->internalTypes); + freeze(source->interfaceTypes); + ModulePtr incremental = std::make_shared(); + incremental->name = source->name; + incremental->humanReadableName = source->humanReadableName; + incremental->allocator = std::move(alloc); + // Clone types + cloneModuleMap(incremental->internalTypes, cloneState, source->astTypes, incremental->astTypes); + cloneModuleMap(incremental->internalTypes, cloneState, source->astTypePacks, incremental->astTypePacks); + cloneModuleMap(incremental->internalTypes, cloneState, source->astExpectedTypes, incremental->astExpectedTypes); + + cloneModuleMap(incremental->internalTypes, cloneState, source->astOverloadResolvedTypes, incremental->astOverloadResolvedTypes); + + cloneModuleMap(incremental->internalTypes, cloneState, source->astForInNextTypes, incremental->astForInNextTypes); + + copyModuleMap(incremental->astScopes, source->astScopes); + + return incremental; +} + +ModulePtr copyModule(const ModulePtr& result, std::unique_ptr alloc) +{ + ModulePtr incrementalModule = std::make_shared(); + incrementalModule->name = result->name; + incrementalModule->humanReadableName = "Incremental$" + result->humanReadableName; + incrementalModule->internalTypes.owningModule = incrementalModule.get(); + incrementalModule->interfaceTypes.owningModule = incrementalModule.get(); + incrementalModule->allocator = std::move(alloc); + // Don't need to keep this alive (it's already on the source module) + copyModuleVec(incrementalModule->scopes, result->scopes); + copyModuleMap(incrementalModule->astTypes, result->astTypes); + copyModuleMap(incrementalModule->astTypePacks, result->astTypePacks); + copyModuleMap(incrementalModule->astExpectedTypes, result->astExpectedTypes); + // Don't need to clone astOriginalCallTypes + copyModuleMap(incrementalModule->astOverloadResolvedTypes, result->astOverloadResolvedTypes); + // Don't need to clone astForInNextTypes + copyModuleMap(incrementalModule->astForInNextTypes, result->astForInNextTypes); + // Don't need to clone astResolvedTypes + // Don't need to clone astResolvedTypePacks + // Don't need to clone upperBoundContributors + copyModuleMap(incrementalModule->astScopes, result->astScopes); + // Don't need to clone declared Globals; + return incrementalModule; +} + +void mixedModeCompatibility( + const ScopePtr& bottomScopeStale, + const ScopePtr& myFakeScope, + const ModulePtr& stale, + NotNull dfg, + AstStatBlock* program +) +{ + // This code does the following + // traverse program + // look for ast refs for locals + // ask for the corresponding defId from dfg + // given that defId, and that expression, in the incremental module, map lvalue types from defID to + + MixedModeIncrementalTCDefFinder finder; + program->visit(&finder); + std::vector> locals = std::move(finder.referencedLocalDefs); + for (auto [loc, expr] : locals) + { + if (std::optional binding = bottomScopeStale->linearSearchForBinding(loc->name.value, true)) + { + myFakeScope->lvalueTypes[dfg->getDef(expr)] = binding->typeId; + } + } +} + +FragmentTypeCheckResult typecheckFragment_( + Frontend& frontend, + AstStatBlock* root, + const ModulePtr& stale, + const ScopePtr& closestScope, + const Position& cursorPos, + std::unique_ptr astAllocator, + const FrontendOptions& opts +) +{ + LUAU_TIMETRACE_SCOPE("Luau::typecheckFragment_", "FragmentAutocomplete"); + + freeze(stale->internalTypes); + freeze(stale->interfaceTypes); + CloneState cloneState{frontend.builtinTypes}; + ModulePtr incrementalModule = + FFlag::LuauCloneIncrementalModule ? cloneModule(cloneState, stale, std::move(astAllocator)) : copyModule(stale, std::move(astAllocator)); + incrementalModule->checkedInNewSolver = true; + unfreeze(incrementalModule->internalTypes); + unfreeze(incrementalModule->interfaceTypes); + + /// Setup typecheck limits + TypeCheckLimits limits; + if (opts.moduleTimeLimitSec) + limits.finishTime = TimeTrace::getClock() + *opts.moduleTimeLimitSec; + else + limits.finishTime = std::nullopt; + limits.cancellationToken = opts.cancellationToken; + + /// Icehandler + NotNull iceHandler{&frontend.iceHandler}; + /// Make the shared state for the unifier (recursion + iteration limits) + UnifierSharedState unifierState{iceHandler}; + unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; + unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); + + /// Initialize the normalizer + Normalizer normalizer{&incrementalModule->internalTypes, frontend.builtinTypes, NotNull{&unifierState}}; + + /// User defined type functions runtime + TypeFunctionRuntime typeFunctionRuntime(iceHandler, NotNull{&limits}); + + /// Create a DataFlowGraph just for the surrounding context + DataFlowGraph dfg = DataFlowGraphBuilder::build(root, NotNull{&incrementalModule->defArena}, NotNull{&incrementalModule->keyArena}, iceHandler); + + SimplifierPtr simplifier = newSimplifier(NotNull{&incrementalModule->internalTypes}, frontend.builtinTypes); + + FrontendModuleResolver& resolver = getModuleResolver(frontend, opts); + + /// Contraint Generator + ConstraintGenerator cg{ + incrementalModule, + NotNull{&normalizer}, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, + NotNull{&resolver}, + frontend.builtinTypes, + iceHandler, + stale->getModuleScope(), + nullptr, + nullptr, + NotNull{&dfg}, + {} + }; + std::shared_ptr freshChildOfNearestScope = nullptr; + if (FFlag::LuauCloneIncrementalModule) + { + freshChildOfNearestScope = std::make_shared(closestScope); + incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); + cg.rootScope = freshChildOfNearestScope.get(); + + cloneAndSquashScopes( + cloneState, closestScope.get(), stale, NotNull{&incrementalModule->internalTypes}, NotNull{&dfg}, root, freshChildOfNearestScope.get() + ); + cg.visitFragmentRoot(freshChildOfNearestScope, root); + } + else + { + // Any additions to the scope must occur in a fresh scope + cg.rootScope = stale->getModuleScope().get(); + freshChildOfNearestScope = std::make_shared(closestScope); + incrementalModule->scopes.emplace_back(root->location, freshChildOfNearestScope); + mixedModeCompatibility(closestScope, freshChildOfNearestScope, stale, NotNull{&dfg}, root); + // closest Scope -> children = { ...., freshChildOfNearestScope} + // We need to trim nearestChild from the scope hierarcy + closestScope->children.emplace_back(freshChildOfNearestScope.get()); + cg.visitFragmentRoot(freshChildOfNearestScope, root); + // Trim nearestChild from the closestScope + Scope* back = closestScope->children.back().get(); + LUAU_ASSERT(back == freshChildOfNearestScope.get()); + closestScope->children.pop_back(); + } + + /// Initialize the constraint solver and run it + ConstraintSolver cs{ + NotNull{&normalizer}, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, + NotNull(cg.rootScope), + borrowConstraints(cg.constraints), + NotNull{&cg.scopeToFunction}, + incrementalModule->name, + NotNull{&resolver}, + {}, + nullptr, + NotNull{&dfg}, + limits + }; + + try + { + cs.run(); + } + catch (const TimeLimitError&) + { + stale->timeout = true; + } + catch (const UserCancelError&) + { + stale->cancelled = true; + } + + // In frontend we would forbid internal types + // because this is just for autocomplete, we don't actually care + // We also don't even need to typecheck - just synthesize types as best as we can + + freeze(incrementalModule->internalTypes); + freeze(incrementalModule->interfaceTypes); + return {std::move(incrementalModule), std::move(freshChildOfNearestScope)}; +} + + +std::pair typecheckFragment( + Frontend& frontend, + const ModuleName& moduleName, + const Position& cursorPos, + std::optional opts, + std::string_view src, + std::optional fragmentEndPosition +) +{ + LUAU_TIMETRACE_SCOPE("Luau::typecheckFragment", "FragmentAutocomplete"); + LUAU_TIMETRACE_ARGUMENT("name", moduleName.c_str()); + + if (FFlag::LuauBetterReverseDependencyTracking) + { + if (!frontend.allModuleDependenciesValid(moduleName, opts && opts->forAutocomplete)) + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + } + + const SourceModule* sourceModule = frontend.getSourceModule(moduleName); + if (!sourceModule) + { + LUAU_ASSERT(!"Expected Source Module for fragment typecheck"); + return {}; + } + + FrontendModuleResolver& resolver = getModuleResolver(frontend, opts); + ModulePtr module = resolver.getModule(moduleName); + if (!module) + { + LUAU_ASSERT(!"Expected Module for fragment typecheck"); + return {}; + } + + if (FFlag::LuauIncrementalAutocompleteBugfixes) + { + if (sourceModule->allocator.get() != module->allocator.get()) + { + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + } + } + + auto tryParse = parseFragment(*sourceModule, src, cursorPos, fragmentEndPosition); + + if (!tryParse) + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + + FragmentParseResult& parseResult = *tryParse; + + if (isWithinComment(parseResult.commentLocations, fragmentEndPosition.value_or(cursorPos))) + return {FragmentTypeCheckStatus::SkipAutocomplete, {}}; + + FrontendOptions frontendOptions = opts.value_or(frontend.options); + const ScopePtr& closestScope = findClosestScope(module, parseResult.nearestStatement); + FragmentTypeCheckResult result = + typecheckFragment_(frontend, parseResult.root, module, closestScope, cursorPos, std::move(parseResult.alloc), frontendOptions); + result.ancestry = std::move(parseResult.ancestry); + return {FragmentTypeCheckStatus::Success, result}; +} + +FragmentAutocompleteStatusResult tryFragmentAutocomplete( + Frontend& frontend, + const ModuleName& moduleName, + Position cursorPosition, + FragmentContext context, + StringCompletionCallback stringCompletionCB +) +{ + // TODO: we should calculate fragmentEnd position here, by using context.newAstRoot and cursorPosition + try + { + Luau::FragmentAutocompleteResult fragmentAutocomplete = Luau::fragmentAutocomplete( + frontend, + context.newSrc, + moduleName, + cursorPosition, + context.opts, + std::move(stringCompletionCB), + context.DEPRECATED_fragmentEndPosition + ); + return {FragmentAutocompleteStatus::Success, std::move(fragmentAutocomplete)}; + } + catch (const Luau::InternalCompilerError& e) + { + if (FFlag::LogFragmentsFromAutocomplete) + logLuau(e.what()); + return {FragmentAutocompleteStatus::InternalIce, std::nullopt}; + } +} + +FragmentAutocompleteResult fragmentAutocomplete( + Frontend& frontend, + std::string_view src, + const ModuleName& moduleName, + Position cursorPosition, + std::optional opts, + StringCompletionCallback callback, + std::optional fragmentEndPosition +) +{ + LUAU_ASSERT(FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete); + LUAU_TIMETRACE_SCOPE("Luau::fragmentAutocomplete", "FragmentAutocomplete"); + LUAU_TIMETRACE_ARGUMENT("name", moduleName.c_str()); + + const SourceModule* sourceModule = frontend.getSourceModule(moduleName); + if (!sourceModule) + { + LUAU_ASSERT(!"Expected Source Module for fragment typecheck"); + return {}; + } + + // If the cursor is within a comment in the stale source module we should avoid providing a recommendation + if (isWithinComment(*sourceModule, fragmentEndPosition.value_or(cursorPosition))) + return {}; + + auto [tcStatus, tcResult] = typecheckFragment(frontend, moduleName, cursorPosition, opts, src, fragmentEndPosition); + if (tcStatus == FragmentTypeCheckStatus::SkipAutocomplete) + return {}; + + auto globalScope = (opts && opts->forAutocomplete) ? frontend.globalsForAutocomplete.globalScope.get() : frontend.globals.globalScope.get(); + if (FFlag::LogFragmentsFromAutocomplete) + logLuau(src); + TypeArena arenaForFragmentAutocomplete; + auto result = Luau::autocomplete_( + tcResult.incrementalModule, + frontend.builtinTypes, + &arenaForFragmentAutocomplete, + tcResult.ancestry, + globalScope, + tcResult.freshScope, + cursorPosition, + frontend.fileResolver, + callback + ); + + return {std::move(tcResult.incrementalModule), tcResult.freshScope.get(), std::move(arenaForFragmentAutocomplete), std::move(result)}; +} + +} // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 8c439181..4bb801ae 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -10,8 +10,10 @@ #include "Luau/ConstraintSolver.h" #include "Luau/DataFlowGraph.h" #include "Luau/DcrLogger.h" +#include "Luau/EqSatSimplification.h" #include "Luau/FileResolver.h" #include "Luau/NonStrictTypeChecker.h" +#include "Luau/NotNull.h" #include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" @@ -36,18 +38,21 @@ LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) -LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauStoreCommentsForDefinitionFiles, false) +LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) -LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile, false) -LUAU_FASTFLAGVARIABLE(DebugLuauForbidInternalTypes, false) -LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode, false) -LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode, false) -LUAU_FASTFLAGVARIABLE(LuauSourceModuleUpdatedWithSelectedMode, false) +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson) +LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJsonFile) +LUAU_FASTFLAGVARIABLE(DebugLuauForbidInternalTypes) +LUAU_FASTFLAGVARIABLE(DebugLuauForceStrictMode) +LUAU_FASTFLAGVARIABLE(DebugLuauForceNonStrictMode) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauRunCustomModuleChecks, false) +LUAU_FASTFLAGVARIABLE(LuauBetterReverseDependencyTracking) + LUAU_FASTFLAG(StudioReportLuauAny2) +LUAU_FASTFLAGVARIABLE(LuauStoreSolverTypeOnModule) + +LUAU_FASTFLAGVARIABLE(LuauSelectivelyRetainDFGArena) namespace Luau { @@ -134,7 +139,7 @@ static ParseResult parseSourceForModule(std::string_view source, Luau::SourceMod sourceModule.root = parseResult.root; sourceModule.mode = Mode::Definition; - if (FFlag::LuauStoreCommentsForDefinitionFiles && options.captureComments) + if (options.captureComments) { sourceModule.hotcomments = parseResult.hotcomments; sourceModule.commentLocations = parseResult.commentLocations; @@ -205,72 +210,6 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile( return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } -std::vector parsePathExpr(const AstExpr& pathExpr) -{ - const AstExprIndexName* indexName = pathExpr.as(); - if (!indexName) - return {}; - - std::vector segments{indexName->index.value}; - - while (true) - { - if (AstExprIndexName* in = indexName->expr->as()) - { - segments.push_back(in->index.value); - indexName = in; - continue; - } - else if (AstExprGlobal* indexNameAsGlobal = indexName->expr->as()) - { - segments.push_back(indexNameAsGlobal->name.value); - break; - } - else if (AstExprLocal* indexNameAsLocal = indexName->expr->as()) - { - segments.push_back(indexNameAsLocal->local->name.value); - break; - } - else - return {}; - } - - std::reverse(segments.begin(), segments.end()); - return segments; -} - -std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& segments) -{ - if (segments.empty()) - return std::nullopt; - - std::vector result; - - auto it = segments.begin(); - - if (*it == "script" && !currentModuleName.empty()) - { - result = split(currentModuleName, '/'); - ++it; - } - - for (; it != segments.end(); ++it) - { - if (result.size() > 1 && *it == "Parent") - result.pop_back(); - else - result.push_back(*it); - } - - return join(result, "/"); -} - -std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& pathExpr) -{ - std::vector segments = parsePathExpr(pathExpr); - return pathExprToModuleName(currentModuleName, segments); -} - namespace { @@ -351,8 +290,7 @@ static void filterLintOptions(LintOptions& lintOptions, const std::vector getRequireCycles( const FileResolver* resolver, const std::unordered_map>& sourceNodes, - const SourceNode* start, - bool stopAtFirst = false + const SourceNode* start ) { std::vector result; @@ -422,9 +360,6 @@ std::vector getRequireCycles( { result.push_back({depLocation, std::move(cycle)}); - if (stopAtFirst) - return result; - // note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start // so it's safe to *only* clear seen vector when we find a cycle // if we don't do it, we will not have correct reporting for some cycles @@ -812,6 +747,32 @@ std::optional Frontend::getCheckResult(const ModuleName& name, bool return checkResult; } +std::vector Frontend::getRequiredScripts(const ModuleName& name) +{ + RequireTraceResult require = requireTrace[name]; + if (isDirty(name)) + { + std::optional source = fileResolver->readSource(name); + if (!source) + { + return {}; + } + const Config& config = configResolver->getConfig(name); + ParseOptions opts = config.parseOptions; + opts.captureComments = true; + SourceModule result = parse(name, source->source, opts); + result.type = source->type; + require = traceRequires(fileResolver, result.root, name); + } + std::vector requiredModuleNames; + requiredModuleNames.reserve(require.requireList.size()); + for (const auto& [moduleName, _] : require.requireList) + { + requiredModuleNames.push_back(moduleName); + } + return requiredModuleNames; +} + bool Frontend::parseGraph( std::vector& buildQueue, const ModuleName& root, @@ -860,6 +821,16 @@ bool Frontend::parseGraph( topseen = Permanent; buildQueue.push_back(top->name); + + if (FFlag::LuauBetterReverseDependencyTracking) + { + // at this point we know all valid dependencies are processed into SourceNodes + for (const ModuleName& dep : top->requireSet) + { + if (auto it = sourceNodes.find(dep); it != sourceNodes.end()) + it->second->dependents.insert(top->name); + } + } } else { @@ -948,14 +919,11 @@ void Frontend::addBuildQueueItems( data.environmentScope = getModuleEnvironment(*sourceModule, data.config, frontendOptions.forAutocomplete); data.recordJsonLog = FFlag::DebugLuauLogSolverToJson; - Mode mode = sourceModule->mode.value_or(data.config.mode); - - // in NoCheck mode we only need to compute the value of .cyclic for typeck // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // all correct programs must be acyclic so this code triggers rarely if (cycleDetected) - data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get(), mode == Mode::NoCheck); + data.requireCycles = getRequireCycles(fileResolver, sourceNodes, sourceNode.get()); data.options = frontendOptions; @@ -987,8 +955,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) else mode = sourceModule.mode.value_or(config.mode); - if (FFlag::LuauSourceModuleUpdatedWithSelectedMode) - item.sourceModule->mode = {mode}; + item.sourceModule->mode = {mode}; ScopePtr environmentScope = item.environmentScope; double timestamp = getTimestamp(); const std::vector& requireCycles = item.requireCycles; @@ -1093,6 +1060,11 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) freeze(module->interfaceTypes); module->internalTypes.clear(); + if (FFlag::LuauSelectivelyRetainDFGArena) + { + module->defArena.allocator.clear(); + module->keyArena.allocator.clear(); + } module->astTypes.clear(); module->astTypePacks.clear(); @@ -1146,15 +1118,49 @@ void Frontend::recordItemResult(const BuildQueueItem& item) if (item.exception) std::rethrow_exception(item.exception); - if (item.options.forAutocomplete) + if (FFlag::LuauBetterReverseDependencyTracking) { - moduleResolverForAutocomplete.setModule(item.name, item.module); - item.sourceNode->dirtyModuleForAutocomplete = false; + bool replacedModule = false; + if (item.options.forAutocomplete) + { + replacedModule = moduleResolverForAutocomplete.setModule(item.name, item.module); + item.sourceNode->dirtyModuleForAutocomplete = false; + } + else + { + replacedModule = moduleResolver.setModule(item.name, item.module); + item.sourceNode->dirtyModule = false; + } + + if (replacedModule) + { + LUAU_TIMETRACE_SCOPE("Frontend::invalidateDependentModules", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", item.name.c_str()); + traverseDependents( + item.name, + [forAutocomplete = item.options.forAutocomplete](SourceNode& sourceNode) + { + bool traverseSubtree = !sourceNode.hasInvalidModuleDependency(forAutocomplete); + sourceNode.setInvalidModuleDependency(true, forAutocomplete); + return traverseSubtree; + } + ); + } + + item.sourceNode->setInvalidModuleDependency(false, item.options.forAutocomplete); } else { - moduleResolver.setModule(item.name, item.module); - item.sourceNode->dirtyModule = false; + if (item.options.forAutocomplete) + { + moduleResolverForAutocomplete.setModule(item.name, item.module); + item.sourceNode->dirtyModuleForAutocomplete = false; + } + else + { + moduleResolver.setModule(item.name, item.module); + item.sourceNode->dirtyModule = false; + } } stats.timeCheck += item.stats.timeCheck; @@ -1191,6 +1197,13 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config return result; } +bool Frontend::allModuleDependenciesValid(const ModuleName& name, bool forAutocomplete) const +{ + LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking); + auto it = sourceNodes.find(name); + return it != sourceNodes.end() && !it->second->hasInvalidModuleDependency(forAutocomplete); +} + bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const { auto it = sourceNodes.find(name); @@ -1205,16 +1218,80 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { + LUAU_TIMETRACE_SCOPE("Frontend::markDirty", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + + if (FFlag::LuauBetterReverseDependencyTracking) + { + traverseDependents( + name, + [markedDirty](SourceNode& sourceNode) + { + if (markedDirty) + markedDirty->push_back(sourceNode.name); + + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + return false; + + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; + + return true; + } + ); + } + else + { + if (sourceNodes.count(name) == 0) + return; + + std::unordered_map> reverseDeps; + for (const auto& module : sourceNodes) + { + for (const auto& dep : module.second->requireSet) + reverseDeps[dep].push_back(module.first); + } + + std::vector queue{name}; + + while (!queue.empty()) + { + ModuleName next = std::move(queue.back()); + queue.pop_back(); + + LUAU_ASSERT(sourceNodes.count(next) > 0); + SourceNode& sourceNode = *sourceNodes[next]; + + if (markedDirty) + markedDirty->push_back(next); + + if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + continue; + + sourceNode.dirtySourceModule = true; + sourceNode.dirtyModule = true; + sourceNode.dirtyModuleForAutocomplete = true; + + if (0 == reverseDeps.count(next)) + continue; + + sourceModules.erase(next); + + const std::vector& dependents = reverseDeps[next]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); + } + } +} + +void Frontend::traverseDependents(const ModuleName& name, std::function processSubtree) +{ + LUAU_ASSERT(FFlag::LuauBetterReverseDependencyTracking); + LUAU_TIMETRACE_SCOPE("Frontend::traverseDependents", "Frontend"); + if (sourceNodes.count(name) == 0) return; - std::unordered_map> reverseDeps; - for (const auto& module : sourceNodes) - { - for (const auto& dep : module.second->requireSet) - reverseDeps[dep].push_back(module.first); - } - std::vector queue{name}; while (!queue.empty()) @@ -1225,22 +1302,10 @@ void Frontend::markDirty(const ModuleName& name, std::vector* marked LUAU_ASSERT(sourceNodes.count(next) > 0); SourceNode& sourceNode = *sourceNodes[next]; - if (markedDirty) - markedDirty->push_back(next); - - if (sourceNode.dirtySourceModule && sourceNode.dirtyModule && sourceNode.dirtyModuleForAutocomplete) + if (!processSubtree(sourceNode)) continue; - sourceNode.dirtySourceModule = true; - sourceNode.dirtyModule = true; - sourceNode.dirtyModuleForAutocomplete = true; - - if (0 == reverseDeps.count(next)) - continue; - - sourceModules.erase(next); - - const std::vector& dependents = reverseDeps[next]; + const Set& dependents = sourceNode.dependents; queue.insert(queue.end(), dependents.begin(), dependents.end()); } } @@ -1357,11 +1422,15 @@ ModulePtr check( LUAU_TIMETRACE_ARGUMENT("name", sourceModule.humanReadableName.c_str()); ModulePtr result = std::make_shared(); + if (FFlag::LuauStoreSolverTypeOnModule) + result->checkedInNewSolver = true; result->name = sourceModule.name; result->humanReadableName = sourceModule.humanReadableName; result->mode = mode; result->internalTypes.owningModule = result.get(); result->interfaceTypes.owningModule = result.get(); + result->allocator = sourceModule.allocator; + result->names = sourceModule.names; iceHandler->moduleName = sourceModule.name; @@ -1376,17 +1445,23 @@ ModulePtr check( } } - DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, iceHandler); + DataFlowGraph dfg = DataFlowGraphBuilder::build(sourceModule.root, NotNull{&result->defArena}, NotNull{&result->keyArena}, iceHandler); UnifierSharedState unifierState{iceHandler}; unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.iterationLimit = limits.unifierIterationLimit.value_or(FInt::LuauTypeInferIterationLimit); Normalizer normalizer{&result->internalTypes, builtinTypes, NotNull{&unifierState}}; + SimplifierPtr simplifier = newSimplifier(NotNull{&result->internalTypes}, builtinTypes); + TypeFunctionRuntime typeFunctionRuntime{iceHandler, NotNull{&limits}}; + + typeFunctionRuntime.allowEvaluation = sourceModule.parseErrors.empty(); ConstraintGenerator cg{ result, NotNull{&normalizer}, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, moduleResolver, builtinTypes, iceHandler, @@ -1402,12 +1477,16 @@ ModulePtr check( ConstraintSolver cs{ NotNull{&normalizer}, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, NotNull(cg.rootScope), borrowConstraints(cg.constraints), + NotNull{&cg.scopeToFunction}, result->name, moduleResolver, requireCycles, logger.get(), + NotNull{&dfg}, limits }; @@ -1461,12 +1540,31 @@ ModulePtr check( switch (mode) { case Mode::Nonstrict: - Luau::checkNonStrict(builtinTypes, iceHandler, NotNull{&unifierState}, NotNull{&dfg}, NotNull{&limits}, sourceModule, result.get()); + Luau::checkNonStrict( + builtinTypes, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, + iceHandler, + NotNull{&unifierState}, + NotNull{&dfg}, + NotNull{&limits}, + sourceModule, + result.get() + ); break; case Mode::Definition: // fallthrough intentional case Mode::Strict: - Luau::check(builtinTypes, NotNull{&unifierState}, NotNull{&limits}, logger.get(), sourceModule, result.get()); + Luau::check( + builtinTypes, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, + NotNull{&unifierState}, + NotNull{&limits}, + logger.get(), + sourceModule, + result.get() + ); break; case Mode::NoCheck: break; @@ -1647,6 +1745,17 @@ std::pair Frontend::getSourceNode(const ModuleName& sourceNode->name = sourceModule->name; sourceNode->humanReadableName = sourceModule->humanReadableName; + + if (FFlag::LuauBetterReverseDependencyTracking) + { + // clear all prior dependents. we will re-add them after parsing the rest of the graph + for (const auto& [moduleName, _] : sourceNode->requireLocations) + { + if (auto depIt = sourceNodes.find(moduleName); depIt != sourceNodes.end()) + depIt->second->dependents.erase(sourceNode->name); + } + } + sourceNode->requireSet.clear(); sourceNode->requireLocations.clear(); sourceNode->dirtySourceModule = false; @@ -1768,11 +1877,21 @@ std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& return frontend->fileResolver->getHumanReadableModuleName(moduleName); } -void FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module) +bool FrontendModuleResolver::setModule(const ModuleName& moduleName, ModulePtr module) { std::scoped_lock lock(moduleMutex); - modules[moduleName] = std::move(module); + if (FFlag::LuauBetterReverseDependencyTracking) + { + bool replaced = modules.count(moduleName) > 0; + modules[moduleName] = std::move(module); + return replaced; + } + else + { + modules[moduleName] = std::move(module); + return false; + } } void FrontendModuleResolver::clearModules() diff --git a/Analysis/src/Generalization.cpp b/Analysis/src/Generalization.cpp index d209cb81..054ad509 100644 --- a/Analysis/src/Generalization.cpp +++ b/Analysis/src/Generalization.cpp @@ -2,6 +2,8 @@ #include "Luau/Generalization.h" +#include "Luau/Common.h" +#include "Luau/DenseHash.h" #include "Luau/Scope.h" #include "Luau/Type.h" #include "Luau/ToString.h" @@ -9,11 +11,15 @@ #include "Luau/TypePack.h" #include "Luau/VisitType.h" +LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) +LUAU_FASTFLAGVARIABLE(LuauGeneralizationRemoveRecursiveUpperBound2) + namespace Luau { struct MutatingGeneralizer : TypeOnceVisitor { + NotNull arena; NotNull builtinTypes; NotNull scope; @@ -27,6 +33,7 @@ struct MutatingGeneralizer : TypeOnceVisitor bool avoidSealingTables = false; MutatingGeneralizer( + NotNull arena, NotNull builtinTypes, NotNull scope, NotNull> cachedTypes, @@ -35,6 +42,7 @@ struct MutatingGeneralizer : TypeOnceVisitor bool avoidSealingTables ) : TypeOnceVisitor(/* skipBoundTypes */ true) + , arena(arena) , builtinTypes(builtinTypes) , scope(scope) , cachedTypes(cachedTypes) @@ -44,7 +52,7 @@ struct MutatingGeneralizer : TypeOnceVisitor { } - static void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) + void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) { haystack = follow(haystack); @@ -91,6 +99,10 @@ struct MutatingGeneralizer : TypeOnceVisitor LUAU_ASSERT(onlyType != haystack); emplaceType(asMutable(haystack), onlyType); } + else if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound2 && ut->options.empty()) + { + emplaceType(asMutable(haystack), builtinTypes->neverType); + } return; } @@ -133,6 +145,10 @@ struct MutatingGeneralizer : TypeOnceVisitor TypeId onlyType = it->parts[0]; LUAU_ASSERT(onlyType != needle); emplaceType(asMutable(needle), onlyType); + } + else if (FFlag::LuauGeneralizationRemoveRecursiveUpperBound2 && it->parts.empty()) + { + emplaceType(asMutable(needle), builtinTypes->unknownType); } return; @@ -445,7 +461,7 @@ struct FreeTypeSearcher : TypeVisitor traverse(*prop.readTy); else { - LUAU_ASSERT(prop.isShared()); + LUAU_ASSERT(prop.isShared() || FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete); Polarity p = polarity; polarity = Both; @@ -526,7 +542,7 @@ struct TypeCacher : TypeOnceVisitor DenseHashSet uncacheablePacks{nullptr}; explicit TypeCacher(NotNull> cachedTypes) - : TypeOnceVisitor(/* skipBoundTypes */ true) + : TypeOnceVisitor(/* skipBoundTypes */ false) , cachedTypes(cachedTypes) { } @@ -563,9 +579,19 @@ struct TypeCacher : TypeOnceVisitor bool visit(TypeId ty) override { - if (isUncacheable(ty) || isCached(ty)) - return false; - return true; + // NOTE: `TypeCacher` should explicitly visit _all_ types and type packs, + // otherwise it's prone to marking types that cannot be cached as + // cacheable. + LUAU_ASSERT(false); + LUAU_UNREACHABLE(); + } + + bool visit(TypeId ty, const BoundType& btv) override + { + traverse(btv.boundTo); + if (isUncacheable(btv.boundTo)) + markUncacheable(ty); + return false; } bool visit(TypeId ty, const FreeType& ft) override @@ -590,6 +616,12 @@ struct TypeCacher : TypeOnceVisitor return false; } + bool visit(TypeId ty, const ErrorType&) override + { + cache(ty); + return false; + } + bool visit(TypeId ty, const PrimitiveType&) override { cache(ty); @@ -727,6 +759,17 @@ struct TypeCacher : TypeOnceVisitor return false; } + bool visit(TypeId ty, const MetatableType& mtv) override + { + traverse(mtv.table); + traverse(mtv.metatable); + if (isUncacheable(mtv.table) || isUncacheable(mtv.metatable)) + markUncacheable(ty); + else + cache(ty); + return false; + } + bool visit(TypeId ty, const ClassType&) override { cache(ty); @@ -739,6 +782,12 @@ struct TypeCacher : TypeOnceVisitor return false; } + bool visit(TypeId ty, const NoRefineType&) override + { + cache(ty); + return false; + } + bool visit(TypeId ty, const UnionType& ut) override { if (isUncacheable(ty) || isCached(ty)) @@ -841,12 +890,31 @@ struct TypeCacher : TypeOnceVisitor return false; } + bool visit(TypePackId tp) override + { + // NOTE: `TypeCacher` should explicitly visit _all_ types and type packs, + // otherwise it's prone to marking types that cannot be cached as + // cacheable, which will segfault down the line. + LUAU_ASSERT(false); + LUAU_UNREACHABLE(); + } + bool visit(TypePackId tp, const FreeTypePack&) override { markUncacheable(tp); return false; } + bool visit(TypePackId tp, const GenericTypePack& gtp) override + { + return true; + } + + bool visit(TypePackId tp, const ErrorTypePack& etp) override + { + return true; + } + bool visit(TypePackId tp, const VariadicTypePack& vtp) override { if (isUncacheable(tp)) @@ -871,6 +939,32 @@ struct TypeCacher : TypeOnceVisitor markUncacheable(tp); return false; } + + bool visit(TypePackId tp, const BoundTypePack& btp) override + { + traverse(btp.boundTo); + if (isUncacheable(btp.boundTo)) + markUncacheable(tp); + return false; + } + + bool visit(TypePackId tp, const TypePack& typ) override + { + bool uncacheable = false; + for (TypeId ty : typ.head) + { + traverse(ty); + uncacheable |= isUncacheable(ty); + } + if (typ.tail) + { + traverse(*typ.tail); + uncacheable |= isUncacheable(*typ.tail); + } + if (uncacheable) + markUncacheable(tp); + return false; + } }; std::optional generalize( @@ -890,7 +984,7 @@ std::optional generalize( FreeTypeSearcher fts{scope, cachedTypes}; fts.traverse(ty); - MutatingGeneralizer gen{builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; + MutatingGeneralizer gen{arena, builtinTypes, scope, cachedTypes, std::move(fts.positiveTypes), std::move(fts.negativeTypes), avoidSealingTables}; gen.traverse(ty); diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 4b6d1115..79b7f03e 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -61,9 +62,7 @@ TypeId Instantiation::clean(TypeId ty) LUAU_ASSERT(ftv); FunctionType clone = FunctionType{level, scope, ftv->argTypes, ftv->retTypes, ftv->definition, ftv->hasSelf}; - clone.magicFunction = ftv->magicFunction; - clone.dcrMagicFunction = ftv->dcrMagicFunction; - clone.dcrMagicRefinement = ftv->dcrMagicRefinement; + clone.magic = ftv->magic; clone.tags = ftv->tags; clone.argNames = ftv->argNames; TypeId result = addType(std::move(clone)); @@ -165,7 +164,7 @@ TypeId ReplaceGenerics::clean(TypeId ty) } else { - return addType(FreeType{scope, level}); + return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, scope, level) : addType(FreeType{scope, level}); } } diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index a3d8b4e3..64e05993 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -227,6 +227,8 @@ static void errorToString(std::ostream& stream, const T& err) stream << "UnexpectedTypeInSubtyping { ty = '" + toString(err.ty) + "' }"; else if constexpr (std::is_same_v) stream << "UnexpectedTypePackInSubtyping { tp = '" + toString(err.tp) + "' }"; + else if constexpr (std::is_same_v) + stream << "UserDefinedTypeFunctionError { " << err.message << " }"; else if constexpr (std::is_same_v) { stream << "CannotAssignToNever { rvalueType = '" << toString(err.rhsType) << "', reason = '" << err.reason << "', cause = { "; diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index c4f46c84..a2bcb247 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -17,8 +17,7 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauAttribute) -LUAU_FASTFLAG(LuauNativeAttribute) -LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute, false) +LUAU_FASTFLAGVARIABLE(LintRedundantNativeAttribute) namespace Luau { @@ -3239,7 +3238,6 @@ static void lintComments(LintContext& context, const std::vector& ho static bool hasNativeCommentDirective(const std::vector& hotcomments) { - LUAU_ASSERT(FFlag::LuauNativeAttribute); LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); for (const HotComment& hc : hotcomments) @@ -3265,7 +3263,6 @@ struct LintRedundantNativeAttribute : AstVisitor public: LUAU_NOINLINE static void process(LintContext& context) { - LUAU_ASSERT(FFlag::LuauNativeAttribute); LUAU_ASSERT(FFlag::LintRedundantNativeAttribute); LintRedundantNativeAttribute pass; @@ -3389,7 +3386,7 @@ std::vector lint( if (context.warningEnabled(LintWarning::Code_ComparisonPrecedence)) LintComparisonPrecedence::process(context); - if (FFlag::LuauNativeAttribute && FFlag::LintRedundantNativeAttribute && context.warningEnabled(LintWarning::Code_RedundantNativeAttribute)) + if (FFlag::LintRedundantNativeAttribute && context.warningEnabled(LintWarning::Code_RedundantNativeAttribute)) { if (hasNativeCommentDirective(hotcomments)) LintRedundantNativeAttribute::process(context); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 3a049216..1dbd6608 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,11 +15,32 @@ #include LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAGVARIABLE(LuauIncrementalAutocompleteCommentDetection) namespace Luau { -static bool contains(Position pos, Comment comment) +static void defaultLogLuau(std::string_view input) +{ + // The default is to do nothing because we don't want to mess with + // the xml parsing done by the dcr script. +} + +Luau::LogLuauProc logLuau = &defaultLogLuau; + +void setLogLuau(LogLuauProc ll) +{ + logLuau = ll; +} + +void resetLogLuauProc() +{ + logLuau = &defaultLogLuau; +} + + + +static bool contains_DEPRECATED(Position pos, Comment comment) { if (comment.location.contains(pos)) return true; @@ -32,7 +53,22 @@ static bool contains(Position pos, Comment comment) return false; } -static bool isWithinComment(const std::vector& commentLocations, Position pos) +static bool contains(Position pos, Comment comment) +{ + if (comment.location.contains(pos)) + return true; + else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't + // have an end + return true; + // comments actually span the whole line - in incremental mode, we could pass a cursor outside of the current parsed comment range span, but it + // would still be 'within' the comment So, the cursor must be on the same line and the comment itself must come strictly after the `begin` + else if (comment.type == Lexeme::Comment && comment.location.end.line == pos.line && comment.location.begin <= pos) + return true; + else + return false; +} + +bool isWithinComment(const std::vector& commentLocations, Position pos) { auto iter = std::lower_bound( commentLocations.begin(), @@ -40,6 +76,11 @@ static bool isWithinComment(const std::vector& commentLocations, Positi Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) { + if (FFlag::LuauIncrementalAutocompleteCommentDetection) + { + if (a.type == Lexeme::Comment) + return a.location.end.line < b.location.end.line; + } return a.location.end < b.location.end; } ); @@ -47,7 +88,7 @@ static bool isWithinComment(const std::vector& commentLocations, Positi if (iter == commentLocations.end()) return false; - if (contains(pos, *iter)) + if (FFlag::LuauIncrementalAutocompleteCommentDetection ? contains(pos, *iter) : contains_DEPRECATED(pos, *iter)) return true; // Due to the nature of std::lower_bound, it is possible that iter points at a comment that ends @@ -131,10 +172,32 @@ struct ClonePublicInterface : Substitution } ftv->level = TypeLevel{0, 0}; + if (FFlag::LuauSolverV2) + ftv->scope = nullptr; } else if (TableType* ttv = getMutable(result)) { ttv->level = TypeLevel{0, 0}; + if (FFlag::LuauSolverV2) + ttv->scope = nullptr; + } + + if (FFlag::LuauSolverV2) + { + if (auto freety = getMutable(result)) + { + module->errors.emplace_back( + freety->scope->location, + module->name, + InternalError{"Free type is escaping its module; please report this bug at " + "https://github.com/luau-lang/luau/issues"} + ); + result = builtinTypes->errorRecoveryType(); + } + else if (auto genericty = getMutable(result)) + { + genericty->scope = nullptr; + } } return result; @@ -142,7 +205,27 @@ struct ClonePublicInterface : Substitution TypePackId clean(TypePackId tp) override { - return clone(tp); + if (FFlag::LuauSolverV2) + { + auto clonedTp = clone(tp); + if (auto ftp = getMutable(clonedTp)) + { + module->errors.emplace_back( + ftp->scope->location, + module->name, + InternalError{"Free type pack is escaping its module; please report this bug at " + "https://github.com/luau-lang/luau/issues"} + ); + clonedTp = builtinTypes->errorRecoveryTypePack(); + } + else if (auto gtp = getMutable(clonedTp)) + gtp->scope = nullptr; + return clonedTp; + } + else + { + return clone(tp); + } } TypeId cloneType(TypeId ty) diff --git a/Analysis/src/NonStrictTypeChecker.cpp b/Analysis/src/NonStrictTypeChecker.cpp index 116cf5cb..93a02c3f 100644 --- a/Analysis/src/NonStrictTypeChecker.cpp +++ b/Analysis/src/NonStrictTypeChecker.cpp @@ -14,11 +14,15 @@ #include "Luau/TypeFunction.h" #include "Luau/Def.h" #include "Luau/ToString.h" -#include "Luau/TypeFwd.h" +#include "Luau/TypeUtils.h" #include #include +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) +LUAU_FASTFLAGVARIABLE(LuauNonStrictVisitorImprovements) +LUAU_FASTFLAGVARIABLE(LuauNewNonStrictWarnOnUnknownGlobals) + namespace Luau { @@ -154,8 +158,9 @@ private: struct NonStrictTypeChecker { - NotNull builtinTypes; + NotNull simplifier; + NotNull typeFunctionRuntime; const NotNull ice; NotNull arena; Module* module; @@ -171,6 +176,8 @@ struct NonStrictTypeChecker NonStrictTypeChecker( NotNull arena, NotNull builtinTypes, + NotNull simplifier, + NotNull typeFunctionRuntime, const NotNull ice, NotNull unifierState, NotNull dfg, @@ -178,11 +185,13 @@ struct NonStrictTypeChecker Module* module ) : builtinTypes(builtinTypes) + , simplifier(simplifier) + , typeFunctionRuntime(typeFunctionRuntime) , ice(ice) , arena(arena) , module(module) , normalizer{arena, builtinTypes, unifierState, /* cache inhabitance */ true} - , subtyping{builtinTypes, arena, NotNull(&normalizer), ice} + , subtyping{builtinTypes, arena, simplifier, NotNull(&normalizer), typeFunctionRuntime, ice} , dfg(dfg) , limits(limits) { @@ -204,7 +213,7 @@ struct NonStrictTypeChecker return *fst; else if (auto ftp = get(pack)) { - TypeId result = arena->addType(FreeType{ftp->scope}); + TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtinTypes, ftp->scope) : arena->addType(FreeType{ftp->scope}); TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope}); TypePack* resultPack = emplaceTypePack(asMutable(pack)); @@ -213,7 +222,7 @@ struct NonStrictTypeChecker return result; } - else if (get(pack)) + else if (get(pack)) return builtinTypes->errorRecoveryType(); else if (finite(pack) && size(pack) == 0) return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil` @@ -228,7 +237,12 @@ struct NonStrictTypeChecker return instance; ErrorVec errors = - reduceTypeFunctions(instance, location, TypeFunctionContext{arena, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, true) + reduceTypeFunctions( + instance, + location, + TypeFunctionContext{arena, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits}, + true + ) .errors; if (errors.empty()) @@ -329,8 +343,9 @@ struct NonStrictTypeChecker NonStrictContext visit(AstStatIf* ifStatement) { - NonStrictContext condB = visit(ifStatement->condition); + NonStrictContext condB = visit(ifStatement->condition, ValueContext::RValue); NonStrictContext branchContext; + // If there is no else branch, don't bother generating warnings for the then branch - we can't prove there is an error if (ifStatement->elsebody) { @@ -338,17 +353,32 @@ struct NonStrictTypeChecker NonStrictContext elseBody = visit(ifStatement->elsebody); branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody); } + return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext); } NonStrictContext visit(AstStatWhile* whileStatement) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + NonStrictContext condition = visit(whileStatement->condition, ValueContext::RValue); + NonStrictContext body = visit(whileStatement->body); + return NonStrictContext::disjunction(builtinTypes, arena, condition, body); + } + else + return {}; } NonStrictContext visit(AstStatRepeat* repeatStatement) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + NonStrictContext body = visit(repeatStatement->body); + NonStrictContext condition = visit(repeatStatement->condition, ValueContext::RValue); + return NonStrictContext::disjunction(builtinTypes, arena, body, condition); + } + else + return {}; } NonStrictContext visit(AstStatBreak* breakStatement) @@ -363,49 +393,94 @@ struct NonStrictTypeChecker NonStrictContext visit(AstStatReturn* returnStatement) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + // TODO: this is believing existing code, but i'm not sure if this makes sense + // for how the contexts are handled + for (AstExpr* expr : returnStatement->list) + visit(expr, ValueContext::RValue); + } + return {}; } NonStrictContext visit(AstStatExpr* expr) { - return visit(expr->expr); + return visit(expr->expr, ValueContext::RValue); } NonStrictContext visit(AstStatLocal* local) { for (AstExpr* rhs : local->values) - visit(rhs); + visit(rhs, ValueContext::RValue); return {}; } NonStrictContext visit(AstStatFor* forStatement) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + // TODO: throwing out context based on same principle as existing code? + if (forStatement->from) + visit(forStatement->from, ValueContext::RValue); + if (forStatement->to) + visit(forStatement->to, ValueContext::RValue); + if (forStatement->step) + visit(forStatement->step, ValueContext::RValue); + return visit(forStatement->body); + } + else + { + return {}; + } } NonStrictContext visit(AstStatForIn* forInStatement) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (AstExpr* rhs : forInStatement->values) + visit(rhs, ValueContext::RValue); + return visit(forInStatement->body); + } + else + { + return {}; + } } NonStrictContext visit(AstStatAssign* assign) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (AstExpr* lhs : assign->vars) + visit(lhs, ValueContext::LValue); + for (AstExpr* rhs : assign->values) + visit(rhs, ValueContext::RValue); + } + return {}; } NonStrictContext visit(AstStatCompoundAssign* compoundAssign) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + visit(compoundAssign->var, ValueContext::LValue); + visit(compoundAssign->value, ValueContext::RValue); + } + return {}; } NonStrictContext visit(AstStatFunction* statFn) { - return visit(statFn->func); + return visit(statFn->func, ValueContext::RValue); } NonStrictContext visit(AstStatLocalFunction* localFn) { - return visit(localFn->func); + return visit(localFn->func, ValueContext::RValue); } NonStrictContext visit(AstStatTypeAlias* typeAlias) @@ -415,7 +490,6 @@ struct NonStrictTypeChecker NonStrictContext visit(AstStatTypeFunction* typeFunc) { - reportError(GenericError{"This syntax is not supported"}, typeFunc->location); return {}; } @@ -436,14 +510,22 @@ struct NonStrictTypeChecker NonStrictContext visit(AstStatError* error) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (AstStat* stat : error->statements) + visit(stat); + for (AstExpr* expr : error->expressions) + visit(expr, ValueContext::RValue); + } + return {}; } - NonStrictContext visit(AstExpr* expr) + NonStrictContext visit(AstExpr* expr, ValueContext context) { auto pusher = pushStack(expr); if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) @@ -453,17 +535,17 @@ struct NonStrictTypeChecker else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) - return visit(e); + return visit(e, context); else if (auto e = expr->as()) return visit(e); else if (auto e = expr->as()) @@ -487,9 +569,12 @@ struct NonStrictTypeChecker } } - NonStrictContext visit(AstExprGroup* group) + NonStrictContext visit(AstExprGroup* group, ValueContext context) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + return visit(group->expr, context); + else + return {}; } NonStrictContext visit(AstExprConstantNil* expr) @@ -512,22 +597,34 @@ struct NonStrictTypeChecker return {}; } - NonStrictContext visit(AstExprLocal* local) + NonStrictContext visit(AstExprLocal* local, ValueContext context) { return {}; } - NonStrictContext visit(AstExprGlobal* global) + NonStrictContext visit(AstExprGlobal* global, ValueContext context) { + if (FFlag::LuauNewNonStrictWarnOnUnknownGlobals) + { + // We don't file unknown symbols for LValues. + if (context == ValueContext::LValue) + return {}; + + NotNull scope = stack.back(); + if (!scope->lookup(global->name)) + { + reportError(UnknownSymbol{global->name.value, UnknownSymbol::Binding}, global->location); + } + } + return {}; } - NonStrictContext visit(AstExprVarargs* global) + NonStrictContext visit(AstExprVarargs* varargs) { return {}; } - NonStrictContext visit(AstExprCall* call) { NonStrictContext fresh{}; @@ -536,106 +633,126 @@ struct NonStrictTypeChecker return fresh; TypeId fnTy = *originalCallTy; - if (auto fn = get(follow(fnTy))) + if (auto fn = get(follow(fnTy)); fn && fn->isCheckedFunction) { - if (fn->isCheckedFunction) + // We know fn is a checked function, which means it looks like: + // (S1, ... SN) -> T & + // (~S1, unknown^N-1) -> error & + // (unknown, ~S2, unknown^N-2) -> error + // ... + // ... + // (unknown^N-1, ~S_N) -> error + + std::vector arguments; + arguments.reserve(call->args.size + (call->self ? 1 : 0)); + if (call->self) { - // We know fn is a checked function, which means it looks like: - // (S1, ... SN) -> T & - // (~S1, unknown^N-1) -> error & - // (unknown, ~S2, unknown^N-2) -> error - // ... - // ... - // (unknown^N-1, ~S_N) -> error - std::vector argTypes; - argTypes.reserve(call->args.size); - // Pad out the arg types array with the types you would expect to see - TypePackIterator curr = begin(fn->argTypes); - TypePackIterator fin = end(fn->argTypes); - while (curr != fin) + if (auto indexExpr = call->func->as()) + arguments.push_back(indexExpr->expr); + else + ice->ice("method call expression has no 'self'"); + } + arguments.insert(arguments.end(), call->args.begin(), call->args.end()); + + std::vector argTypes; + argTypes.reserve(arguments.size()); + + // Move all the types over from the argument typepack for `fn` + TypePackIterator curr = begin(fn->argTypes); + TypePackIterator fin = end(fn->argTypes); + for (; curr != fin; curr++) + argTypes.push_back(*curr); + + // Pad out the rest with the variadic as needed. + if (auto argTail = curr.tail()) + { + if (const VariadicTypePack* vtp = get(follow(*argTail))) { - argTypes.push_back(*curr); - ++curr; - } - if (auto argTail = curr.tail()) - { - if (const VariadicTypePack* vtp = get(follow(*argTail))) + while (argTypes.size() < arguments.size()) { - while (argTypes.size() < call->args.size) - { - argTypes.push_back(vtp->ty); - } + argTypes.push_back(vtp->ty); } } + } - std::string functionName = getFunctionNameAsString(*call->func).value_or(""); - if (call->args.size > argTypes.size()) + std::string functionName = getFunctionNameAsString(*call->func).value_or(""); + if (arguments.size() > argTypes.size()) + { + // We are passing more arguments than we expect, so we should error + reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), arguments.size()}, call->location); + return fresh; + } + + for (size_t i = 0; i < arguments.size(); i++) + { + // For example, if the arg is "hi" + // The actual arg type is string + // The expected arg type is number + // The type of the argument in the overload is ~number + // We will compare arg and ~number + AstExpr* arg = arguments[i]; + TypeId expectedArgType = argTypes[i]; + std::shared_ptr norm = normalizer.normalize(expectedArgType); + DefId def = dfg->getDef(arg); + TypeId runTimeErrorTy; + // If we're dealing with any, negating any will cause all subtype tests to fail + // However, when someone calls this function, they're going to want to be able to pass it anything, + // for that reason, we manually inject never into the context so that the runtime test will always pass. + if (!norm) + reportError(NormalizationTooComplex{}, arg->location); + + if (norm && get(norm->tops)) + runTimeErrorTy = builtinTypes->neverType; + else + runTimeErrorTy = getOrCreateNegation(expectedArgType); + fresh.addContext(def, runTimeErrorTy); + } + + // Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types + for (size_t i = 0; i < arguments.size(); i++) + { + AstExpr* arg = arguments[i]; + if (auto runTimeFailureType = willRunTimeError(arg, fresh)) + reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location); + } + + if (arguments.size() < argTypes.size()) + { + // We are passing fewer arguments than we expect + // so we need to ensure that the rest of the args are optional. + bool remainingArgsOptional = true; + for (size_t i = arguments.size(); i < argTypes.size(); i++) + remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]); + + if (!remainingArgsOptional) { - // We are passing more arguments than we expect, so we should error - reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); + reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), arguments.size()}, call->location); return fresh; } - - for (size_t i = 0; i < call->args.size; i++) - { - // For example, if the arg is "hi" - // The actual arg type is string - // The expected arg type is number - // The type of the argument in the overload is ~number - // We will compare arg and ~number - AstExpr* arg = call->args.data[i]; - TypeId expectedArgType = argTypes[i]; - std::shared_ptr norm = normalizer.normalize(expectedArgType); - DefId def = dfg->getDef(arg); - TypeId runTimeErrorTy; - // If we're dealing with any, negating any will cause all subtype tests to fail, since ~any is any - // However, when someone calls this function, they're going to want to be able to pass it anything, - // for that reason, we manually inject never into the context so that the runtime test will always pass. - if (!norm) - reportError(NormalizationTooComplex{}, arg->location); - - if (norm && get(norm->tops)) - runTimeErrorTy = builtinTypes->neverType; - else - runTimeErrorTy = getOrCreateNegation(expectedArgType); - fresh.addContext(def, runTimeErrorTy); - } - - // Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types - for (size_t i = 0; i < call->args.size; i++) - { - AstExpr* arg = call->args.data[i]; - if (auto runTimeFailureType = willRunTimeError(arg, fresh)) - reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location); - } - - if (call->args.size < argTypes.size()) - { - // We are passing fewer arguments than we expect - // so we need to ensure that the rest of the args are optional. - bool remainingArgsOptional = true; - for (size_t i = call->args.size; i < argTypes.size(); i++) - remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]); - if (!remainingArgsOptional) - { - reportError(CheckedFunctionIncorrectArgs{functionName, argTypes.size(), call->args.size}, call->location); - return fresh; - } - } } } return fresh; } - NonStrictContext visit(AstExprIndexName* indexName) + NonStrictContext visit(AstExprIndexName* indexName, ValueContext context) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + return visit(indexName->expr, context); + else + return {}; } - NonStrictContext visit(AstExprIndexExpr* indexExpr) + NonStrictContext visit(AstExprIndexExpr* indexExpr, ValueContext context) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + NonStrictContext expr = visit(indexExpr->expr, context); + NonStrictContext index = visit(indexExpr->index, ValueContext::RValue); + return NonStrictContext::disjunction(builtinTypes, arena, expr, index); + } + else + return {}; } NonStrictContext visit(AstExprFunction* exprFn) @@ -654,39 +771,74 @@ struct NonStrictTypeChecker NonStrictContext visit(AstExprTable* table) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (auto [_, key, value] : table->items) + { + if (key) + visit(key, ValueContext::RValue); + visit(value, ValueContext::RValue); + } + } + return {}; } NonStrictContext visit(AstExprUnary* unary) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + return visit(unary->expr, ValueContext::RValue); + else + return {}; } NonStrictContext visit(AstExprBinary* binary) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + { + NonStrictContext lhs = visit(binary->left, ValueContext::RValue); + NonStrictContext rhs = visit(binary->right, ValueContext::RValue); + return NonStrictContext::disjunction(builtinTypes, arena, lhs, rhs); + } + else + return {}; } NonStrictContext visit(AstExprTypeAssertion* typeAssertion) { - return {}; + if (FFlag::LuauNonStrictVisitorImprovements) + return visit(typeAssertion->expr, ValueContext::RValue); + else + return {}; } NonStrictContext visit(AstExprIfElse* ifElse) { - NonStrictContext condB = visit(ifElse->condition); - NonStrictContext thenB = visit(ifElse->trueExpr); - NonStrictContext elseB = visit(ifElse->falseExpr); + NonStrictContext condB = visit(ifElse->condition, ValueContext::RValue); + NonStrictContext thenB = visit(ifElse->trueExpr, ValueContext::RValue); + NonStrictContext elseB = visit(ifElse->falseExpr, ValueContext::RValue); return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB)); } NonStrictContext visit(AstExprInterpString* interpString) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (AstExpr* expr : interpString->expressions) + visit(expr, ValueContext::RValue); + } + return {}; } NonStrictContext visit(AstExprError* error) { + if (FFlag::LuauNonStrictVisitorImprovements) + { + for (AstExpr* expr : error->expressions) + visit(expr, ValueContext::RValue); + } + return {}; } @@ -754,6 +906,8 @@ private: void checkNonStrict( NotNull builtinTypes, + NotNull simplifier, + NotNull typeFunctionRuntime, NotNull ice, NotNull unifierState, NotNull dfg, @@ -764,7 +918,9 @@ void checkNonStrict( { LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking"); - NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, ice, unifierState, dfg, limits, module}; + NonStrictTypeChecker typeChecker{ + NotNull{&module->internalTypes}, builtinTypes, simplifier, typeFunctionRuntime, ice, unifierState, dfg, limits, module + }; typeChecker.visit(sourceModule.root); unfreeze(module->interfaceTypes); copyErrors(module->errors, module->interfaceTypes, builtinTypes); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index 2db2f40c..864c12a8 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -15,36 +15,17 @@ #include "Luau/TypeFwd.h" #include "Luau/Unifier.h" -LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant, false) -LUAU_FASTFLAGVARIABLE(LuauNormalizeAwayUninhabitableTables, false) -LUAU_FASTFLAGVARIABLE(LuauNormalizeNotUnknownIntersection, false); -LUAU_FASTFLAGVARIABLE(LuauFixReduceStackPressure, false); -LUAU_FASTFLAGVARIABLE(LuauFixCyclicTablesBlowingStack, false); +LUAU_FASTFLAGVARIABLE(DebugLuauCheckNormalizeInvariant) -// This could theoretically be 2000 on amd64, but x86 requires this. -LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200); -LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000); -LUAU_FASTFLAG(LuauSolverV2); - -static bool fixReduceStackPressure() -{ - return FFlag::LuauFixReduceStackPressure || FFlag::LuauSolverV2; -} - -static bool fixCyclicTablesBlowingStack() -{ - return FFlag::LuauFixCyclicTablesBlowingStack || FFlag::LuauSolverV2; -} +LUAU_FASTINTVARIABLE(LuauNormalizeCacheLimit, 100000) +LUAU_FASTINTVARIABLE(LuauNormalizeIntersectionLimit, 200) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(LuauFixInfiniteRecursionInNormalization) +LUAU_FASTFLAGVARIABLE(LuauFixNormalizedIntersectionOfNegatedClass) namespace Luau { -// helper to make `FFlag::LuauNormalizeAwayUninhabitableTables` not explicitly required when DCR is enabled. -static bool normalizeAwayUninhabitableTables() -{ - return FFlag::LuauNormalizeAwayUninhabitableTables || FFlag::LuauSolverV2; -} - static bool shouldEarlyExit(NormalizationResult res) { // if res is hit limits, return control flow @@ -589,10 +570,11 @@ NormalizationResult Normalizer::isInhabited(TypeId ty, Set& seen) NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right) { Set seen{nullptr}; - return isIntersectionInhabited(left, right, seen); + SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}}; + return isIntersectionInhabited(left, right, seenTablePropPairs, seen); } -NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, Set& seenSet) +NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId right, SeenTablePropPairs& seenTablePropPairs, Set& seenSet) { left = follow(left); right = follow(right); @@ -605,7 +587,7 @@ NormalizationResult Normalizer::isIntersectionInhabited(TypeId left, TypeId righ } NormalizedType norm{builtinTypes}; - NormalizationResult res = normalizeIntersections({left, right}, norm, seenSet); + NormalizationResult res = normalizeIntersections({left, right}, norm, seenTablePropPairs, seenSet); if (res != NormalizationResult::True) { if (cacheInhabitance && res == NormalizationResult::False) @@ -956,7 +938,8 @@ std::shared_ptr Normalizer::normalize(TypeId ty) NormalizedType norm{builtinTypes}; Set seenSetTypes{nullptr}; - NormalizationResult res = unionNormalWithTy(norm, ty, seenSetTypes); + SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}}; + NormalizationResult res = unionNormalWithTy(norm, ty, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) return nullptr; @@ -974,7 +957,12 @@ std::shared_ptr Normalizer::normalize(TypeId ty) return shared; } -NormalizationResult Normalizer::normalizeIntersections(const std::vector& intersections, NormalizedType& outType, Set& seenSet) +NormalizationResult Normalizer::normalizeIntersections( + const std::vector& intersections, + NormalizedType& outType, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSet +) { if (!arena) sharedState->iceHandler->ice("Normalizing types outside a module"); @@ -983,7 +971,7 @@ NormalizationResult Normalizer::normalizeIntersections(const std::vector // Now we need to intersect the two types for (auto ty : intersections) { - NormalizationResult res = intersectNormalWithTy(norm, ty, seenSet); + NormalizationResult res = intersectNormalWithTy(norm, ty, seenTablePropPairs, seenSet); if (res != NormalizationResult::True) return res; } @@ -1620,7 +1608,7 @@ void Normalizer::unionTablesWithTable(TypeIds& heres, TypeId there) // TODO: remove unions of tables where possible // we can always skip `never` - if (normalizeAwayUninhabitableTables() && get(there)) + if (get(there)) return; heres.insert(there); @@ -1747,7 +1735,13 @@ NormalizationResult Normalizer::intersectNormalWithNegationTy(TypeId toNegate, N } // See above for an explaination of `ignoreSmallerTyvars`. -NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes, int ignoreSmallerTyvars) +NormalizationResult Normalizer::unionNormalWithTy( + NormalizedType& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes, + int ignoreSmallerTyvars +) { RecursionCounter _rc(&sharedState->counters.recursionCount); if (!withinResourceLimits()) @@ -1779,7 +1773,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) { - NormalizationResult res = unionNormalWithTy(here, *it, seenSetTypes); + NormalizationResult res = unionNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) { seenSetTypes.erase(there); @@ -1800,7 +1794,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t norm.tops = builtinTypes->anyType; for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) { - NormalizationResult res = intersectNormalWithTy(norm, *it, seenSetTypes); + NormalizationResult res = intersectNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) { seenSetTypes.erase(there); @@ -1814,7 +1808,8 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t } else if (get(here.tops)) return NormalizationResult::True; - else if (get(there) || get(there) || get(there) || get(there) || get(there)) + else if (get(there) || get(there) || get(there) || get(there) || + get(there)) { if (tyvarIndex(there) <= ignoreSmallerTyvars) return NormalizationResult::True; @@ -1891,7 +1886,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t if (res != NormalizationResult::True) return res; } - else if (get(there) || get(there)) + else if (get(there) || get(there) || get(there)) { // nothing } @@ -1900,7 +1895,7 @@ NormalizationResult Normalizer::unionNormalWithTy(NormalizedType& here, TypeId t for (auto& [tyvar, intersect] : here.tyvars) { - NormalizationResult res = unionNormalWithTy(*intersect, there, seenSetTypes, tyvarIndex(tyvar)); + NormalizationResult res = unionNormalWithTy(*intersect, there, seenTablePropPairs, seenSetTypes, tyvarIndex(tyvar)); if (res != NormalizationResult::True) return res; } @@ -2289,9 +2284,24 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th else if (isSubclass(there, hereTy)) { TypeIds negations = std::move(hereNegations); + bool emptyIntersectWithNegation = false; for (auto nIt = negations.begin(); nIt != negations.end();) { + if (FFlag::LuauFixNormalizedIntersectionOfNegatedClass && isSubclass(there, *nIt)) + { + // Hitting this block means that the incoming class is a + // subclass of this type, _and_ one of its negations is a + // superclass of this type, e.g.: + // + // Dog & ~Animal + // + // Clearly this intersects to never, so we mark this class as + // being removed from the normalized class type. + emptyIntersectWithNegation = true; + break; + } + if (!isSubclass(*nIt, there)) { nIt = negations.erase(nIt); @@ -2304,7 +2314,8 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th it = heres.ordering.erase(it); heres.classes.erase(hereTy); - heres.pushPair(there, std::move(negations)); + if (!emptyIntersectWithNegation) + heres.pushPair(there, std::move(negations)); break; } // If the incoming class is a superclass of the current class, we don't @@ -2510,7 +2521,7 @@ std::optional Normalizer::intersectionOfTypePacks(TypePackId here, T return arena->addTypePack({}); } -std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there, Set& seenSet) +std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSet) { if (here == there) return here; @@ -2589,49 +2600,60 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there { if (tprop.readTy.has_value()) { - // if the intersection of the read types of a property is uninhabited, the whole table is `never`. - if (fixReduceStackPressure()) + if (FFlag::LuauFixInfiniteRecursionInNormalization) { - // We've seen these table prop elements before and we're about to ask if their intersection - // is inhabited - if (fixCyclicTablesBlowingStack()) - { - if (seenSet.contains(*hprop.readTy) && seenSet.contains(*tprop.readTy)) - { - seenSet.erase(*hprop.readTy); - seenSet.erase(*tprop.readTy); - return {builtinTypes->neverType}; - } - else - { - seenSet.insert(*hprop.readTy); - seenSet.insert(*tprop.readTy); - } - } + TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; - NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy); - - // Cleanup - if (fixCyclicTablesBlowingStack()) - { - seenSet.erase(*hprop.readTy); - seenSet.erase(*tprop.readTy); - } - - if (normalizeAwayUninhabitableTables() && NormalizationResult::True != res) + // If any property is going to get mapped to `never`, we can just call the entire table `never`. + // Since this check is syntactic, we may sometimes miss simplifying tables with complex uninhabited properties. + // Prior versions of this code attempted to do this semantically using the normalization machinery, but this + // mistakenly causes infinite loops when giving more complex recursive table types. As it stands, this approach + // will continue to scale as simplification is improved, but we may wish to reintroduce the semantic approach + // once we have revisited the usage of seen sets systematically (and possibly with some additional guarding to recognize + // when types are infinitely-recursive with non-pointer identical instances of them, or some guard to prevent that + // construction altogether). See also: `gh1632_no_infinite_recursion_in_normalization` + if (get(ty)) return {builtinTypes->neverType}; + + prop.readTy = ty; + hereSubThere &= (ty == hprop.readTy); + thereSubHere &= (ty == tprop.readTy); } else { - if (normalizeAwayUninhabitableTables() && - NormalizationResult::False == isIntersectionInhabited(*hprop.readTy, *tprop.readTy)) - return {builtinTypes->neverType}; - } + // if the intersection of the read types of a property is uninhabited, the whole table is `never`. + // We've seen these table prop elements before and we're about to ask if their intersection + // is inhabited - TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; - prop.readTy = ty; - hereSubThere &= (ty == hprop.readTy); - thereSubHere &= (ty == tprop.readTy); + auto pair1 = std::pair{*hprop.readTy, *tprop.readTy}; + auto pair2 = std::pair{*tprop.readTy, *hprop.readTy}; + if (seenTablePropPairs.contains(pair1) || seenTablePropPairs.contains(pair2)) + { + seenTablePropPairs.erase(pair1); + seenTablePropPairs.erase(pair2); + return {builtinTypes->neverType}; + } + else + { + seenTablePropPairs.insert(pair1); + seenTablePropPairs.insert(pair2); + } + + // FIXME(ariel): this is being added in a flag removal, so not changing the semantics here, but worth noting that this + // fresh `seenSet` is definitely a bug. we already have `seenSet` from the parameter that _should_ have been used here. + Set seenSet{nullptr}; + NormalizationResult res = isIntersectionInhabited(*hprop.readTy, *tprop.readTy, seenTablePropPairs, seenSet); + + seenTablePropPairs.erase(pair1); + seenTablePropPairs.erase(pair2); + if (NormalizationResult::True != res) + return {builtinTypes->neverType}; + + TypeId ty = simplifyIntersection(builtinTypes, NotNull{arena}, *hprop.readTy, *tprop.readTy).result; + prop.readTy = ty; + hereSubThere &= (ty == hprop.readTy); + thereSubHere &= (ty == tprop.readTy); + } } else { @@ -2737,7 +2759,7 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there if (tmtable && hmtable) { // NOTE: this assumes metatables are ivariant - if (std::optional mtable = intersectionOfTables(hmtable, tmtable, seenSet)) + if (std::optional mtable = intersectionOfTables(hmtable, tmtable, seenTablePropPairs, seenSet)) { if (table == htable && *mtable == hmtable) return here; @@ -2767,12 +2789,12 @@ std::optional Normalizer::intersectionOfTables(TypeId here, TypeId there return table; } -void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, Set& seenSetTypes) +void Normalizer::intersectTablesWithTable(TypeIds& heres, TypeId there, SeenTablePropPairs& seenTablePropPairs, Set& seenSetTypes) { TypeIds tmp; for (TypeId here : heres) { - if (std::optional inter = intersectionOfTables(here, there, seenSetTypes)) + if (std::optional inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes)) tmp.insert(*inter); } heres.retain(tmp); @@ -2787,7 +2809,8 @@ void Normalizer::intersectTables(TypeIds& heres, const TypeIds& theres) for (TypeId there : theres) { Set seenSetTypes{nullptr}; - if (std::optional inter = intersectionOfTables(here, there, seenSetTypes)) + SeenTablePropPairs seenTablePropPairs{{nullptr, nullptr}}; + if (std::optional inter = intersectionOfTables(here, there, seenTablePropPairs, seenSetTypes)) tmp.insert(*inter); } } @@ -3005,12 +3028,17 @@ void Normalizer::intersectFunctions(NormalizedFunctionType& heres, const Normali } } -NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, TypeId there, Set& seenSetTypes) +NormalizationResult Normalizer::intersectTyvarsWithTy( + NormalizedTyvars& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes +) { for (auto it = here.begin(); it != here.end();) { NormalizedType& inter = *it->second; - NormalizationResult res = intersectNormalWithTy(inter, there, seenSetTypes); + NormalizationResult res = intersectNormalWithTy(inter, there, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) return res; if (isShallowInhabited(inter)) @@ -3024,6 +3052,10 @@ NormalizationResult Normalizer::intersectTyvarsWithTy(NormalizedTyvars& here, Ty // See above for an explaination of `ignoreSmallerTyvars`. NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const NormalizedType& there, int ignoreSmallerTyvars) { + RecursionCounter _rc(&sharedState->counters.recursionCount); + if (!withinResourceLimits()) + return NormalizationResult::HitLimits; + if (!get(there.tops)) { here.tops = intersectionOfTops(here.tops, there.tops); @@ -3035,6 +3067,11 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor return unionNormals(here, there, ignoreSmallerTyvars); } + // Limit based on worst-case expansion of the table intersection + // This restriction can be relaxed when table intersection simplification is improved + if (here.tables.size() * there.tables.size() >= size_t(FInt::LuauNormalizeIntersectionLimit)) + return NormalizationResult::HitLimits; + here.booleans = intersectionOfBools(here.booleans, there.booleans); intersectClasses(here.classes, there.classes); @@ -3088,7 +3125,12 @@ NormalizationResult Normalizer::intersectNormals(NormalizedType& here, const Nor return NormalizationResult::True; } -NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, TypeId there, Set& seenSetTypes) +NormalizationResult Normalizer::intersectNormalWithTy( + NormalizedType& here, + TypeId there, + SeenTablePropPairs& seenTablePropPairs, + Set& seenSetTypes +) { RecursionCounter _rc(&sharedState->counters.recursionCount); if (!withinResourceLimits()) @@ -3104,14 +3146,14 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type else if (!get(here.tops)) { clearNormal(here); - return unionNormalWithTy(here, there, seenSetTypes); + return unionNormalWithTy(here, there, seenTablePropPairs, seenSetTypes); } else if (const UnionType* utv = get(there)) { NormalizedType norm{builtinTypes}; for (UnionTypeIterator it = begin(utv); it != end(utv); ++it) { - NormalizationResult res = unionNormalWithTy(norm, *it, seenSetTypes); + NormalizationResult res = unionNormalWithTy(norm, *it, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) return res; } @@ -3121,13 +3163,14 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { for (IntersectionTypeIterator it = begin(itv); it != end(itv); ++it) { - NormalizationResult res = intersectNormalWithTy(here, *it, seenSetTypes); + NormalizationResult res = intersectNormalWithTy(here, *it, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) return res; } return NormalizationResult::True; } - else if (get(there) || get(there) || get(there) || get(there) || get(there)) + else if (get(there) || get(there) || get(there) || get(there) || + get(there)) { NormalizedType thereNorm{builtinTypes}; NormalizedType topNorm{builtinTypes}; @@ -3150,7 +3193,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { TypeIds tables = std::move(here.tables); clearNormal(here); - intersectTablesWithTable(tables, there, seenSetTypes); + intersectTablesWithTable(tables, there, seenTablePropPairs, seenSetTypes); here.tables = std::move(tables); } else if (get(there)) @@ -3243,13 +3286,18 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type // assumption that it is the same as any. return NormalizationResult::True; } + else if (get(t)) + { + // `*no-refine*` means we will never do anything to affect the intersection. + return NormalizationResult::True; + } else if (get(t)) { // if we're intersecting with `~never`, this is equivalent to intersecting with `unknown` // this is a noop since an intersection with `unknown` is trivial. return NormalizationResult::True; } - else if ((FFlag::LuauNormalizeNotUnknownIntersection || FFlag::LuauSolverV2) && get(t)) + else if (get(t)) { // if we're intersecting with `~unknown`, this is equivalent to intersecting with `never` // this means we should clear the type entirely. @@ -3257,7 +3305,7 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type return NormalizationResult::True; } else if (auto nt = get(t)) - return intersectNormalWithTy(here, nt->ty, seenSetTypes); + return intersectNormalWithTy(here, nt->ty, seenTablePropPairs, seenSetTypes); else { // TODO negated unions, intersections, table, and function. @@ -3269,10 +3317,15 @@ NormalizationResult Normalizer::intersectNormalWithTy(NormalizedType& here, Type { here.classes.resetToNever(); } + else if (get(there)) + { + // `*no-refine*` means we will never do anything to affect the intersection. + return NormalizationResult::True; + } else LUAU_ASSERT(!"Unreachable"); - NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenSetTypes); + NormalizationResult res = intersectTyvarsWithTy(tyvars, there, seenTablePropPairs, seenSetTypes); if (res != NormalizationResult::True) return res; here.tyvars = std::move(tyvars); @@ -3420,16 +3473,27 @@ TypeId Normalizer::typeFromNormal(const NormalizedType& norm) return arena->addType(UnionType{std::move(result)}); } -bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +bool isSubtype( + TypeId subTy, + TypeId superTy, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +) { UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{ + NotNull{&ice}, NotNull{&limits} + }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime // Subtyping under DCR is not implemented using unification! if (FFlag::LuauSolverV2) { - Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}}; + Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; return subtyping.isSubtype(subTy, superTy, scope).isSubtype; } @@ -3442,16 +3506,27 @@ bool isSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) +bool isSubtype( + TypePackId subPack, + TypePackId superPack, + NotNull scope, + NotNull builtinTypes, + NotNull simplifier, + InternalErrorReporter& ice +) { UnifierSharedState sharedState{&ice}; TypeArena arena; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{ + NotNull{&ice}, NotNull{&limits} + }; // TODO: maybe subtyping checks should not invoke user-defined type function runtime // Subtyping under DCR is not implemented using unification! if (FFlag::LuauSolverV2) { - Subtyping subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&ice}}; + Subtyping subtyping{builtinTypes, NotNull{&arena}, simplifier, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&ice}}; return subtyping.isSubtype(subPack, superPack, scope).isSubtype; } @@ -3464,38 +3539,4 @@ bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull scope, N } } -bool isConsistentSubtype(TypeId subTy, TypeId superTy, NotNull scope, NotNull builtinTypes, InternalErrorReporter& ice) -{ - LUAU_ASSERT(!FFlag::LuauSolverV2); - - UnifierSharedState sharedState{&ice}; - TypeArena arena; - Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; - - u.tryUnify(subTy, superTy); - const bool ok = u.errors.empty() && u.log.empty(); - return ok; -} - -bool isConsistentSubtype( - TypePackId subPack, - TypePackId superPack, - NotNull scope, - NotNull builtinTypes, - InternalErrorReporter& ice -) -{ - LUAU_ASSERT(!FFlag::LuauSolverV2); - - UnifierSharedState sharedState{&ice}; - TypeArena arena; - Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; - Unifier u{NotNull{&normalizer}, scope, Location{}, Covariant}; - - u.tryUnify(subPack, superPack); - const bool ok = u.errors.empty() && u.log.empty(); - return ok; -} - } // namespace Luau diff --git a/Analysis/src/OverloadResolution.cpp b/Analysis/src/OverloadResolution.cpp index 972c9e3a..32858cd1 100644 --- a/Analysis/src/OverloadResolution.cpp +++ b/Analysis/src/OverloadResolution.cpp @@ -16,7 +16,9 @@ namespace Luau OverloadResolver::OverloadResolver( NotNull builtinTypes, NotNull arena, + NotNull simplifier, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull scope, NotNull reporter, NotNull limits, @@ -24,11 +26,13 @@ OverloadResolver::OverloadResolver( ) : builtinTypes(builtinTypes) , arena(arena) + , simplifier(simplifier) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , scope(scope) , ice(reporter) , limits(limits) - , subtyping({builtinTypes, arena, normalizer, ice}) + , subtyping({builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, ice}) , callLoc(callLocation) { } @@ -199,8 +203,9 @@ std::pair OverloadResolver::checkOverload_ const std::vector* argExprs ) { - FunctionGraphReductionResult result = - reduceTypeFunctions(fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, normalizer, ice, limits}, /*force=*/true); + FunctionGraphReductionResult result = reduceTypeFunctions( + fnTy, callLoc, TypeFunctionContext{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, ice, limits}, /*force=*/true + ); if (!result.errors.empty()) return {OverloadIsNonviable, result.errors}; @@ -401,10 +406,12 @@ void OverloadResolver::add(Analysis analysis, TypeId ty, ErrorVec&& errors) // we wrap calling the overload resolver in a separate function to reduce overall stack pressure in `solveFunctionCall`. // this limits the lifetime of `OverloadResolver`, a large type, to only as long as it is actually needed. -std::optional selectOverload( +static std::optional selectOverload( NotNull builtinTypes, NotNull arena, + NotNull simplifier, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull scope, NotNull iceReporter, NotNull limits, @@ -413,8 +420,9 @@ std::optional selectOverload( TypePackId argsPack ) { - OverloadResolver resolver{builtinTypes, arena, normalizer, scope, iceReporter, limits, location}; - auto [status, overload] = resolver.selectOverload(fn, argsPack); + auto resolver = + std::make_unique(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location); + auto [status, overload] = resolver->selectOverload(fn, argsPack); if (status == OverloadResolver::Analysis::Ok) return overload; @@ -428,7 +436,9 @@ std::optional selectOverload( SolveResult solveFunctionCall( NotNull arena, NotNull builtinTypes, + NotNull simplifier, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter, NotNull limits, NotNull scope, @@ -437,7 +447,8 @@ SolveResult solveFunctionCall( TypePackId argsPack ) { - std::optional overloadToUse = selectOverload(builtinTypes, arena, normalizer, scope, iceReporter, limits, location, fn, argsPack); + std::optional overloadToUse = + selectOverload(builtinTypes, arena, simplifier, normalizer, typeFunctionRuntime, scope, iceReporter, limits, location, fn, argsPack); if (!overloadToUse) return {SolveResult::NoMatchingOverload}; @@ -450,9 +461,9 @@ SolveResult solveFunctionCall( if (!u2.genericSubstitutions.empty() || !u2.genericPackSubstitutions.empty()) { - Instantiation2 instantiation{arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions)}; + auto instantiation = std::make_unique(arena, std::move(u2.genericSubstitutions), std::move(u2.genericPackSubstitutions)); - std::optional subst = instantiation.substitute(resultPack); + std::optional subst = instantiation->substitute(resultPack); if (!subst) return {SolveResult::CodeTooComplex}; diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index c036a7a5..95c1a344 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -4,6 +4,8 @@ #include "Luau/Ast.h" #include "Luau/Module.h" +LUAU_FASTFLAGVARIABLE(LuauExtendedSimpleRequire) + namespace Luau { @@ -65,7 +67,7 @@ struct RequireTracer : AstVisitor return true; } - AstExpr* getDependent(AstExpr* node) + AstExpr* getDependent_DEPRECATED(AstExpr* node) { if (AstExprLocal* expr = node->as()) return locals[expr->local]; @@ -78,50 +80,122 @@ struct RequireTracer : AstVisitor else return nullptr; } + AstNode* getDependent(AstNode* node) + { + if (AstExprLocal* expr = node->as()) + return locals[expr->local]; + else if (AstExprIndexName* expr = node->as()) + return expr->expr; + else if (AstExprIndexExpr* expr = node->as()) + return expr->expr; + else if (AstExprCall* expr = node->as(); expr && expr->self) + return expr->func->as()->expr; + else if (AstExprGroup* expr = node->as()) + return expr->expr; + else if (AstExprTypeAssertion* expr = node->as()) + return expr->annotation; + else if (AstTypeGroup* expr = node->as()) + return expr->type; + else if (AstTypeTypeof* expr = node->as()) + return expr->expr; + else + return nullptr; + } void process() { ModuleInfo moduleContext{currentModuleName}; - // seed worklist with require arguments - work.reserve(requireCalls.size()); - - for (AstExprCall* require : requireCalls) - work.push_back(require->args.data[0]); - - // push all dependent expressions to the work stack; note that the vector is modified during traversal - for (size_t i = 0; i < work.size(); ++i) - if (AstExpr* dep = getDependent(work[i])) - work.push_back(dep); - - // resolve all expressions to a module info - for (size_t i = work.size(); i > 0; --i) + if (FFlag::LuauExtendedSimpleRequire) { - AstExpr* expr = work[i - 1]; + // seed worklist with require arguments + work.reserve(requireCalls.size()); - // when multiple expressions depend on the same one we push it to work queue multiple times - if (result.exprs.contains(expr)) - continue; + for (AstExprCall* require : requireCalls) + work.push_back(require->args.data[0]); - std::optional info; - - if (AstExpr* dep = getDependent(expr)) + // push all dependent expressions to the work stack; note that the vector is modified during traversal + for (size_t i = 0; i < work.size(); ++i) { - const ModuleInfo* context = result.exprs.find(dep); + if (AstNode* dep = getDependent(work[i])) + work.push_back(dep); + } - // locals just inherit their dependent context, no resolution required - if (expr->is()) - info = context ? std::optional(*context) : std::nullopt; + // resolve all expressions to a module info + for (size_t i = work.size(); i > 0; --i) + { + AstNode* expr = work[i - 1]; + + // when multiple expressions depend on the same one we push it to work queue multiple times + if (result.exprs.contains(expr)) + continue; + + std::optional info; + + if (AstNode* dep = getDependent(expr)) + { + const ModuleInfo* context = result.exprs.find(dep); + + if (context && expr->is()) + info = *context; // locals just inherit their dependent context, no resolution required + else if (context && (expr->is() || expr->is())) + info = *context; // simple group nodes propagate their value + else if (context && (expr->is() || expr->is())) + info = *context; // typeof type annotations will resolve to the typeof content + else if (AstExpr* asExpr = expr->asExpr()) + info = fileResolver->resolveModule(context, asExpr); + } + else if (AstExpr* asExpr = expr->asExpr()) + { + info = fileResolver->resolveModule(&moduleContext, asExpr); + } + + if (info) + result.exprs[expr] = std::move(*info); + } + } + else + { + // seed worklist with require arguments + work_DEPRECATED.reserve(requireCalls.size()); + + for (AstExprCall* require : requireCalls) + work_DEPRECATED.push_back(require->args.data[0]); + + // push all dependent expressions to the work stack; note that the vector is modified during traversal + for (size_t i = 0; i < work_DEPRECATED.size(); ++i) + if (AstExpr* dep = getDependent_DEPRECATED(work_DEPRECATED[i])) + work_DEPRECATED.push_back(dep); + + // resolve all expressions to a module info + for (size_t i = work_DEPRECATED.size(); i > 0; --i) + { + AstExpr* expr = work_DEPRECATED[i - 1]; + + // when multiple expressions depend on the same one we push it to work queue multiple times + if (result.exprs.contains(expr)) + continue; + + std::optional info; + + if (AstExpr* dep = getDependent_DEPRECATED(expr)) + { + const ModuleInfo* context = result.exprs.find(dep); + + // locals just inherit their dependent context, no resolution required + if (expr->is()) + info = context ? std::optional(*context) : std::nullopt; + else + info = fileResolver->resolveModule(context, expr); + } else - info = fileResolver->resolveModule(context, expr); - } - else - { - info = fileResolver->resolveModule(&moduleContext, expr); - } + { + info = fileResolver->resolveModule(&moduleContext, expr); + } - if (info) - result.exprs[expr] = std::move(*info); + if (info) + result.exprs[expr] = std::move(*info); + } } // resolve all requires according to their argument @@ -150,7 +224,8 @@ struct RequireTracer : AstVisitor ModuleName currentModuleName; DenseHashMap locals; - std::vector work; + std::vector work_DEPRECATED; + std::vector work; std::vector requireCalls; }; diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index 27894505..db99d827 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -211,6 +211,16 @@ void Scope::inheritRefinements(const ScopePtr& childScope) } } +bool Scope::shouldWarnGlobal(std::string name) const +{ + for (const Scope* current = this; current; current = current->parent.get()) + { + if (current->globalsToWarn.contains(name)) + return true; + } + return false; +} + bool subsumesStrict(Scope* left, Scope* right) { while (right) diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index 099e6a0d..8a0483e6 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -2,6 +2,7 @@ #include "Luau/Simplify.h" +#include "Luau/Common.h" #include "Luau/DenseHash.h" #include "Luau/RecursionCounter.h" #include "Luau/Set.h" @@ -14,6 +15,7 @@ LUAU_FASTINT(LuauTypeReductionRecursionLimit) LUAU_FASTFLAG(LuauSolverV2) LUAU_DYNAMIC_FASTINTVARIABLE(LuauSimplificationComplexityLimit, 8); +LUAU_FASTFLAGVARIABLE(LuauFlagBasicIntersectFollows); namespace Luau { @@ -29,16 +31,16 @@ struct TypeSimplifier int recursionDepth = 0; - TypeId mkNegation(TypeId ty); + TypeId mkNegation(TypeId ty) const; TypeId intersectFromParts(std::set parts); - TypeId intersectUnionWithType(TypeId unionTy, TypeId right); + TypeId intersectUnionWithType(TypeId left, TypeId right); TypeId intersectUnions(TypeId left, TypeId right); - TypeId intersectNegatedUnion(TypeId unionTy, TypeId right); + TypeId intersectNegatedUnion(TypeId left, TypeId right); - TypeId intersectTypeWithNegation(TypeId a, TypeId b); - TypeId intersectNegations(TypeId a, TypeId b); + TypeId intersectTypeWithNegation(TypeId left, TypeId right); + TypeId intersectNegations(TypeId left, TypeId right); TypeId intersectIntersectionWithType(TypeId left, TypeId right); @@ -46,8 +48,8 @@ struct TypeSimplifier // unions, intersections, or negations. std::optional basicIntersect(TypeId left, TypeId right); - TypeId intersect(TypeId ty, TypeId discriminant); - TypeId union_(TypeId ty, TypeId discriminant); + TypeId intersect(TypeId left, TypeId right); + TypeId union_(TypeId left, TypeId right); TypeId simplify(TypeId ty); TypeId simplify(TypeId ty, DenseHashSet& seen); @@ -571,7 +573,7 @@ Relation relate(TypeId left, TypeId right) return relate(left, right, seen); } -TypeId TypeSimplifier::mkNegation(TypeId ty) +TypeId TypeSimplifier::mkNegation(TypeId ty) const { TypeId result = nullptr; @@ -1064,6 +1066,12 @@ TypeId TypeSimplifier::intersectIntersectionWithType(TypeId left, TypeId right) std::optional TypeSimplifier::basicIntersect(TypeId left, TypeId right) { + if (FFlag::LuauFlagBasicIntersectFollows) + { + left = follow(left); + right = follow(right); + } + if (get(left) && get(right)) return right; if (get(right) && get(left)) @@ -1403,8 +1411,6 @@ TypeId TypeSimplifier::simplify(TypeId ty, DenseHashSet& seen) SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) { - LUAU_ASSERT(FFlag::LuauSolverV2); - TypeSimplifier s{builtinTypes, arena}; // fprintf(stderr, "Intersect %s and %s ...\n", toString(left).c_str(), toString(right).c_str()); @@ -1418,8 +1424,6 @@ SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull< SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull arena, std::set parts) { - LUAU_ASSERT(FFlag::LuauSolverV2); - TypeSimplifier s{builtinTypes, arena}; TypeId res = s.intersectFromParts(std::move(parts)); @@ -1429,8 +1433,6 @@ SimplifyResult simplifyIntersection(NotNull builtinTypes, NotNull< SimplifyResult simplifyUnion(NotNull builtinTypes, NotNull arena, TypeId left, TypeId right) { - LUAU_ASSERT(FFlag::LuauSolverV2); - TypeSimplifier s{builtinTypes, arena}; TypeId res = s.union_(left, right); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 526d8212..e00f0d3d 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -4,13 +4,15 @@ #include "Luau/Common.h" #include "Luau/Clone.h" #include "Luau/TxnLog.h" +#include "Luau/Type.h" #include #include LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) -LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256); +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTINTVARIABLE(LuauTarjanPreallocationSize, 256) +LUAU_FASTFLAG(LuauSyntheticErrors) namespace Luau { @@ -50,11 +52,33 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a LUAU_ASSERT(ty->persistent); return ty; } - else if constexpr (std::is_same_v) + else if constexpr (std::is_same_v) { LUAU_ASSERT(ty->persistent); return ty; } + else if constexpr (std::is_same_v) + { + if (FFlag::LuauSyntheticErrors) + { + LUAU_ASSERT(ty->persistent || a.synthetic); + + if (ty->persistent) + return ty; + + // While this code intentionally works (and clones) even if `a.synthetic` is `std::nullopt`, + // we still assert above because we consider it a bug to have a non-persistent error type + // without any associated metadata. We should always use the persistent version in such cases. + ErrorType clone = ErrorType{}; + clone.synthetic = a.synthetic; + return dest.addType(clone); + } + else + { + LUAU_ASSERT(ty->persistent); + return ty; + } + } else if constexpr (std::is_same_v) { LUAU_ASSERT(ty->persistent); @@ -74,9 +98,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a FunctionType clone = FunctionType{a.level, a.scope, a.argTypes, a.retTypes, a.definition, a.hasSelf}; clone.generics = a.generics; clone.genericPacks = a.genericPacks; - clone.magicFunction = a.magicFunction; - clone.dcrMagicFunction = a.dcrMagicFunction; - clone.dcrMagicRefinement = a.dcrMagicRefinement; + clone.magic = a.magic; clone.tags = a.tags; clone.argNames = a.argNames; clone.isCheckedFunction = a.isCheckedFunction; @@ -127,7 +149,7 @@ static TypeId shallowClone(TypeId ty, TypeArena& dest, const TxnLog* log, bool a return dest.addType(NegationType{a.ty}); else if constexpr (std::is_same_v) { - TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncBody}; + TypeFunctionInstanceType clone{a.function, a.typeArguments, a.packArguments, a.userFuncName, a.userFuncData}; return dest.addType(std::move(clone)); } else diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp index ee199b66..a4f2ce1e 100644 --- a/Analysis/src/Subtyping.cpp +++ b/Analysis/src/Subtyping.cpp @@ -5,6 +5,7 @@ #include "Luau/Common.h" #include "Luau/Error.h" #include "Luau/Normalize.h" +#include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" #include "Luau/Substitution.h" @@ -20,7 +21,8 @@ #include -LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity, false); +LUAU_FASTFLAGVARIABLE(DebugLuauSubtypingCheckPathValidity) +LUAU_FASTFLAGVARIABLE(LuauSubtypingFixTailPack) namespace Luau { @@ -258,43 +260,32 @@ SubtypingResult SubtypingResult::any(const std::vector& results struct ApplyMappedGenerics : Substitution { - using MappedGenerics = DenseHashMap; - using MappedGenericPacks = DenseHashMap; - NotNull builtinTypes; NotNull arena; - MappedGenerics& mappedGenerics; - MappedGenericPacks& mappedGenericPacks; + SubtypingEnvironment& env; - - ApplyMappedGenerics( - NotNull builtinTypes, - NotNull arena, - MappedGenerics& mappedGenerics, - MappedGenericPacks& mappedGenericPacks - ) + ApplyMappedGenerics(NotNull builtinTypes, NotNull arena, SubtypingEnvironment& env) : Substitution(TxnLog::empty(), arena) , builtinTypes(builtinTypes) , arena(arena) - , mappedGenerics(mappedGenerics) - , mappedGenericPacks(mappedGenericPacks) + , env(env) { } bool isDirty(TypeId ty) override { - return mappedGenerics.contains(ty); + return env.containsMappedType(ty); } bool isDirty(TypePackId tp) override { - return mappedGenericPacks.contains(tp); + return env.containsMappedPack(tp); } TypeId clean(TypeId ty) override { - const auto& bounds = mappedGenerics[ty]; + const auto& bounds = env.getMappedTypeBounds(ty); if (bounds.upperBound.empty()) return builtinTypes->unknownType; @@ -307,7 +298,12 @@ struct ApplyMappedGenerics : Substitution TypePackId clean(TypePackId tp) override { - return mappedGenericPacks[tp]; + if (auto it = env.getMappedPackBounds(tp)) + return *it; + + // Clean is only called when isDirty found a pack bound + LUAU_ASSERT(!"Unreachable"); + return nullptr; } bool ignoreChildren(TypeId ty) override @@ -325,19 +321,91 @@ struct ApplyMappedGenerics : Substitution std::optional SubtypingEnvironment::applyMappedGenerics(NotNull builtinTypes, NotNull arena, TypeId ty) { - ApplyMappedGenerics amg{builtinTypes, arena, mappedGenerics, mappedGenericPacks}; + ApplyMappedGenerics amg{builtinTypes, arena, *this}; return amg.substitute(ty); } +const TypeId* SubtypingEnvironment::tryFindSubstitution(TypeId ty) const +{ + if (auto it = substitutions.find(ty)) + return it; + + if (parent) + return parent->tryFindSubstitution(ty); + + return nullptr; +} + +const SubtypingResult* SubtypingEnvironment::tryFindSubtypingResult(std::pair subAndSuper) const +{ + if (auto it = ephemeralCache.find(subAndSuper)) + return it; + + if (parent) + return parent->tryFindSubtypingResult(subAndSuper); + + return nullptr; +} + +bool SubtypingEnvironment::containsMappedType(TypeId ty) const +{ + if (mappedGenerics.contains(ty)) + return true; + + if (parent) + return parent->containsMappedType(ty); + + return false; +} + +bool SubtypingEnvironment::containsMappedPack(TypePackId tp) const +{ + if (mappedGenericPacks.contains(tp)) + return true; + + if (parent) + return parent->containsMappedPack(tp); + + return false; +} + +SubtypingEnvironment::GenericBounds& SubtypingEnvironment::getMappedTypeBounds(TypeId ty) +{ + if (auto it = mappedGenerics.find(ty)) + return *it; + + if (parent) + return parent->getMappedTypeBounds(ty); + + LUAU_ASSERT(!"Use containsMappedType before asking for bounds!"); + return mappedGenerics[ty]; +} + +TypePackId* SubtypingEnvironment::getMappedPackBounds(TypePackId tp) +{ + if (auto it = mappedGenericPacks.find(tp)) + return it; + + if (parent) + return parent->getMappedPackBounds(tp); + + // This fallback is reachable in valid cases, unlike the final part of getMappedTypeBounds + return nullptr; +} + Subtyping::Subtyping( NotNull builtinTypes, NotNull typeArena, + NotNull simplifier, NotNull normalizer, + NotNull typeFunctionRuntime, NotNull iceReporter ) : builtinTypes(builtinTypes) , arena(typeArena) + , simplifier(simplifier) , normalizer(normalizer) + , typeFunctionRuntime(typeFunctionRuntime) , iceReporter(iceReporter) { } @@ -379,7 +447,10 @@ SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy, NotNull scope) { + UnifierCounters& counters = normalizer->sharedState->counters; + RecursionCounter rc(&counters.recursionCount); + + if (counters.recursionLimit > 0 && counters.recursionLimit < counters.recursionCount) + { + SubtypingResult result; + result.normalizationTooComplex = true; + return result; + } + subTy = follow(subTy); superTy = follow(superTy); - if (TypeId* subIt = env.substitutions.find(subTy); subIt && *subIt) + if (const TypeId* subIt = env.tryFindSubstitution(subTy); subIt && *subIt) subTy = *subIt; - if (TypeId* superIt = env.substitutions.find(superTy); superIt && *superIt) + if (const TypeId* superIt = env.tryFindSubstitution(superTy); superIt && *superIt) superTy = *superIt; - SubtypingResult* cachedResult = resultCache.find({subTy, superTy}); + const SubtypingResult* cachedResult = resultCache.find({subTy, superTy}); if (cachedResult) return *cachedResult; - cachedResult = env.ephemeralCache.find({subTy, superTy}); + cachedResult = env.tryFindSubtypingResult({subTy, superTy}); if (cachedResult) return *cachedResult; @@ -700,7 +781,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId std::vector headSlice(begin(superHead), begin(superHead) + headSize); TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail); - if (TypePackId* other = env.mappedGenericPacks.find(*subTail)) + if (TypePackId* other = env.getMappedPackBounds(*subTail)) // TODO: TypePath can't express "slice of a pack + its tail". results.push_back(isCovariantWith(env, *other, superTailPack, scope).withSubComponent(TypePath::PackField::Tail)); else @@ -755,7 +836,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId std::vector headSlice(begin(subHead), begin(subHead) + headSize); TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail); - if (TypePackId* other = env.mappedGenericPacks.find(*superTail)) + if (TypePackId* other = env.getMappedPackBounds(*superTail)) // TODO: TypePath can't express "slice of a pack + its tail". results.push_back(isContravariantWith(env, subTailPack, *other, scope).withSuperComponent(TypePath::PackField::Tail)); else @@ -778,7 +859,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, TypePackId else return SubtypingResult{false} .withSuperComponent(TypePath::PackField::Tail) - .withError({scope->location, UnexpectedTypePackInSubtyping{*subTail}}); + .withError({scope->location, UnexpectedTypePackInSubtyping{FFlag::LuauSubtypingFixTailPack ? *superTail : *subTail}}); } else return {false}; @@ -1316,6 +1397,7 @@ SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const Tabl SubtypingResult Subtyping::isCovariantWith(SubtypingEnvironment& env, const MetatableType* subMt, const MetatableType* superMt, NotNull scope) { return isCovariantWith(env, subMt->table, superMt->table, scope) + .withBothComponent(TypePath::TypeField::Table) .andAlso(isCovariantWith(env, subMt->metatable, superMt->metatable, scope).withBothComponent(TypePath::TypeField::Metatable)); } @@ -1389,6 +1471,19 @@ SubtypingResult Subtyping::isCovariantWith( result.orElse( isContravariantWith(env, subFunction->argTypes, superFunction->argTypes, scope).withBothComponent(TypePath::PackField::Arguments) ); + + // If subtyping failed in the argument packs, we should check if there's a hidden variadic tail and try ignoring it. + // This might cause subtyping correctly because the sub type here may not have a hidden variadic tail or equivalent. + if (!result.isSubtype) + { + auto [arguments, tail] = flatten(superFunction->argTypes); + + if (auto variadic = get(tail); variadic && variadic->hidden) + { + result.orElse(isContravariantWith(env, subFunction->argTypes, arena->addTypePack(TypePack{arguments}), scope) + .withBothComponent(TypePath::PackField::Arguments)); + } + } } result.andAlso(isCovariantWith(env, subFunction->retTypes, superFunction->retTypes, scope).withBothComponent(TypePath::PackField::Returns)); @@ -1688,6 +1783,9 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId supe if (!get(subTy)) return false; + if (!env.mappedGenerics.find(subTy) && env.containsMappedType(subTy)) + iceReporter->ice("attempting to modify bounds of a potentially visited generic"); + env.mappedGenerics[subTy].upperBound.insert(superTy); } else @@ -1695,6 +1793,9 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypeId subTy, TypeId supe if (!get(superTy)) return false; + if (!env.mappedGenerics.find(superTy) && env.containsMappedType(superTy)) + iceReporter->ice("attempting to modify bounds of a potentially visited generic"); + env.mappedGenerics[superTy].lowerBound.insert(subTy); } @@ -1740,7 +1841,7 @@ bool Subtyping::bindGeneric(SubtypingEnvironment& env, TypePackId subTp, TypePac if (!get(subTp)) return false; - if (TypePackId* m = env.mappedGenericPacks.find(subTp)) + if (TypePackId* m = env.getMappedPackBounds(subTp)) return *m == superTp; env.mappedGenericPacks[subTp] = superTp; @@ -1761,7 +1862,7 @@ TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) std::pair Subtyping::handleTypeFunctionReductionResult(const TypeFunctionInstanceType* functionInstance, NotNull scope) { - TypeFunctionContext context{arena, builtinTypes, scope, normalizer, iceReporter, NotNull{&limits}}; + TypeFunctionContext context{arena, builtinTypes, scope, simplifier, normalizer, typeFunctionRuntime, iceReporter, NotNull{&limits}}; TypeId function = arena->addType(*functionInstance); FunctionGraphReductionResult result = reduceTypeFunctions(function, {}, context, true); ErrorVec errors; diff --git a/Analysis/src/Symbol.cpp b/Analysis/src/Symbol.cpp index 5e5b9d8c..a5117608 100644 --- a/Analysis/src/Symbol.cpp +++ b/Analysis/src/Symbol.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(LuauSymbolEquality) namespace Luau { @@ -14,7 +15,7 @@ bool Symbol::operator==(const Symbol& rhs) const return local == rhs.local; else if (global.value) return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. - else if (FFlag::LuauSolverV2) + else if (FFlag::LuauSolverV2 || FFlag::LuauSymbolEquality) return !rhs.local && !rhs.global.value; // Reflexivity: we already know `this` Symbol is empty, so check that rhs is. else return false; diff --git a/Analysis/src/TableLiteralInference.cpp b/Analysis/src/TableLiteralInference.cpp index 177396a7..e5d8be04 100644 --- a/Analysis/src/TableLiteralInference.cpp +++ b/Analysis/src/TableLiteralInference.cpp @@ -6,19 +6,15 @@ #include "Luau/Type.h" #include "Luau/ToString.h" #include "Luau/TypeArena.h" +#include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" +LUAU_FASTFLAGVARIABLE(LuauDontInPlaceMutateTableType) +LUAU_FASTFLAGVARIABLE(LuauAllowNonSharedTableTypesInLiteral) + namespace Luau { -static bool isLiteral(const AstExpr* expr) -{ - return ( - expr->is() || expr->is() || expr->is() || expr->is() || - expr->is() || expr->is() - ); -} - // A fast approximation of subTy <: superTy static bool fastIsSubtype(TypeId subTy, TypeId superTy) { @@ -243,6 +239,8 @@ TypeId matchLiteralType( return exprType; } + DenseHashSet keysToDelete{nullptr}; + for (const AstExprTable::Item& item : exprTable->items) { if (isRecord(item)) @@ -254,8 +252,19 @@ TypeId matchLiteralType( Property& prop = it->second; - // Table literals always initially result in shared read-write types - LUAU_ASSERT(prop.isShared()); + if (FFlag::LuauAllowNonSharedTableTypesInLiteral) + { + // If we encounter a duplcate property, we may have already + // set it to be read-only. If that's the case, the only thing + // that will definitely crash is trying to access a write + // only property. + LUAU_ASSERT(!prop.isWriteOnly()); + } + else + { + // Table literals always initially result in shared read-write types + LUAU_ASSERT(prop.isShared()); + } TypeId propTy = *prop.readTy; auto it2 = expectedTableTy->props.find(keyStr); @@ -287,7 +296,10 @@ TypeId matchLiteralType( else tableTy->indexer = TableIndexer{expectedTableTy->indexer->indexType, matchedType}; - tableTy->props.erase(keyStr); + if (FFlag::LuauDontInPlaceMutateTableType) + keysToDelete.insert(item.key->as()); + else + tableTy->props.erase(keyStr); } // If it's just an extra property and the expected type @@ -381,15 +393,11 @@ TypeId matchLiteralType( const TypeId* keyTy = astTypes->find(item.key); LUAU_ASSERT(keyTy); TypeId tKey = follow(*keyTy); - if (get(tKey)) - toBlock.push_back(tKey); - + LUAU_ASSERT(!is(tKey)); const TypeId* propTy = astTypes->find(item.value); LUAU_ASSERT(propTy); TypeId tProp = follow(*propTy); - if (get(tProp)) - toBlock.push_back(tProp); - + LUAU_ASSERT(!is(tProp)); // Populate expected types for non-string keys declared with [] (the code below will handle the case where they are strings) if (!item.key->as() && expectedTableTy->indexer) (*astExpectedTypes)[item.key] = expectedTableTy->indexer->indexType; @@ -398,6 +406,16 @@ TypeId matchLiteralType( LUAU_ASSERT(!"Unexpected"); } + if (FFlag::LuauDontInPlaceMutateTableType) + { + for (const auto& key : keysToDelete) + { + const AstArray& s = key->value; + std::string keyStr{s.data, s.data + s.size}; + tableTy->props.erase(keyStr); + } + } + // Keys that the expectedType says we should have, but that aren't // specified by the AST fragment. // diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 4408063f..9b1c20fb 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -269,6 +269,12 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } + else if constexpr (std::is_same_v) + { + formatAppend(result, "NoRefineType %d", index); + finishNodeLabel(ty); + finishNode(); + } else if constexpr (std::is_same_v) { formatAppend(result, "UnknownType %d", index); @@ -414,7 +420,7 @@ void StateDot::visitChildren(TypePackId tp, int index) finishNodeLabel(tp); finishNode(); } - else if (get(tp)) + else if (get(tp)) { formatAppend(result, "ErrorTypePack %d", index); finishNodeLabel(tp); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index f0850835..91ec3edc 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -20,6 +20,7 @@ #include LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(LuauSyntheticErrors) /* * Enables increasing levels of verbosity for Luau type names when stringifying. @@ -38,7 +39,7 @@ LUAU_FASTFLAG(LuauSolverV2) * 3: Suffix free/generic types with their scope pointer, if present. */ LUAU_FASTINTVARIABLE(DebugLuauVerboseTypeNames, 0) -LUAU_FASTFLAGVARIABLE(DebugLuauToStringNoLexicalSort, false) +LUAU_FASTFLAGVARIABLE(DebugLuauToStringNoLexicalSort) namespace Luau { @@ -856,6 +857,11 @@ struct TypeStringifier state.emit("any"); } + void operator()(TypeId, const NoRefineType&) + { + state.emit("*no-refine*"); + } + void operator()(TypeId, const UnionType& uv) { if (state.hasSeen(&uv)) @@ -865,6 +871,8 @@ struct TypeStringifier return; } + LUAU_ASSERT(uv.options.size() > 1); + bool optional = false; bool hasNonNilDisjunct = false; @@ -873,7 +881,7 @@ struct TypeStringifier { el = follow(el); - if (isNil(el)) + if (state.opts.useQuestionMarks && isNil(el)) { optional = true; continue; @@ -991,7 +999,15 @@ struct TypeStringifier void operator()(TypeId, const ErrorType& tv) { state.result.error = true; - state.emit("*error-type*"); + + if (FFlag::LuauSyntheticErrors && tv.synthetic) + { + state.emit("*error-type<"); + stringify(*tv.synthetic); + state.emit(">*"); + } + else + state.emit("*error-type*"); } void operator()(TypeId, const LazyType& ltv) @@ -1040,6 +1056,7 @@ struct TypeStringifier state.emit(tfitv.userFuncName->value); else state.emit(tfitv.function->name); + state.emit("<"); bool comma = false; @@ -1165,10 +1182,18 @@ struct TypePackStringifier state.unsee(&tp); } - void operator()(TypePackId, const Unifiable::Error& error) + void operator()(TypePackId, const ErrorTypePack& error) { state.result.error = true; - state.emit("*error-type*"); + + if (FFlag::LuauSyntheticErrors && error.synthetic) + { + state.emit("*"); + stringify(*error.synthetic); + state.emit("*"); + } + else + state.emit("*error-type*"); } void operator()(TypePackId, const VariadicTypePack& pack) @@ -1840,6 +1865,8 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) } else if constexpr (std::is_same_v) return "equality: " + tos(c.resultType) + " ~ " + tos(c.assignmentType); + else if constexpr (std::is_same_v) + return "table_check " + tos(c.expectedType) + " :> " + tos(c.exprType); else static_assert(always_false_v, "Non-exhaustive constraint switch"); }; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index a42882ed..ab272587 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,6 +10,10 @@ #include #include +LUAU_FASTFLAG(LuauStoreCSTData) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAG(LuauAstTypeGroup2) +LUAU_FASTFLAG(LuauFixDoBlockEndLocation) namespace { @@ -45,11 +49,13 @@ struct Writer virtual void space() = 0; virtual void maybeSpace(const Position& newPos, int reserve) = 0; virtual void write(std::string_view) = 0; + virtual void writeMultiline(std::string_view) = 0; virtual void identifier(std::string_view name) = 0; virtual void keyword(std::string_view) = 0; virtual void symbol(std::string_view) = 0; virtual void literal(std::string_view) = 0; virtual void string(std::string_view) = 0; + virtual void sourceString(std::string_view, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth) = 0; }; struct StringWriter : Writer @@ -93,6 +99,32 @@ struct StringWriter : Writer lastChar = ' '; } + void writeMultiline(std::string_view s) override + { + if (s.empty()) + return; + + ss.append(s.data(), s.size()); + lastChar = s[s.size() - 1]; + + size_t index = 0; + size_t numLines = 0; + while (true) + { + auto newlinePos = s.find('\n', index); + if (newlinePos == std::string::npos) + break; + numLines++; + index = newlinePos + 1; + } + + pos.line += unsigned(numLines); + if (numLines > 0) + pos.column = unsigned(s.size()) - unsigned(index); + else + pos.column += unsigned(s.size()); + } + void write(std::string_view s) override { if (s.empty()) @@ -134,10 +166,17 @@ struct StringWriter : Writer void symbol(std::string_view s) override { - if (isDigit(lastChar) && s[0] == '.') - space(); + if (FFlag::LuauStoreCSTData) + { + write(s); + } + else + { + if (isDigit(lastChar) && s[0] == '.') + space(); - write(s); + write(s); + } } void literal(std::string_view s) override @@ -161,14 +200,54 @@ struct StringWriter : Writer write(escape(s)); write(quote); } + + void sourceString(std::string_view s, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth) override + { + if (quoteStyle == CstExprConstantString::QuotedRaw) + { + auto blocks = std::string(blockDepth, '='); + write('['); + write(blocks); + write('['); + writeMultiline(s); + write(']'); + write(blocks); + write(']'); + } + else + { + LUAU_ASSERT(blockDepth == 0); + + char quote = '"'; + switch (quoteStyle) + { + case CstExprConstantString::QuotedDouble: + quote = '"'; + break; + case CstExprConstantString::QuotedSingle: + quote = '\''; + break; + case CstExprConstantString::QuotedInterp: + quote = '`'; + break; + default: + LUAU_ASSERT(!"Unhandled quote type"); + } + + write(quote); + writeMultiline(s); + write(quote); + } + } }; class CommaSeparatorInserter { public: - CommaSeparatorInserter(Writer& w) + explicit CommaSeparatorInserter(Writer& w, const Position* commaPosition = nullptr) : first(true) , writer(w) + , commaPosition(commaPosition) { } void operator()() @@ -176,17 +255,25 @@ public: if (first) first = !first; else + { + if (FFlag::LuauStoreCSTData && commaPosition) + { + writer.advance(*commaPosition); + commaPosition++; + } writer.symbol(","); + } } private: bool first; Writer& writer; + const Position* commaPosition; }; -struct Printer +struct Printer_DEPRECATED { - explicit Printer(Writer& writer) + explicit Printer_DEPRECATED(Writer& writer) : writer(writer) { } @@ -242,7 +329,8 @@ struct Printer } else if (typeCount == 1) { - if (unconditionallyParenthesize) + bool shouldParenthesize = unconditionallyParenthesize && (list.types.size == 0 || !list.types.data[0]->is()); + if (FFlag::LuauAstTypeGroup2 ? shouldParenthesize : unconditionallyParenthesize) writer.symbol("("); // Only variadic tail @@ -255,7 +343,7 @@ struct Printer visualizeTypeAnnotation(*list.types.data[0]); } - if (unconditionallyParenthesize) + if (FFlag::LuauAstTypeGroup2 ? shouldParenthesize : unconditionallyParenthesize) writer.symbol(")"); } else @@ -433,6 +521,7 @@ struct Printer visualize(*item.value); } + // Decrement endPos column so that we advance to before the closing `}` brace before writing, rather than after it Position endPos = expr.location.end; if (endPos.column > 0) --endPos.column; @@ -578,7 +667,8 @@ struct Printer writer.keyword("do"); for (const auto& s : block->body) visualize(*s); - writer.advance(block->location.end); + if (!FFlag::LuauFixDoBlockEndLocation) + writer.advance(block->location.end); writeEnd(program.location); } else if (const auto& a = program.as()) @@ -810,14 +900,14 @@ struct Printer { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); - if (o.defaultValue) + if (o->defaultValue) { - writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.maybeSpace(o->defaultValue->location.begin, 2); writer.symbol("="); - visualizeTypeAnnotation(*o.defaultValue); + visualizeTypeAnnotation(*o->defaultValue); } } @@ -825,15 +915,15 @@ struct Printer { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); writer.symbol("..."); - if (o.defaultValue) + if (o->defaultValue) { - writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.maybeSpace(o->defaultValue->location.begin, 2); writer.symbol("="); - visualizeTypePackAnnotation(*o.defaultValue, false); + visualizeTypePackAnnotation(*o->defaultValue, false); } } @@ -890,15 +980,15 @@ struct Printer { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); } for (const auto& o : func.genericPacks) { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); writer.symbol("..."); } writer.symbol(">"); @@ -1027,15 +1117,15 @@ struct Printer { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); } for (const auto& o : a->genericPacks) { comma(); - writer.advance(o.location.begin); - writer.identifier(o.name.value); + writer.advance(o->location.begin); + writer.identifier(o->name.value); writer.symbol("..."); } writer.symbol(">"); @@ -1164,6 +1254,12 @@ struct Printer writer.symbol(")"); } } + else if (const auto& a = typeAnnotation.as()) + { + writer.symbol("("); + visualizeTypeAnnotation(*a->type); + writer.symbol(")"); + } else if (const auto& a = typeAnnotation.as()) { writer.keyword(a->value ? "true" : "false"); @@ -1183,20 +1279,1384 @@ struct Printer } }; +struct Printer +{ + explicit Printer(Writer& writer, CstNodeMap cstNodeMap) + : writer(writer) + , cstNodeMap(std::move(cstNodeMap)) + { + } + + bool writeTypes = false; + Writer& writer; + CstNodeMap cstNodeMap; + + template + T* lookupCstNode(AstNode* astNode) + { + if (const auto cstNode = cstNodeMap[astNode]) + return cstNode->as(); + return nullptr; + } + + void visualize(const AstLocal& local) + { + advance(local.location.begin); + + writer.identifier(local.name.value); + if (writeTypes && local.annotation) + { + // TODO: handle spacing for type annotation + writer.symbol(":"); + visualizeTypeAnnotation(*local.annotation); + } + } + + void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) + { + advance(annotation.location.begin); + if (const AstTypePackVariadic* variadicTp = annotation.as()) + { + if (!forVarArg) + writer.symbol("..."); + + visualizeTypeAnnotation(*variadicTp->variadicType); + } + else if (const AstTypePackGeneric* genericTp = annotation.as()) + { + writer.symbol(genericTp->genericName.value); + writer.symbol("..."); + } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + LUAU_ASSERT(!forVarArg); + visualizeTypeList(explicitTp->typeList, true); + } + else + { + LUAU_ASSERT(!"Unknown TypePackAnnotation kind"); + } + } + + void visualizeTypeList(const AstTypeList& list, bool unconditionallyParenthesize) + { + size_t typeCount = list.types.size + (list.tailType != nullptr ? 1 : 0); + if (typeCount == 0) + { + writer.symbol("("); + writer.symbol(")"); + } + else if (typeCount == 1) + { + bool shouldParenthesize = unconditionallyParenthesize && (list.types.size == 0 || !list.types.data[0]->is()); + if (FFlag::LuauAstTypeGroup2 ? shouldParenthesize : unconditionallyParenthesize) + writer.symbol("("); + + // Only variadic tail + if (list.types.size == 0) + { + visualizeTypePackAnnotation(*list.tailType, false); + } + else + { + visualizeTypeAnnotation(*list.types.data[0]); + } + + if (FFlag::LuauAstTypeGroup2 ? shouldParenthesize : unconditionallyParenthesize) + writer.symbol(")"); + } + else + { + writer.symbol("("); + + bool first = true; + for (const auto& el : list.types) + { + if (first) + first = false; + else + writer.symbol(","); + + visualizeTypeAnnotation(*el); + } + + if (list.tailType) + { + writer.symbol(","); + visualizeTypePackAnnotation(*list.tailType, false); + } + + writer.symbol(")"); + } + } + + bool isIntegerish(double d) + { + if (d <= std::numeric_limits::max() && d >= std::numeric_limits::min()) + return double(int(d)) == d && !(d == 0.0 && signbit(d)); + else + return false; + } + + void visualize(AstExpr& expr) + { + advance(expr.location.begin); + + if (const auto& a = expr.as()) + { + writer.symbol("("); + visualize(*a->expr); + advance(Position{a->location.end.line, a->location.end.column - 1}); + writer.symbol(")"); + } + else if (expr.is()) + { + writer.keyword("nil"); + } + else if (const auto& a = expr.as()) + { + if (a->value) + writer.keyword("true"); + else + writer.keyword("false"); + } + else if (const auto& a = expr.as()) + { + if (const auto cstNode = lookupCstNode(a)) + { + writer.literal(std::string_view(cstNode->value.data, cstNode->value.size)); + } + else + { + if (isinf(a->value)) + { + if (a->value > 0) + writer.literal("1e500"); + else + writer.literal("-1e500"); + } + else if (isnan(a->value)) + writer.literal("0/0"); + else + { + if (isIntegerish(a->value)) + writer.literal(std::to_string(int(a->value))); + else + { + char buffer[100]; + size_t len = snprintf(buffer, sizeof(buffer), "%.17g", a->value); + writer.literal(std::string_view{buffer, len}); + } + } + } + } + else if (const auto& a = expr.as()) + { + if (const auto cstNode = lookupCstNode(a)) + { + writer.sourceString( + std::string_view(cstNode->sourceString.data, cstNode->sourceString.size), cstNode->quoteStyle, cstNode->blockDepth + ); + } + else + writer.string(std::string_view(a->value.data, a->value.size)); + } + else if (const auto& a = expr.as()) + { + writer.identifier(a->local->name.value); + } + else if (const auto& a = expr.as()) + { + writer.identifier(a->name.value); + } + else if (expr.is()) + { + writer.symbol("..."); + } + else if (const auto& a = expr.as()) + { + visualize(*a->func); + + const auto cstNode = lookupCstNode(a); + + if (cstNode) + { + if (cstNode->openParens) + { + advance(*cstNode->openParens); + writer.symbol("("); + } + } + else + { + writer.symbol("("); + } + + CommaSeparatorInserter comma(writer, cstNode ? cstNode->commaPositions.begin() : nullptr); + for (const auto& arg : a->args) + { + comma(); + visualize(*arg); + } + + if (cstNode) + { + if (cstNode->closeParens) + { + advance(*cstNode->closeParens); + writer.symbol(")"); + } + } + else + { + writer.symbol(")"); + } + } + else if (const auto& a = expr.as()) + { + visualize(*a->expr); + advance(a->opPosition); + writer.symbol(std::string(1, a->op)); + advance(a->indexLocation.begin); + writer.write(a->index.value); + } + else if (const auto& a = expr.as()) + { + const auto cstNode = lookupCstNode(a); + visualize(*a->expr); + if (cstNode) + advance(cstNode->openBracketPosition); + writer.symbol("["); + visualize(*a->index); + if (cstNode) + advance(cstNode->closeBracketPosition); + writer.symbol("]"); + } + else if (const auto& a = expr.as()) + { + writer.keyword("function"); + visualizeFunctionBody(*a); + } + else if (const auto& a = expr.as()) + { + writer.symbol("{"); + + const CstExprTable::Item* cstItem = nullptr; + if (const auto cstNode = lookupCstNode(a)) + { + LUAU_ASSERT(cstNode->items.size == a->items.size); + cstItem = cstNode->items.begin(); + } + + bool first = true; + + for (const auto& item : a->items) + { + if (!cstItem) + { + if (first) + first = false; + else + writer.symbol(","); + } + + switch (item.kind) + { + case AstExprTable::Item::List: + break; + + case AstExprTable::Item::Record: + { + const auto& value = item.key->as()->value; + advance(item.key->location.begin); + writer.identifier(std::string_view(value.data, value.size)); + if (cstItem) + advance(*cstItem->equalsPosition); + else + writer.maybeSpace(item.value->location.begin, 1); + writer.symbol("="); + } + break; + + case AstExprTable::Item::General: + { + if (cstItem) + advance(*cstItem->indexerOpenPosition); + writer.symbol("["); + visualize(*item.key); + if (cstItem) + advance(*cstItem->indexerClosePosition); + writer.symbol("]"); + if (cstItem) + advance(*cstItem->equalsPosition); + else + writer.maybeSpace(item.value->location.begin, 1); + writer.symbol("="); + } + break; + + default: + LUAU_ASSERT(!"Unknown table item kind"); + } + + advance(item.value->location.begin); + visualize(*item.value); + + if (cstItem) + { + if (cstItem->separator) + { + LUAU_ASSERT(cstItem->separatorPosition); + advance(*cstItem->separatorPosition); + if (cstItem->separator == CstExprTable::Comma) + writer.symbol(","); + else if (cstItem->separator == CstExprTable::Semicolon) + writer.symbol(";"); + } + cstItem++; + } + } + + Position endPos = expr.location.end; + if (endPos.column > 0) + --endPos.column; + + advance(endPos); + + writer.symbol("}"); + advance(expr.location.end); + } + else if (const auto& a = expr.as()) + { + if (const auto cstNode = lookupCstNode(a)) + advance(cstNode->opPosition); + + switch (a->op) + { + case AstExprUnary::Not: + writer.keyword("not"); + break; + case AstExprUnary::Minus: + writer.symbol("-"); + break; + case AstExprUnary::Len: + writer.symbol("#"); + break; + } + visualize(*a->expr); + } + else if (const auto& a = expr.as()) + { + visualize(*a->left); + + if (const auto cstNode = lookupCstNode(a)) + advance(cstNode->opPosition); + else + { + switch (a->op) + { + case AstExprBinary::Add: + case AstExprBinary::Sub: + case AstExprBinary::Mul: + case AstExprBinary::Div: + case AstExprBinary::FloorDiv: + case AstExprBinary::Mod: + case AstExprBinary::Pow: + case AstExprBinary::CompareLt: + case AstExprBinary::CompareGt: + writer.maybeSpace(a->right->location.begin, 2); + break; + case AstExprBinary::Concat: + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + case AstExprBinary::CompareLe: + case AstExprBinary::CompareGe: + case AstExprBinary::Or: + writer.maybeSpace(a->right->location.begin, 3); + break; + case AstExprBinary::And: + writer.maybeSpace(a->right->location.begin, 4); + break; + default: + LUAU_ASSERT(!"Unknown Op"); + } + } + + writer.symbol(toString(a->op)); + + visualize(*a->right); + } + else if (const auto& a = expr.as()) + { + visualize(*a->expr); + + if (writeTypes) + { + if (const auto* cstNode = lookupCstNode(a)) + advance(cstNode->opPosition); + else + writer.maybeSpace(a->annotation->location.begin, 2); + writer.symbol("::"); + visualizeTypeAnnotation(*a->annotation); + } + } + else if (const auto& a = expr.as()) + { + writer.keyword("if"); + visualizeElseIfExpr(*a); + } + else if (const auto& a = expr.as()) + { + const auto* cstNode = lookupCstNode(a); + + writer.symbol("`"); + + size_t index = 0; + + for (const auto& string : a->strings) + { + if (cstNode) + { + if (index > 0) + { + advance(cstNode->stringPositions.data[index]); + writer.symbol("}"); + } + const AstArray sourceString = cstNode->sourceStrings.data[index]; + writer.writeMultiline(std::string_view(sourceString.data, sourceString.size)); + } + else + { + writer.write(escape(std::string_view(string.data, string.size), /* escapeForInterpString = */ true)); + } + + if (index < a->expressions.size) + { + writer.symbol("{"); + visualize(*a->expressions.data[index]); + if (!cstNode) + writer.symbol("}"); + } + + index++; + } + + writer.symbol("`"); + } + else if (const auto& a = expr.as()) + { + writer.symbol("(error-expr"); + + for (size_t i = 0; i < a->expressions.size; i++) + { + writer.symbol(i == 0 ? ": " : ", "); + visualize(*a->expressions.data[i]); + } + + writer.symbol(")"); + } + else + { + LUAU_ASSERT(!"Unknown AstExpr"); + } + } + + void writeEnd(const Location& loc) + { + Position endPos = loc.end; + if (endPos.column >= 3) + endPos.column -= 3; + advance(endPos); + writer.keyword("end"); + } + + void advance(const Position& newPos) + { + writer.advance(newPos); + } + + void visualize(AstStat& program) + { + advance(program.location.begin); + + if (const auto& block = program.as()) + { + writer.keyword("do"); + for (const auto& s : block->body) + visualize(*s); + if (const auto cstNode = lookupCstNode(block)) + { + advance(cstNode->endPosition); + writer.keyword("end"); + } + else + { + writer.advance(block->location.end); + writeEnd(program.location); + } + } + else if (const auto& a = program.as()) + { + writer.keyword("if"); + visualizeElseIf(*a); + } + else if (const auto& a = program.as()) + { + writer.keyword("while"); + visualize(*a->condition); + // TODO: what if 'hasDo = false'? + advance(a->doLocation.begin); + writer.keyword("do"); + visualizeBlock(*a->body); + advance(a->body->location.end); + writer.keyword("end"); + } + else if (const auto& a = program.as()) + { + writer.keyword("repeat"); + visualizeBlock(*a->body); + if (const auto cstNode = lookupCstNode(a)) + writer.advance(cstNode->untilPosition); + else if (a->condition->location.begin.column > 5) + writer.advance(Position{a->condition->location.begin.line, a->condition->location.begin.column - 6}); + writer.keyword("until"); + visualize(*a->condition); + } + else if (program.is()) + writer.keyword("break"); + else if (program.is()) + writer.keyword("continue"); + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("return"); + + CommaSeparatorInserter comma(writer, cstNode ? cstNode->commaPositions.begin() : nullptr); + for (const auto& expr : a->list) + { + comma(); + visualize(*expr); + } + } + else if (const auto& a = program.as()) + { + visualize(*a->expr); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("local"); + + CommaSeparatorInserter varComma(writer, cstNode ? cstNode->varsCommaPositions.begin() : nullptr); + for (const auto& local : a->vars) + { + varComma(); + visualize(*local); + } + + if (a->equalsSignLocation) + { + advance(a->equalsSignLocation->begin); + writer.symbol("="); + } + + + CommaSeparatorInserter valueComma(writer, cstNode ? cstNode->valuesCommaPositions.begin() : nullptr); + for (const auto& value : a->values) + { + valueComma(); + visualize(*value); + } + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("for"); + + visualize(*a->var); + if (cstNode) + advance(cstNode->equalsPosition); + writer.symbol("="); + visualize(*a->from); + if (cstNode) + advance(cstNode->endCommaPosition); + writer.symbol(","); + visualize(*a->to); + if (a->step) + { + if (cstNode && cstNode->stepCommaPosition) + advance(*cstNode->stepCommaPosition); + writer.symbol(","); + visualize(*a->step); + } + advance(a->doLocation.begin); + writer.keyword("do"); + visualizeBlock(*a->body); + + advance(a->body->location.end); + writer.keyword("end"); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("for"); + + CommaSeparatorInserter varComma(writer, cstNode ? cstNode->varsCommaPositions.begin() : nullptr); + for (const auto& var : a->vars) + { + varComma(); + visualize(*var); + } + + advance(a->inLocation.begin); + writer.keyword("in"); + + CommaSeparatorInserter valComma(writer, cstNode ? cstNode->valuesCommaPositions.begin() : nullptr); + + for (const auto& val : a->values) + { + valComma(); + visualize(*val); + } + + advance(a->doLocation.begin); + writer.keyword("do"); + + visualizeBlock(*a->body); + + advance(a->body->location.end); + writer.keyword("end"); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + CommaSeparatorInserter varComma(writer, cstNode ? cstNode->varsCommaPositions.begin() : nullptr); + for (const auto& var : a->vars) + { + varComma(); + visualize(*var); + } + + if (cstNode) + advance(cstNode->equalsPosition); + else + writer.space(); + writer.symbol("="); + + CommaSeparatorInserter valueComma(writer, cstNode ? cstNode->valuesCommaPositions.begin() : nullptr); + for (const auto& value : a->values) + { + valueComma(); + visualize(*value); + } + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + visualize(*a->var); + + if (cstNode) + advance(cstNode->opPosition); + + switch (a->op) + { + case AstExprBinary::Add: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("+="); + break; + case AstExprBinary::Sub: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("-="); + break; + case AstExprBinary::Mul: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("*="); + break; + case AstExprBinary::Div: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("/="); + break; + case AstExprBinary::FloorDiv: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 3); + writer.symbol("//="); + break; + case AstExprBinary::Mod: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("%="); + break; + case AstExprBinary::Pow: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("^="); + break; + case AstExprBinary::Concat: + if (!cstNode) + writer.maybeSpace(a->value->location.begin, 3); + writer.symbol("..="); + break; + default: + LUAU_ASSERT(!"Unexpected compound assignment op"); + } + + visualize(*a->value); + } + else if (const auto& a = program.as()) + { + writer.keyword("function"); + visualize(*a->name); + visualizeFunctionBody(*a->func); + } + else if (const auto& a = program.as()) + { + const auto cstNode = lookupCstNode(a); + + writer.keyword("local"); + + if (cstNode) + advance(cstNode->functionKeywordPosition); + else + writer.space(); + + writer.keyword("function"); + advance(a->name->location.begin); + writer.identifier(a->name->name.value); + visualizeFunctionBody(*a->func); + } + else if (const auto& a = program.as()) + { + if (writeTypes) + { + const auto* cstNode = lookupCstNode(a); + + if (a->exported) + writer.keyword("export"); + + if (cstNode) + advance(cstNode->typeKeywordPosition); + + writer.keyword("type"); + advance(a->nameLocation.begin); + writer.identifier(a->name.value); + if (a->generics.size > 0 || a->genericPacks.size > 0) + { + if (cstNode) + advance(cstNode->genericsOpenPosition); + writer.symbol("<"); + CommaSeparatorInserter comma(writer, cstNode ? cstNode->genericsCommaPositions.begin() : nullptr); + + for (auto o : a->generics) + { + comma(); + + writer.advance(o->location.begin); + writer.identifier(o->name.value); + + if (o->defaultValue) + { + const auto* genericTypeCstNode = lookupCstNode(o); + + if (genericTypeCstNode) + { + LUAU_ASSERT(genericTypeCstNode->defaultEqualsPosition.has_value()); + advance(*genericTypeCstNode->defaultEqualsPosition); + } + else + writer.maybeSpace(o->defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*o->defaultValue); + } + } + + for (auto o : a->genericPacks) + { + comma(); + + const auto* genericTypePackCstNode = lookupCstNode(o); + + writer.advance(o->location.begin); + writer.identifier(o->name.value); + if (genericTypePackCstNode) + advance(genericTypePackCstNode->ellipsisPosition); + writer.symbol("..."); + + if (o->defaultValue) + { + if (cstNode) + { + LUAU_ASSERT(genericTypePackCstNode->defaultEqualsPosition.has_value()); + advance(*genericTypePackCstNode->defaultEqualsPosition); + } + else + writer.maybeSpace(o->defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypePackAnnotation(*o->defaultValue, false); + } + } + + if (cstNode) + advance(cstNode->genericsClosePosition); + writer.symbol(">"); + } + if (cstNode) + advance(cstNode->equalsPosition); + else + writer.maybeSpace(a->type->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*a->type); + } + } + else if (const auto& t = program.as()) + { + if (writeTypes) + { + writer.keyword("type function"); + writer.identifier(t->name.value); + visualizeFunctionBody(*t->body); + } + } + else if (const auto& a = program.as()) + { + writer.symbol("(error-stat"); + + for (size_t i = 0; i < a->expressions.size; i++) + { + writer.symbol(i == 0 ? ": " : ", "); + visualize(*a->expressions.data[i]); + } + + for (size_t i = 0; i < a->statements.size; i++) + { + writer.symbol(i == 0 && a->expressions.size == 0 ? ": " : ", "); + visualize(*a->statements.data[i]); + } + + writer.symbol(")"); + } + else + { + LUAU_ASSERT(!"Unknown AstStat"); + } + + if (program.hasSemicolon) + { + if (FFlag::LuauStoreCSTData) + advance(Position{program.location.end.line, program.location.end.column - 1}); + writer.symbol(";"); + } + } + + void visualizeFunctionBody(AstExprFunction& func) + { + if (func.generics.size > 0 || func.genericPacks.size > 0) + { + CommaSeparatorInserter comma(writer); + writer.symbol("<"); + for (const auto& o : func.generics) + { + comma(); + + writer.advance(o->location.begin); + writer.identifier(o->name.value); + } + for (const auto& o : func.genericPacks) + { + comma(); + + writer.advance(o->location.begin); + writer.identifier(o->name.value); + writer.symbol("..."); + } + writer.symbol(">"); + } + + writer.symbol("("); + CommaSeparatorInserter comma(writer); + + for (size_t i = 0; i < func.args.size; ++i) + { + AstLocal* local = func.args.data[i]; + + comma(); + + advance(local->location.begin); + writer.identifier(local->name.value); + if (writeTypes && local->annotation) + { + writer.symbol(":"); + visualizeTypeAnnotation(*local->annotation); + } + } + + if (func.vararg) + { + comma(); + advance(func.varargLocation.begin); + writer.symbol("..."); + + if (func.varargAnnotation) + { + writer.symbol(":"); + visualizeTypePackAnnotation(*func.varargAnnotation, true); + } + } + + writer.symbol(")"); + + if (writeTypes && func.returnAnnotation) + { + writer.symbol(":"); + writer.space(); + + visualizeTypeList(*func.returnAnnotation, false); + } + + visualizeBlock(*func.body); + advance(func.body->location.end); + writer.keyword("end"); + } + + void visualizeBlock(AstStatBlock& block) + { + for (const auto& s : block.body) + visualize(*s); + writer.advance(block.location.end); + } + + void visualizeBlock(AstStat& stat) + { + if (AstStatBlock* block = stat.as()) + visualizeBlock(*block); + else + LUAU_ASSERT(!"visualizeBlock was expecting an AstStatBlock"); + } + + void visualizeElseIf(AstStatIf& elseif) + { + visualize(*elseif.condition); + if (elseif.thenLocation) + advance(elseif.thenLocation->begin); + writer.keyword("then"); + visualizeBlock(*elseif.thenbody); + + if (elseif.elsebody == nullptr) + { + advance(elseif.thenbody->location.end); + writer.keyword("end"); + } + else if (auto elseifelseif = elseif.elsebody->as()) + { + if (elseif.elseLocation) + advance(elseif.elseLocation->begin); + writer.keyword("elseif"); + visualizeElseIf(*elseifelseif); + } + else + { + if (elseif.elseLocation) + advance(elseif.elseLocation->begin); + writer.keyword("else"); + + visualizeBlock(*elseif.elsebody); + advance(elseif.elsebody->location.end); + writer.keyword("end"); + } + } + + void visualizeElseIfExpr(AstExprIfElse& elseif) + { + const auto cstNode = lookupCstNode(&elseif); + + visualize(*elseif.condition); + if (cstNode) + advance(cstNode->thenPosition); + writer.keyword("then"); + visualize(*elseif.trueExpr); + + if (elseif.falseExpr) + { + if (cstNode) + advance(cstNode->elsePosition); + if (auto elseifelseif = elseif.falseExpr->as(); elseifelseif && (!cstNode || cstNode->isElseIf)) + { + writer.keyword("elseif"); + visualizeElseIfExpr(*elseifelseif); + } + else + { + writer.keyword("else"); + visualize(*elseif.falseExpr); + } + } + } + + void visualizeTypeAnnotation(AstType& typeAnnotation) + { + advance(typeAnnotation.location.begin); + if (const auto& a = typeAnnotation.as()) + { + const auto cstNode = lookupCstNode(a); + + if (a->prefix) + { + writer.write(a->prefix->value); + if (cstNode) + advance(*cstNode->prefixPointPosition); + writer.symbol("."); + } + + advance(a->nameLocation.begin); + writer.write(a->name.value); + if (a->parameters.size > 0 || a->hasParameterList) + { + CommaSeparatorInserter comma(writer, cstNode ? cstNode->parametersCommaPositions.begin() : nullptr); + if (cstNode) + advance(cstNode->openParametersPosition); + writer.symbol("<"); + for (auto o : a->parameters) + { + comma(); + + if (o.type) + visualizeTypeAnnotation(*o.type); + else + visualizeTypePackAnnotation(*o.typePack, false); + } + if (cstNode) + advance(cstNode->closeParametersPosition); + writer.symbol(">"); + } + } + else if (const auto& a = typeAnnotation.as()) + { + if (a->generics.size > 0 || a->genericPacks.size > 0) + { + CommaSeparatorInserter comma(writer); + writer.symbol("<"); + for (const auto& o : a->generics) + { + comma(); + + writer.advance(o->location.begin); + writer.identifier(o->name.value); + } + for (const auto& o : a->genericPacks) + { + comma(); + + writer.advance(o->location.begin); + writer.identifier(o->name.value); + writer.symbol("..."); + } + writer.symbol(">"); + } + + { + visualizeTypeList(a->argTypes, true); + } + + writer.symbol("->"); + visualizeTypeList(a->returnTypes, true); + } + else if (const auto& a = typeAnnotation.as()) + { + AstTypeReference* indexType = a->indexer ? a->indexer->indexType->as() : nullptr; + + writer.symbol("{"); + + const auto cstNode = lookupCstNode(a); + if (cstNode) + { + if (cstNode->isArray) + { + LUAU_ASSERT(a->props.size == 0 && indexType && indexType->name == "number"); + if (a->indexer->accessLocation) + { + LUAU_ASSERT(a->indexer->access != AstTableAccess::ReadWrite); + advance(a->indexer->accessLocation->begin); + writer.keyword(a->indexer->access == AstTableAccess::Read ? "read" : "write"); + } + visualizeTypeAnnotation(*a->indexer->resultType); + } + else + { + const AstTableProp* prop = a->props.begin(); + + for (size_t i = 0; i < cstNode->items.size; ++i) + { + CstTypeTable::Item item = cstNode->items.data[i]; + // we store indexer as part of items to preserve property ordering + if (item.kind == CstTypeTable::Item::Kind::Indexer) + { + LUAU_ASSERT(a->indexer); + + if (a->indexer->accessLocation) + { + LUAU_ASSERT(a->indexer->access != AstTableAccess::ReadWrite); + advance(a->indexer->accessLocation->begin); + writer.keyword(a->indexer->access == AstTableAccess::Read ? "read" : "write"); + } + + advance(item.indexerOpenPosition); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + advance(item.indexerClosePosition); + writer.symbol("]"); + advance(item.colonPosition); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + + if (item.separator) + { + LUAU_ASSERT(item.separatorPosition); + advance(*item.separatorPosition); + if (item.separator == CstExprTable::Comma) + writer.symbol(","); + else if (item.separator == CstExprTable::Semicolon) + writer.symbol(";"); + } + } + else + { + if (prop->accessLocation) + { + LUAU_ASSERT(prop->access != AstTableAccess::ReadWrite); + advance(prop->accessLocation->begin); + writer.keyword(prop->access == AstTableAccess::Read ? "read" : "write"); + } + + if (item.kind == CstTypeTable::Item::Kind::StringProperty) + { + advance(item.indexerOpenPosition); + writer.symbol("["); + writer.sourceString( + std::string_view(item.stringInfo->sourceString.data, item.stringInfo->sourceString.size), + item.stringInfo->quoteStyle, + item.stringInfo->blockDepth + ); + advance(item.indexerClosePosition); + writer.symbol("]"); + } + else + { + advance(prop->location.begin); + writer.identifier(prop->name.value); + } + + advance(item.colonPosition); + writer.symbol(":"); + visualizeTypeAnnotation(*prop->type); + + if (item.separator) + { + LUAU_ASSERT(item.separatorPosition); + advance(*item.separatorPosition); + if (item.separator == CstExprTable::Comma) + writer.symbol(","); + else if (item.separator == CstExprTable::Semicolon) + writer.symbol(";"); + } + + ++prop; + } + } + } + } + else + { + if (a->props.size == 0 && indexType && indexType->name == "number") + { + visualizeTypeAnnotation(*a->indexer->resultType); + } + else + { + CommaSeparatorInserter comma(writer); + + for (size_t i = 0; i < a->props.size; ++i) + { + comma(); + advance(a->props.data[i].location.begin); + writer.identifier(a->props.data[i].name.value); + if (a->props.data[i].type) + { + writer.symbol(":"); + visualizeTypeAnnotation(*a->props.data[i].type); + } + } + if (a->indexer) + { + comma(); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + writer.symbol("]"); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + } + } + } + + Position endPos = a->location.end; + if (endPos.column > 0) + --endPos.column; + advance(endPos); + + writer.symbol("}"); + } + else if (auto a = typeAnnotation.as()) + { + const auto cstNode = lookupCstNode(a); + writer.keyword("typeof"); + if (cstNode) + advance(cstNode->openPosition); + writer.symbol("("); + visualize(*a->expr); + if (cstNode) + advance(cstNode->closePosition); + writer.symbol(")"); + } + else if (const auto& a = typeAnnotation.as()) + { + if (a->types.size == 2) + { + AstType* l = a->types.data[0]; + AstType* r = a->types.data[1]; + + auto lta = l->as(); + if (lta && lta->name == "nil") + std::swap(l, r); + + // it's still possible that we had a (T | U) or (T | nil) and not (nil | T) + auto rta = r->as(); + if (rta && rta->name == "nil") + { + bool wrap = l->as() || l->as(); + + if (wrap) + writer.symbol("("); + + visualizeTypeAnnotation(*l); + + if (wrap) + writer.symbol(")"); + + writer.symbol("?"); + return; + } + } + + for (size_t i = 0; i < a->types.size; ++i) + { + if (i > 0) + { + writer.maybeSpace(a->types.data[i]->location.begin, 2); + writer.symbol("|"); + } + + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); + } + } + else if (const auto& a = typeAnnotation.as()) + { + for (size_t i = 0; i < a->types.size; ++i) + { + if (i > 0) + { + writer.maybeSpace(a->types.data[i]->location.begin, 2); + writer.symbol("&"); + } + + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); + } + } + else if (const auto& a = typeAnnotation.as()) + { + writer.symbol("("); + visualizeTypeAnnotation(*a->type); + advance(Position{a->location.end.line, a->location.end.column - 1}); + writer.symbol(")"); + } + else if (const auto& a = typeAnnotation.as()) + { + writer.keyword(a->value ? "true" : "false"); + } + else if (const auto& a = typeAnnotation.as()) + { + if (const auto cstNode = lookupCstNode(a)) + { + writer.sourceString( + std::string_view(cstNode->sourceString.data, cstNode->sourceString.size), cstNode->quoteStyle, cstNode->blockDepth + ); + } + else + writer.string(std::string_view(a->value.data, a->value.size)); + } + else if (typeAnnotation.is()) + { + writer.symbol("%error-type%"); + } + else + { + LUAU_ASSERT(!"Unknown AstType"); + } + } +}; + std::string toString(AstNode* node) { StringWriter writer; writer.pos = node->location.begin; - Printer printer(writer); - printer.writeTypes = true; + if (FFlag::LuauStoreCSTData) + { + Printer printer(writer, CstNodeMap{nullptr}); + printer.writeTypes = true; - if (auto statNode = node->asStat()) - printer.visualize(*statNode); - else if (auto exprNode = node->asExpr()) - printer.visualize(*exprNode); - else if (auto typeNode = node->asType()) - printer.visualizeTypeAnnotation(*typeNode); + if (auto statNode = node->asStat()) + printer.visualize(*statNode); + else if (auto exprNode = node->asExpr()) + printer.visualize(*exprNode); + else if (auto typeNode = node->asType()) + printer.visualizeTypeAnnotation(*typeNode); + } + else + { + Printer_DEPRECATED printer(writer); + printer.writeTypes = true; + + if (auto statNode = node->asStat()) + printer.visualize(*statNode); + else if (auto exprNode = node->asExpr()) + printer.visualize(*exprNode); + else if (auto typeNode = node->asType()) + printer.visualizeTypeAnnotation(*typeNode); + } return writer.str(); } @@ -1206,24 +2666,48 @@ void dump(AstNode* node) printf("%s\n", toString(node).c_str()); } -std::string transpile(AstStatBlock& block) +std::string transpile(AstStatBlock& block, const CstNodeMap& cstNodeMap) { StringWriter writer; - Printer(writer).visualizeBlock(block); + if (FFlag::LuauStoreCSTData) + { + Printer(writer, cstNodeMap).visualizeBlock(block); + } + else + { + Printer_DEPRECATED(writer).visualizeBlock(block); + } + return writer.str(); +} + +std::string transpileWithTypes(AstStatBlock& block, const CstNodeMap& cstNodeMap) +{ + StringWriter writer; + if (FFlag::LuauStoreCSTData) + { + Printer printer(writer, cstNodeMap); + printer.writeTypes = true; + printer.visualizeBlock(block); + } + else + { + Printer_DEPRECATED printer(writer); + printer.writeTypes = true; + printer.visualizeBlock(block); + } return writer.str(); } std::string transpileWithTypes(AstStatBlock& block) { - StringWriter writer; - Printer printer(writer); - printer.writeTypes = true; - printer.visualizeBlock(block); - return writer.str(); + // TODO: remove this interface? + return transpileWithTypes(block, CstNodeMap{nullptr}); } TranspileResult transpile(std::string_view source, ParseOptions options, bool withTypes) { + options.storeCstData = true; + auto allocator = Allocator{}; auto names = AstNameTable{allocator}; ParseResult parseResult = Parser::parse(source.data(), source.size(), names, allocator, options); @@ -1241,9 +2725,9 @@ TranspileResult transpile(std::string_view source, ParseOptions options, bool wi return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"}; if (withTypes) - return TranspileResult{transpileWithTypes(*parseResult.root)}; + return TranspileResult{transpileWithTypes(*parseResult.root, parseResult.cstNodeMap)}; - return TranspileResult{transpile(*parseResult.root)}; + return TranspileResult{transpile(*parseResult.root, parseResult.cstNodeMap)}; } } // namespace Luau diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index bde7751a..e272c661 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -93,8 +93,8 @@ void TxnLog::concatAsIntersections(TxnLog rhs, NotNull arena) if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) { - TypeId leftTy = arena->addType((*leftRep)->pending); - TypeId rightTy = arena->addType(rightRep->pending); + TypeId leftTy = arena->addType((*leftRep)->pending.clone()); + TypeId rightTy = arena->addType(rightRep->pending.clone()); typeVarChanges[ty]->pending.ty = IntersectionType{{leftTy, rightTy}}; } else @@ -170,8 +170,8 @@ void TxnLog::concatAsUnion(TxnLog rhs, NotNull arena) if (auto leftRep = typeVarChanges.find(ty); leftRep && !(*leftRep)->dead) { - TypeId leftTy = arena->addType((*leftRep)->pending); - TypeId rightTy = arena->addType(rightRep->pending); + TypeId leftTy = arena->addType((*leftRep)->pending.clone()); + TypeId rightTy = arena->addType(rightRep->pending.clone()); if (follow(leftTy) == follow(rightTy)) typeVarChanges[ty] = std::move(rightRep); @@ -217,7 +217,7 @@ TxnLog TxnLog::inverse() for (auto& [ty, _rep] : typeVarChanges) { if (!_rep->dead) - inversed.typeVarChanges[ty] = std::make_unique(*ty); + inversed.typeVarChanges[ty] = std::make_unique(ty->clone()); } for (auto& [tp, _rep] : typePackChanges) @@ -292,7 +292,7 @@ PendingType* TxnLog::queue(TypeId ty) auto& pending = typeVarChanges[ty]; if (!pending || (*pending).dead) { - pending = std::make_unique(*ty); + pending = std::make_unique(ty->clone()); pending->pending.owningArena = nullptr; } diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index b024fdd2..bb08856c 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -27,6 +27,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAGVARIABLE(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -478,24 +479,12 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) return false; } -FreeType::FreeType(TypeLevel level) +// New constructors +FreeType::FreeType(TypeLevel level, TypeId lowerBound, TypeId upperBound) : index(Unifiable::freshIndex()) , level(level) - , scope(nullptr) -{ -} - -FreeType::FreeType(Scope* scope) - : index(Unifiable::freshIndex()) - , level{} - , scope(scope) -{ -} - -FreeType::FreeType(Scope* scope, TypeLevel level) - : index(Unifiable::freshIndex()) - , level(level) - , scope(scope) + , lowerBound(lowerBound) + , upperBound(upperBound) { } @@ -507,6 +496,40 @@ FreeType::FreeType(Scope* scope, TypeId lowerBound, TypeId upperBound) { } +FreeType::FreeType(Scope* scope, TypeLevel level, TypeId lowerBound, TypeId upperBound) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) + , lowerBound(lowerBound) + , upperBound(upperBound) +{ +} + +// Old constructors +FreeType::FreeType(TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(nullptr) +{ + LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds); +} + +FreeType::FreeType(Scope* scope) + : index(Unifiable::freshIndex()) + , level{} + , scope(scope) +{ + LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds); +} + +FreeType::FreeType(Scope* scope, TypeLevel level) + : index(Unifiable::freshIndex()) + , level(level) + , scope(scope) +{ + LUAU_ASSERT(!FFlag::LuauFreeTypesMustHaveBounds); +} + GenericType::GenericType() : index(Unifiable::freshIndex()) , name("g" + std::to_string(index)) @@ -554,12 +577,12 @@ BlockedType::BlockedType() { } -Constraint* BlockedType::getOwner() const +const Constraint* BlockedType::getOwner() const { return owner; } -void BlockedType::setOwner(Constraint* newOwner) +void BlockedType::setOwner(const Constraint* newOwner) { LUAU_ASSERT(owner == nullptr); @@ -569,7 +592,7 @@ void BlockedType::setOwner(Constraint* newOwner) owner = newOwner; } -void BlockedType::replaceOwner(Constraint* newOwner) +void BlockedType::replaceOwner(const Constraint* newOwner) { owner = newOwner; } @@ -999,6 +1022,11 @@ Type& Type::operator=(const Type& rhs) return *this; } +Type Type::clone() const +{ + return *this; +} + TypeId makeFunction( TypeArena& arena, std::optional selfType, @@ -1030,6 +1058,7 @@ BuiltinTypes::BuiltinTypes() , unknownType(arena->addType(Type{UnknownType{}, /*persistent*/ true})) , neverType(arena->addType(Type{NeverType{}, /*persistent*/ true})) , errorType(arena->addType(Type{ErrorType{}, /*persistent*/ true})) + , noRefineType(arena->addType(Type{NoRefineType{}, /*persistent*/ true})) , falsyType(arena->addType(Type{UnionType{{falseType, nilType}}, /*persistent*/ true})) , truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true})) , optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true})) @@ -1039,7 +1068,7 @@ BuiltinTypes::BuiltinTypes() , unknownTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{unknownType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true})) - , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) + , errorTypePack(arena->addTypePack(TypePackVar{ErrorTypePack{}, /*persistent*/ true})) { freeze(*arena); } diff --git a/Analysis/src/TypeArena.cpp b/Analysis/src/TypeArena.cpp index 6cf81471..e4e9e293 100644 --- a/Analysis/src/TypeArena.cpp +++ b/Analysis/src/TypeArena.cpp @@ -2,7 +2,8 @@ #include "Luau/TypeArena.h" -LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false); +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena); +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -22,7 +23,34 @@ TypeId TypeArena::addTV(Type&& tv) return allocated; } -TypeId TypeArena::freshType(TypeLevel level) +TypeId TypeArena::freshType(NotNull builtins, TypeLevel level) +{ + TypeId allocated = types.allocate(FreeType{level, builtins->neverType, builtins->unknownType}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(NotNull builtins, Scope* scope) +{ + TypeId allocated = types.allocate(FreeType{scope, builtins->neverType, builtins->unknownType}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(NotNull builtins, Scope* scope, TypeLevel level) +{ + TypeId allocated = types.allocate(FreeType{scope, level, builtins->neverType, builtins->unknownType}); + + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType_DEPRECATED(TypeLevel level) { TypeId allocated = types.allocate(FreeType{level}); @@ -31,7 +59,7 @@ TypeId TypeArena::freshType(TypeLevel level) return allocated; } -TypeId TypeArena::freshType(Scope* scope) +TypeId TypeArena::freshType_DEPRECATED(Scope* scope) { TypeId allocated = types.allocate(FreeType{scope}); @@ -40,7 +68,7 @@ TypeId TypeArena::freshType(Scope* scope) return allocated; } -TypeId TypeArena::freshType(Scope* scope, TypeLevel level) +TypeId TypeArena::freshType_DEPRECATED(Scope* scope, TypeLevel level) { TypeId allocated = types.allocate(FreeType{scope, level}); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index a288cfbe..0d038694 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -145,6 +145,12 @@ public: { return allocator->alloc(Location(), std::nullopt, AstName("any"), std::nullopt, Location()); } + + AstType* operator()(const NoRefineType&) + { + return allocator->alloc(Location(), std::nullopt, AstName("*no-refine*"), std::nullopt, Location()); + } + AstType* operator()(const TableType& ttv) { RecursionCounter counter(&count); @@ -255,24 +261,24 @@ public: if (hasSeen(&ftv)) return allocator->alloc(Location(), std::nullopt, AstName(""), std::nullopt, Location()); - AstArray generics; + AstArray generics; generics.size = ftv.generics.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstGenericType) * generics.size)); + generics.data = static_cast(allocator->allocate(sizeof(AstGenericType) * generics.size)); size_t numGenerics = 0; for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { if (auto gtv = get(*it)) - generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr}; + generics.data[numGenerics++] = allocator->alloc(Location(), AstName(gtv->name.c_str()), nullptr); } - AstArray genericPacks; + AstArray genericPacks; genericPacks.size = ftv.genericPacks.size(); - genericPacks.data = static_cast(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { if (auto gtv = get(*it)) - genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; + genericPacks.data[numGenericPacks++] = allocator->alloc(Location(), AstName(gtv->name.c_str()), nullptr); } AstArray argTypes; @@ -323,7 +329,7 @@ public: Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation} ); } - AstType* operator()(const Unifiable::Error&) + AstType* operator()(const ErrorType&) { return allocator->alloc(Location(), std::nullopt, AstName("Unifiable"), std::nullopt, Location()); } @@ -380,8 +386,12 @@ public: } AstType* operator()(const NegationType& ntv) { - // FIXME: do the same thing we do with ErrorType - throw InternalCompilerError("Cannot convert NegationType into AstNode"); + AstArray params; + params.size = 1; + params.data = static_cast(allocator->allocate(sizeof(AstType*))); + params.data[0] = AstTypeOrPack{Luau::visit(*this, ntv.ty->ty), nullptr}; + + return allocator->alloc(Location(), std::nullopt, AstName("negate"), std::nullopt, Location(), true, params); } AstType* operator()(const TypeFunctionInstanceType& tfit) { @@ -452,7 +462,7 @@ public: return allocator->alloc(Location(), AstName("free")); } - AstTypePack* operator()(const Unifiable::Error&) const + AstTypePack* operator()(const ErrorTypePack&) const { return allocator->alloc(Location(), AstName("Unifiable")); } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 7023fba9..01db570a 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -7,7 +7,6 @@ #include "Luau/DcrLogger.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" -#include "Luau/InsertionOrderedMap.h" #include "Luau/Instantiation.h" #include "Luau/Metamethods.h" #include "Luau/Normalize.h" @@ -27,11 +26,11 @@ #include "Luau/VisitType.h" #include -#include -#include LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) + namespace Luau { @@ -173,7 +172,7 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor DenseHashSet mentionedFunctions{nullptr}; DenseHashSet mentionedFunctionPacks{nullptr}; - InternalTypeFunctionFinder(std::vector& declStack) + explicit InternalTypeFunctionFinder(std::vector& declStack) { TypeFunctionFinder f; for (TypeId fn : declStack) @@ -266,6 +265,8 @@ struct InternalTypeFunctionFinder : TypeOnceVisitor void check( NotNull builtinTypes, + NotNull simplifier, + NotNull typeFunctionRuntime, NotNull unifierState, NotNull limits, DcrLogger* logger, @@ -275,7 +276,7 @@ void check( { LUAU_TIMETRACE_SCOPE("check", "Typechecking"); - TypeChecker2 typeChecker{builtinTypes, unifierState, limits, logger, &sourceModule, module}; + TypeChecker2 typeChecker{builtinTypes, simplifier, typeFunctionRuntime, unifierState, limits, logger, &sourceModule, module}; typeChecker.visit(sourceModule.root); @@ -292,6 +293,8 @@ void check( TypeChecker2::TypeChecker2( NotNull builtinTypes, + NotNull simplifier, + NotNull typeFunctionRuntime, NotNull unifierState, NotNull limits, DcrLogger* logger, @@ -299,13 +302,15 @@ TypeChecker2::TypeChecker2( Module* module ) : builtinTypes(builtinTypes) + , simplifier(simplifier) + , typeFunctionRuntime(typeFunctionRuntime) , logger(logger) , limits(limits) , ice(unifierState->iceHandler) , sourceModule(sourceModule) , module(module) , normalizer{&module->internalTypes, builtinTypes, unifierState, /* cacheInhabitance */ true} - , _subtyping{builtinTypes, NotNull{&module->internalTypes}, NotNull{&normalizer}, NotNull{unifierState->iceHandler}} + , _subtyping{builtinTypes, NotNull{&module->internalTypes}, simplifier, NotNull{&normalizer}, typeFunctionRuntime, NotNull{unifierState->iceHandler}} , subtyping(&_subtyping) { } @@ -483,19 +488,22 @@ TypeId TypeChecker2::checkForTypeFunctionInhabitance(TypeId instance, Location l return instance; seenTypeFunctionInstances.insert(instance); - ErrorVec errors = reduceTypeFunctions( - instance, - location, - TypeFunctionContext{NotNull{&module->internalTypes}, builtinTypes, stack.back(), NotNull{&normalizer}, ice, limits}, - true - ) - .errors; + ErrorVec errors = + reduceTypeFunctions( + instance, + location, + TypeFunctionContext{ + NotNull{&module->internalTypes}, builtinTypes, stack.back(), simplifier, NotNull{&normalizer}, typeFunctionRuntime, ice, limits + }, + true + ) + .errors; if (!isErrorSuppressing(location, instance)) reportErrors(std::move(errors)); return instance; } -TypePackId TypeChecker2::lookupPack(AstExpr* expr) +TypePackId TypeChecker2::lookupPack(AstExpr* expr) const { // If a type isn't in the type graph, it probably means that a recursion limit was exceeded. // We'll just return anyType in these cases. Typechecking against any is very fast and this @@ -545,7 +553,7 @@ TypeId TypeChecker2::lookupAnnotation(AstType* annotation) return checkForTypeFunctionInhabitance(follow(*ty), annotation->location); } -std::optional TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) +std::optional TypeChecker2::lookupPackAnnotation(AstTypePack* annotation) const { TypePackId* tp = module->astResolvedTypePacks.find(annotation); if (tp != nullptr) @@ -553,7 +561,7 @@ std::optional TypeChecker2::lookupPackAnnotation(AstTypePack* annota return {}; } -TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) +TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) const { if (TypeId* ty = module->astExpectedTypes.find(expr)) return follow(*ty); @@ -561,7 +569,7 @@ TypeId TypeChecker2::lookupExpectedType(AstExpr* expr) return builtinTypes->anyType; } -TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) +TypePackId TypeChecker2::lookupExpectedPack(AstExpr* expr, TypeArena& arena) const { if (TypeId* ty = module->astExpectedTypes.find(expr)) return arena.addTypePack(TypePack{{follow(*ty)}, std::nullopt}); @@ -585,7 +593,7 @@ TypePackId TypeChecker2::reconstructPack(AstArray exprs, TypeArena& ar return arena.addTypePack(TypePack{head, tail}); } -Scope* TypeChecker2::findInnermostScope(Location location) +Scope* TypeChecker2::findInnermostScope(Location location) const { Scope* bestScope = module->getModuleScope().get(); @@ -1008,7 +1016,8 @@ void TypeChecker2::visit(AstStatForIn* forInStatement) { reportError(OptionalValueAccess{iteratorTy}, forInStatement->values.data[0]->location); } - else if (std::optional iterMmTy = findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) + else if (std::optional iterMmTy = + findMetatableEntry(builtinTypes, module->errors, iteratorTy, "__iter", forInStatement->values.data[0]->location)) { Instantiation instantiation{TxnLog::empty(), &arena, builtinTypes, TypeLevel{}, scope}; @@ -1193,8 +1202,6 @@ void TypeChecker2::visit(AstStatTypeAlias* stat) void TypeChecker2::visit(AstStatTypeFunction* stat) { // TODO: add type checking for user-defined type functions - - reportError(TypeError{stat->location, GenericError{"This syntax is not supported"}}); } void TypeChecker2::visit(AstTypeList types) @@ -1345,7 +1352,17 @@ void TypeChecker2::visit(AstExprGlobal* expr) { NotNull scope = stack.back(); if (!scope->lookup(expr->name)) + { reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); + } + else + { + if (scope->shouldWarnGlobal(expr->name.value) && !warnedGlobals.contains(expr->name.value)) + { + reportError(UnknownSymbol{expr->name.value, UnknownSymbol::Binding}, expr->location); + warnedGlobals.insert(expr->name.value); + } + } } void TypeChecker2::visit(AstExprVarargs* expr) @@ -1433,10 +1450,11 @@ void TypeChecker2::visitCall(AstExprCall* call) TypePackId argsTp = module->internalTypes.addTypePack(args); if (auto ftv = get(follow(*originalCallTy))) { - if (ftv->dcrMagicTypeCheck) + if (ftv->magic) { - ftv->dcrMagicTypeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope}); - return; + bool usedMagic = ftv->magic->typeCheck(MagicFunctionTypeCheckContext{NotNull{this}, builtinTypes, call, argsTp, scope}); + if (usedMagic) + return; } } @@ -1444,7 +1462,9 @@ void TypeChecker2::visitCall(AstExprCall* call) OverloadResolver resolver{ builtinTypes, NotNull{&module->internalTypes}, + simplifier, NotNull{&normalizer}, + typeFunctionRuntime, NotNull{stack.back()}, ice, limits, @@ -1540,7 +1560,7 @@ void TypeChecker2::visit(AstExprCall* call) visitCall(call); } -std::optional TypeChecker2::tryStripUnionFromNil(TypeId ty) +std::optional TypeChecker2::tryStripUnionFromNil(TypeId ty) const { if (const UnionType* utv = get(ty)) { @@ -1618,8 +1638,7 @@ void TypeChecker2::indexExprMetatableHelper(AstExprIndexExpr* indexExpr, const M indexExprMetatableHelper(indexExpr, mtmt, exprType, indexType); else { - LUAU_ASSERT(tt || get(follow(metaTable->table))); - + // CLI-122161: We're not handling unions correctly (probably). reportError(CannotExtendTable{exprType, CannotExtendTable::Indexer, "indexer??"}, indexExpr->location); } } @@ -1826,11 +1845,10 @@ void TypeChecker2::visit(AstExprFunction* fn) void TypeChecker2::visit(AstExprTable* expr) { - // TODO! for (const AstExprTable::Item& item : expr->items) { if (item.key) - visit(item.key, ValueContext::LValue); + visit(item.key, ValueContext::RValue); visit(item.value, ValueContext::RValue); } } @@ -2078,7 +2096,10 @@ TypeId TypeChecker2::visit(AstExprBinary* expr, AstNode* overrideKey) } else { - expectedRets = module->internalTypes.addTypePack({module->internalTypes.freshType(scope, TypeLevel{})}); + expectedRets = module->internalTypes.addTypePack( + {FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, scope, TypeLevel{}) + : module->internalTypes.freshType_DEPRECATED(scope, TypeLevel{})} + ); } TypeId expectedTy = module->internalTypes.addType(FunctionType(expectedArgs, expectedRets)); @@ -2330,7 +2351,8 @@ TypeId TypeChecker2::flattenPack(TypePackId pack) return *fst; else if (auto ftp = get(pack)) { - TypeId result = module->internalTypes.addType(FreeType{ftp->scope}); + TypeId result = FFlag::LuauFreeTypesMustHaveBounds ? module->internalTypes.freshType(builtinTypes, ftp->scope) + : module->internalTypes.addType(FreeType{ftp->scope}); TypePackId freeTail = module->internalTypes.addTypePack(FreeTypePack{ftp->scope}); TypePack* resultPack = emplaceTypePack(asMutable(pack)); @@ -2339,7 +2361,7 @@ TypeId TypeChecker2::flattenPack(TypePackId pack) return result; } - else if (get(pack)) + else if (get(pack)) return builtinTypes->errorRecoveryType(); else if (finite(pack) && size(pack) == 0) return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil` @@ -2347,30 +2369,30 @@ TypeId TypeChecker2::flattenPack(TypePackId pack) ice->ice("flattenPack got a weird pack!"); } -void TypeChecker2::visitGenerics(AstArray generics, AstArray genericPacks) +void TypeChecker2::visitGenerics(AstArray generics, AstArray genericPacks) { DenseHashSet seen{AstName{}}; - for (const auto& g : generics) + for (const auto* g : generics) { - if (seen.contains(g.name)) - reportError(DuplicateGenericParameter{g.name.value}, g.location); + if (seen.contains(g->name)) + reportError(DuplicateGenericParameter{g->name.value}, g->location); else - seen.insert(g.name); + seen.insert(g->name); - if (g.defaultValue) - visit(g.defaultValue); + if (g->defaultValue) + visit(g->defaultValue); } - for (const auto& g : genericPacks) + for (const auto* g : genericPacks) { - if (seen.contains(g.name)) - reportError(DuplicateGenericParameter{g.name.value}, g.location); + if (seen.contains(g->name)) + reportError(DuplicateGenericParameter{g->name.value}, g->location); else - seen.insert(g.name); + seen.insert(g->name); - if (g.defaultValue) - visit(g.defaultValue); + if (g->defaultValue) + visit(g->defaultValue); } } @@ -2392,6 +2414,8 @@ void TypeChecker2::visit(AstType* ty) return visit(t); else if (auto t = ty->as()) return visit(t); + else if (auto t = ty->as()) + return visit(t->type); } void TypeChecker2::visit(AstTypeReference* ty) @@ -3012,10 +3036,8 @@ PropertyType TypeChecker2::hasIndexTypeFromType( if (tt->indexer) { TypeId indexType = follow(tt->indexer->indexType); - if (isPrim(indexType, PrimitiveType::String)) - return {NormalizationResult::True, {tt->indexer->indexResultType}}; - // If the indexer looks like { [any] : _} - the prop lookup should be allowed! - else if (get(indexType) || get(indexType)) + TypeId givenType = module->internalTypes.addType(SingletonType{StringSingleton{prop}}); + if (isSubtype(givenType, indexType, NotNull{module->getModuleScope().get()}, builtinTypes, simplifier, *ice)) return {NormalizationResult::True, {tt->indexer->indexResultType}}; } diff --git a/Analysis/src/TypeFunction.cpp b/Analysis/src/TypeFunction.cpp index 9ae57fd1..9b5f5ef7 100644 --- a/Analysis/src/TypeFunction.cpp +++ b/Analysis/src/TypeFunction.cpp @@ -2,7 +2,9 @@ #include "Luau/TypeFunction.h" +#include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" +#include "Luau/Compiler.h" #include "Luau/ConstraintSolver.h" #include "Luau/DenseHash.h" #include "Luau/Instantiation.h" @@ -12,17 +14,25 @@ #include "Luau/Set.h" #include "Luau/Simplify.h" #include "Luau/Subtyping.h" +#include "Luau/TimeTrace.h" #include "Luau/ToString.h" #include "Luau/TxnLog.h" #include "Luau/Type.h" #include "Luau/TypeFunctionReductionGuesser.h" +#include "Luau/TypeFunctionRuntime.h" +#include "Luau/TypeFunctionRuntimeBuilder.h" #include "Luau/TypeFwd.h" #include "Luau/TypeUtils.h" #include "Luau/Unifier2.h" #include "Luau/VecDeque.h" #include "Luau/VisitType.h" +#include "lua.h" +#include "lualib.h" + #include +#include +#include // used to control emitting CodeTooComplex warnings on type function reduction LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyGraphReductionMaximumSteps, 1'000'000); @@ -35,7 +45,14 @@ LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyApplicationCartesianProductLimit, 5'0 // when this value is set to a negative value, guessing will be totally disabled. LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFamilyUseGuesserDepth, -1); -LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies, false); +LUAU_FASTFLAGVARIABLE(DebugLuauLogTypeFamilies) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) +LUAU_FASTFLAGVARIABLE(LuauMetatableTypeFunctions) +LUAU_FASTFLAGVARIABLE(LuauClipNestedAndRecursiveUnion) +LUAU_FASTFLAGVARIABLE(LuauDoNotGeneralizeInTypeFunctions) +LUAU_FASTFLAGVARIABLE(LuauPreventReentrantTypeFunctionReduction) +LUAU_FASTFLAGVARIABLE(LuauIntersectNotNil) +LUAU_FASTFLAGVARIABLE(LuauSkipNoRefineDuringRefinement) namespace Luau { @@ -164,7 +181,7 @@ struct TypeFunctionReducer return SkipTestResult::Okay; } - SkipTestResult testForSkippability(TypePackId ty) + SkipTestResult testForSkippability(TypePackId ty) const { ty = follow(ty); @@ -206,23 +223,29 @@ struct TypeFunctionReducer template void handleTypeFunctionReduction(T subject, TypeFunctionReductionResult reduction) { + for (auto& message : reduction.messages) + result.messages.emplace_back(location, UserDefinedTypeFunctionError{std::move(message)}); + if (reduction.result) replace(subject, *reduction.result); else { irreducible.insert(subject); - if (reduction.uninhabited || force) + if (reduction.error.has_value()) + result.errors.emplace_back(location, UserDefinedTypeFunctionError{*reduction.error}); + + if (reduction.reductionStatus != Reduction::MaybeOk || force) { if (FFlag::DebugLuauLogTypeFamilies) printf("%s is uninhabited\n", toString(subject, {true}).c_str()); if constexpr (std::is_same_v) - result.errors.push_back(TypeError{location, UninhabitedTypeFunction{subject}}); + result.errors.emplace_back(location, UninhabitedTypeFunction{subject}); else if constexpr (std::is_same_v) - result.errors.push_back(TypeError{location, UninhabitedTypePackFunction{subject}}); + result.errors.emplace_back(location, UninhabitedTypePackFunction{subject}); } - else if (!reduction.uninhabited && !force) + else if (reduction.reductionStatus == Reduction::MaybeOk && !force) { if (FFlag::DebugLuauLogTypeFamilies) printf( @@ -241,7 +264,7 @@ struct TypeFunctionReducer } } - bool done() + bool done() const { return queuedTys.empty() && queuedTps.empty(); } @@ -359,7 +382,6 @@ struct TypeFunctionReducer return; ctx.userFuncName = tfit->userFuncName; - ctx.userFuncBody = tfit->userFuncBody; TypeFunctionReductionResult result = tfit->function->reducer(subject, tfit->typeArguments, tfit->packArguments, NotNull{&ctx}); handleTypeFunctionReduction(subject, result); @@ -400,6 +422,20 @@ struct TypeFunctionReducer } }; +struct LuauTempThreadPopper +{ + explicit LuauTempThreadPopper(lua_State* L) + : L(L) + { + } + ~LuauTempThreadPopper() + { + lua_pop(L, 1); + } + + lua_State* L = nullptr; +}; + static FunctionGraphReductionResult reduceFunctionsInternal( VecDeque queuedTys, VecDeque queuedTps, @@ -413,19 +449,48 @@ static FunctionGraphReductionResult reduceFunctionsInternal( TypeFunctionReducer reducer{std::move(queuedTys), std::move(queuedTps), std::move(shouldGuess), std::move(cyclics), location, ctx, force}; int iterationCount = 0; - while (!reducer.done()) + if (FFlag::LuauPreventReentrantTypeFunctionReduction) { - reducer.step(); + // If we are reducing a type function while reducing a type function, + // we're probably doing something clowny. One known place this can + // occur is type function reduction => overload selection => subtyping + // => back to type function reduction. At worst, if there's a reduction + // that _doesn't_ loop forever and _needs_ reentrancy, we'll fail to + // handle that and potentially emit an error when we didn't need to. + if (ctx.normalizer->sharedState->reentrantTypeReduction) + return {}; - ++iterationCount; - if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) + TypeReductionRentrancyGuard _{ctx.normalizer->sharedState}; + while (!reducer.done()) { - reducer.result.errors.push_back(TypeError{location, CodeTooComplex{}}); - break; - } - } + reducer.step(); - return std::move(reducer.result); + ++iterationCount; + if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) + { + reducer.result.errors.emplace_back(location, CodeTooComplex{}); + break; + } + } + + return std::move(reducer.result); + } + else + { + while (!reducer.done()) + { + reducer.step(); + + ++iterationCount; + if (iterationCount > DFInt::LuauTypeFamilyGraphReductionMaximumSteps) + { + reducer.result.errors.emplace_back(location, CodeTooComplex{}); + break; + } + } + + return std::move(reducer.result); + } } FunctionGraphReductionResult reduceTypeFunctions(TypeId entrypoint, Location location, TypeFunctionContext ctx, bool force) @@ -498,13 +563,13 @@ static std::optional> tryDistributeTypeFunct ) { // op (a | b) (c | d) ~ (op a (c | d)) | (op b (c | d)) ~ (op a c) | (op a d) | (op b c) | (op b d) - bool uninhabited = false; + Reduction reductionStatus = Reduction::MaybeOk; std::vector blockedTypes; std::vector results; size_t cartesianProductSize = 1; const UnionType* firstUnion = nullptr; - size_t unionIndex; + size_t unionIndex = 0; std::vector arguments = typeParams; for (size_t i = 0; i < arguments.size(); ++i) @@ -527,7 +592,7 @@ static std::optional> tryDistributeTypeFunct // TODO: We'd like to report that the type function application is too complex here. if (size_t(DFInt::LuauTypeFamilyApplicationCartesianProductLimit) <= cartesianProductSize) - return {{std::nullopt, true, {}, {}}}; + return {{std::nullopt, Reduction::Erroneous, {}, {}}}; } if (!firstUnion) @@ -542,21 +607,22 @@ static std::optional> tryDistributeTypeFunct TypeFunctionReductionResult result = f(instance, arguments, packParams, ctx, args...); blockedTypes.insert(blockedTypes.end(), result.blockedTypes.begin(), result.blockedTypes.end()); - uninhabited |= result.uninhabited; + if (result.reductionStatus != Reduction::MaybeOk) + reductionStatus = result.reductionStatus; - if (result.uninhabited || !result.result) + if (reductionStatus != Reduction::MaybeOk || !result.result) break; else results.push_back(*result.result); } - if (uninhabited || !blockedTypes.empty()) - return {{std::nullopt, uninhabited, blockedTypes, {}}}; + if (reductionStatus != Reduction::MaybeOk || !blockedTypes.empty()) + return {{std::nullopt, reductionStatus, blockedTypes, {}}}; if (!results.empty()) { if (results.size() == 1) - return {{results[0], false, {}, {}}}; + return {{results[0], Reduction::MaybeOk, {}, {}}}; TypeId resultTy = ctx->arena->addType(TypeFunctionInstanceType{ NotNull{&builtinTypeFunctions().unionFunc}, @@ -564,7 +630,7 @@ static std::optional> tryDistributeTypeFunct {}, }); - return {{resultTy, false, {}, {}}}; + return {{resultTy, Reduction::MaybeOk, {}, {}}}; } return std::nullopt; @@ -577,15 +643,166 @@ TypeFunctionReductionResult userDefinedTypeFunction( NotNull ctx ) { - if (!ctx->userFuncName || !ctx->userFuncBody) + auto typeFunction = getMutable(instance); + + if (typeFunction->userFuncData.owner.expired()) { - ctx->ice->ice("all user-defined type functions must have an associated function definition"); - return {std::nullopt, true, {}, {}}; + ctx->ice->ice("user-defined type function module has expired"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; } - // TODO: implementation of user-defined type functions goes here + if (!typeFunction->userFuncName || !typeFunction->userFuncData.definition) + { + ctx->ice->ice("all user-defined type functions must have an associated function definition"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; + } - return {std::nullopt, true, {}, {}}; + // If type functions cannot be evaluated because of errors in the code, we do not generate any additional ones + if (!ctx->typeFunctionRuntime->allowEvaluation) + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; + + for (auto typeParam : typeParams) + { + TypeId ty = follow(typeParam); + + // block if we need to + if (isPending(ty, ctx->solver)) + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; + } + + // Ensure that whole type function environment is registered + for (auto& [name, definition] : typeFunction->userFuncData.environment) + { + if (std::optional error = ctx->typeFunctionRuntime->registerFunction(definition.first)) + { + // Failure to register at this point means that original definition had to error out and should not have been present in the + // environment + ctx->ice->ice("user-defined type function reference cannot be registered"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; + } + } + + AstName name = typeFunction->userFuncData.definition->name; + + lua_State* global = ctx->typeFunctionRuntime->state.get(); + + if (global == nullptr) + return {std::nullopt, Reduction::Erroneous, {}, {}, format("'%s' type function: cannot be evaluated in this context", name.value)}; + + // Separate sandboxed thread for individual execution and private globals + lua_State* L = lua_newthread(global); + LuauTempThreadPopper popper(global); + + // Build up the environment table of each function we have visible + for (auto& [_, curr] : typeFunction->userFuncData.environment) + { + // Environment table has to be filled only once in the current execution context + if (ctx->typeFunctionRuntime->initialized.find(curr.first)) + continue; + ctx->typeFunctionRuntime->initialized.insert(curr.first); + + lua_pushlightuserdata(L, curr.first); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + { + ctx->ice->ice("user-defined type function reference cannot be found in the registry"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; + } + + // Build up the environment of the current function, where some might not be visible + lua_getfenv(L, -1); + lua_setreadonly(L, -1, false); + + for (auto& [name, definition] : typeFunction->userFuncData.environment) + { + // Filter visibility based on original scope depth + if (definition.second >= curr.second) + { + lua_pushlightuserdata(L, definition.first); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + break; // Don't have to report an error here, we will visit each function in outer loop + + lua_setfield(L, -2, name.c_str()); + } + } + + lua_setreadonly(L, -1, true); + lua_pop(L, 2); + } + + // Fetch the function we want to evaluate + lua_pushlightuserdata(L, typeFunction->userFuncData.definition); + lua_gettable(L, LUA_REGISTRYINDEX); + + if (!lua_isfunction(L, -1)) + { + ctx->ice->ice("user-defined type function reference cannot be found in the registry"); + return {std::nullopt, Reduction::Erroneous, {}, {}}; + } + + resetTypeFunctionState(L); + + std::unique_ptr runtimeBuilder = std::make_unique(ctx); + + // Push serialized arguments onto the stack + for (auto typeParam : typeParams) + { + TypeId ty = follow(typeParam); + // This is checked at the top of the function, and should still be true. + LUAU_ASSERT(!isPending(ty, ctx->solver)); + + TypeFunctionTypeId serializedTy = serialize(ty, runtimeBuilder.get()); + // Check if there were any errors while serializing + if (runtimeBuilder->errors.size() != 0) + return {std::nullopt, Reduction::Erroneous, {}, {}, runtimeBuilder->errors.front()}; + + allocTypeUserData(L, serializedTy->type); + } + + // Set up an interrupt handler for type functions to respect type checking limits and LSP cancellation requests. + lua_callbacks(L)->interrupt = [](lua_State* L, int gc) + { + auto ctx = static_cast(lua_getthreaddata(lua_mainthread(L))); + if (ctx->limits->finishTime && TimeTrace::getClock() > *ctx->limits->finishTime) + throw TimeLimitError(ctx->ice->moduleName); + + if (ctx->limits->cancellationToken && ctx->limits->cancellationToken->requested()) + throw UserCancelError(ctx->ice->moduleName); + }; + + ctx->typeFunctionRuntime->messages.clear(); + + if (auto error = checkResultForError(L, name.value, lua_pcall(L, int(typeParams.size()), 1, 0))) + return {std::nullopt, Reduction::Erroneous, {}, {}, error, ctx->typeFunctionRuntime->messages}; + + // If the return value is not a type userdata, return with error message + if (!isTypeUserData(L, 1)) + { + return { + std::nullopt, + Reduction::Erroneous, + {}, + {}, + format("'%s' type function: returned a non-type value", name.value), + ctx->typeFunctionRuntime->messages + }; + } + + TypeFunctionTypeId retTypeFunctionTypeId = getTypeUserData(L, 1); + + // No errors should be present here since we should've returned already if any were raised during serialization. + LUAU_ASSERT(runtimeBuilder->errors.size() == 0); + + TypeId retTypeId = deserialize(retTypeFunctionTypeId, runtimeBuilder.get()); + + // At least 1 error occurred while deserializing + if (runtimeBuilder->errors.size() > 0) + return {std::nullopt, Reduction::Erroneous, {}, {}, runtimeBuilder->errors.front(), ctx->typeFunctionRuntime->messages}; + + return {retTypeId, Reduction::MaybeOk, {}, {}, std::nullopt, ctx->typeFunctionRuntime->messages}; } TypeFunctionReductionResult notTypeFunction( @@ -604,16 +821,16 @@ TypeFunctionReductionResult notTypeFunction( TypeId ty = follow(typeParams.at(0)); if (ty == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) return *result; // `not` operates on anything and returns a `boolean` always. - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult lenTypeFunction( @@ -632,19 +849,19 @@ TypeFunctionReductionResult lenTypeFunction( TypeId operandTy = follow(typeParams.at(0)); if (operandTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // check to see if the operand type is resolved enough, and wait to reduce if not // the use of `typeFromNormal` later necessitates blocking on local types. if (isPending(operandTy, ctx->solver)) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy, /* avoidSealingTables */ true); if (!maybeGeneralized) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; operandTy = *maybeGeneralized; } @@ -653,23 +870,23 @@ TypeFunctionReductionResult lenTypeFunction( // if the type failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy || inhabited == NormalizationResult::HitLimits) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if the operand type is error suppressing, we can immediately reduce to `number`. if (normTy->shouldSuppressErrors()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; // # always returns a number, even if its operand is never. // if we're checking the length of a string, that works! if (inhabited == NormalizationResult::False || normTy->isSubtypeOfString()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; // we use the normalized operand here in case there was an intersection or union. - TypeId normalizedOperand = ctx->normalizer->typeFromNormal(*normTy); + TypeId normalizedOperand = follow(ctx->normalizer->typeFromNormal(*normTy)); if (normTy->hasTopTable() || get(normalizedOperand)) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; - if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) + if (auto result = tryDistributeTypeFunctionApp(lenTypeFunction, instance, typeParams, packParams, ctx)) return *result; // findMetatableEntry demands the ability to emit errors, so we must give it @@ -678,35 +895,35 @@ TypeFunctionReductionResult lenTypeFunction( std::optional mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__len", Location{}); if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // `len` must return a `number`. - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult unmTypeFunction( @@ -725,18 +942,18 @@ TypeFunctionReductionResult unmTypeFunction( TypeId operandTy = follow(typeParams.at(0)); if (operandTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // check to see if the operand type is resolved enough, and wait to reduce if not if (isPending(operandTy, ctx->solver)) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, operandTy); if (!maybeGeneralized) - return {std::nullopt, false, {operandTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {operandTy}, {}}; operandTy = *maybeGeneralized; } @@ -744,21 +961,21 @@ TypeFunctionReductionResult unmTypeFunction( // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if the operand is error suppressing, we can just go ahead and reduce. if (normTy->shouldSuppressErrors()) - return {operandTy, false, {}, {}}; + return {operandTy, Reduction::MaybeOk, {}, {}}; // if we have a `never`, we can never observe that the operation didn't work. if (is(operandTy)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // If the type is exactly `number`, we can reduce now. if (normTy->isExactlyNumber()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; - if (auto result = tryDistributeTypeFunctionApp(notTypeFunction, instance, typeParams, packParams, ctx)) + if (auto result = tryDistributeTypeFunctionApp(unmTypeFunction, instance, typeParams, packParams, ctx)) return *result; // findMetatableEntry demands the ability to emit errors, so we must give it @@ -767,40 +984,168 @@ TypeFunctionReductionResult unmTypeFunction( std::optional mmType = findMetatableEntry(ctx->builtins, dummy, operandTy, "__unm", Location{}); if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({operandTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; if (std::optional ret = first(instantiatedMmFtv->retTypes)) - return {*ret, false, {}, {}}; + return {ret, Reduction::MaybeOk, {}, {}}; else - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } -NotNull TypeFunctionContext::pushConstraint(ConstraintV&& c) +void dummyStateClose(lua_State*) {} + +TypeFunctionRuntime::TypeFunctionRuntime(NotNull ice, NotNull limits) + : ice(ice) + , limits(limits) + , state(nullptr, dummyStateClose) +{ +} + +TypeFunctionRuntime::~TypeFunctionRuntime() {} + +std::optional TypeFunctionRuntime::registerFunction(AstStatTypeFunction* function) +{ + // If evaluation is disabled, we do not generate additional error messages + if (!allowEvaluation) + return std::nullopt; + + prepareState(); + + lua_State* global = state.get(); + + // Fetch to check if function is already registered + lua_pushlightuserdata(global, function); + lua_gettable(global, LUA_REGISTRYINDEX); + + if (!lua_isnil(global, -1)) + { + lua_pop(global, 1); + return std::nullopt; + } + + lua_pop(global, 1); + + AstName name = function->name; + + // Construct ParseResult containing the type function + Allocator allocator; + AstNameTable names(allocator); + + AstExpr* exprFunction = function->body; + AstArray exprReturns{&exprFunction, 1}; + AstStatReturn stmtReturn{Location{}, exprReturns}; + AstStat* stmtArray[] = {&stmtReturn}; + AstArray stmts{stmtArray, 1}; + AstStatBlock exec{Location{}, stmts}; + ParseResult parseResult{&exec, 1, {}, {}, {}, CstNodeMap{nullptr}}; + + BytecodeBuilder builder; + try + { + compileOrThrow(builder, parseResult, names); + } + catch (CompileError& e) + { + return format("'%s' type function failed to compile with error message: %s", name.value, e.what()); + } + + std::string bytecode = builder.getBytecode(); + + + // Separate sandboxed thread for individual execution and private globals + lua_State* L = lua_newthread(global); + LuauTempThreadPopper popper(global); + + // Create individual environment for the type function + luaL_sandboxthread(L); + + // Do not allow global writes to that environment + lua_pushvalue(L, LUA_GLOBALSINDEX); + lua_setreadonly(L, -1, true); + lua_pop(L, 1); + + // Load bytecode into Luau state + if (auto error = checkResultForError(L, name.value, luau_load(L, name.value, bytecode.data(), bytecode.size(), 0))) + return error; + + // Execute the global function which should return our user-defined type function + if (auto error = checkResultForError(L, name.value, lua_resume(L, nullptr, 0))) + return error; + + if (!lua_isfunction(L, -1)) + { + lua_pop(L, 1); + return format("Could not find '%s' type function in the global scope", name.value); + } + + // Store resulting function in the registry + lua_pushlightuserdata(global, function); + lua_xmove(L, global, 1); + lua_settable(global, LUA_REGISTRYINDEX); + + return std::nullopt; +} + +void TypeFunctionRuntime::prepareState() +{ + if (state) + return; + + state = StateRef(lua_newstate(typeFunctionAlloc, nullptr), lua_close); + lua_State* L = state.get(); + + lua_setthreaddata(L, this); + + setTypeFunctionEnvironment(L); + + registerTypeUserData(L); + + registerTypesLibrary(L); + + luaL_sandbox(L); + luaL_sandboxthread(L); +} + +TypeFunctionContext::TypeFunctionContext(NotNull cs, NotNull scope, NotNull constraint) + : arena(cs->arena) + , builtins(cs->builtinTypes) + , scope(scope) + , simplifier(cs->simplifier) + , normalizer(cs->normalizer) + , typeFunctionRuntime(cs->typeFunctionRuntime) + , ice(NotNull{&cs->iceReporter}) + , limits(NotNull{&cs->limits}) + , solver(cs.get()) + , constraint(constraint.get()) +{ +} + +NotNull TypeFunctionContext::pushConstraint(ConstraintV&& c) const { LUAU_ASSERT(solver); NotNull newConstraint = solver->pushConstraint(scope, constraint ? constraint->location : Location{}, std::move(c)); @@ -832,30 +1177,30 @@ TypeFunctionReductionResult numericBinopTypeFunction( // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. if (lhsTy == instance || rhsTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // if we have a `never`, we can never observe that the math operator is unreachable. if (is(lhsTy) || is(rhsTy)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; const Location location = ctx->constraint ? ctx->constraint->location : Location{}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -867,15 +1212,15 @@ TypeFunctionReductionResult numericBinopTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can reduce to `any` since we should suppress errors in the result of the usage. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->anyType, false, {}, {}}; + return {ctx->builtins->anyType, Reduction::MaybeOk, {}, {}}; // if we're adding two `number` types, the result is `number`. if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) - return {ctx->builtins->numberType, false, {}, {}}; + return {ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(numericBinopTypeFunction, instance, typeParams, packParams, ctx, metamethod)) return *result; @@ -893,32 +1238,56 @@ TypeFunctionReductionResult numericBinopTypeFunction( } if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; TypePackId argPack = ctx->arena->addTypePack({lhsTy, rhsTy}); SolveResult solveResult; if (!reversed) - solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); + solveResult = solveFunctionCall( + ctx->arena, + ctx->builtins, + ctx->simplifier, + ctx->normalizer, + ctx->typeFunctionRuntime, + ctx->ice, + ctx->limits, + ctx->scope, + location, + *mmType, + argPack + ); else { TypePack* p = getMutable(argPack); std::swap(p->head.front(), p->head.back()); - solveResult = solveFunctionCall(ctx->arena, ctx->builtins, ctx->normalizer, ctx->ice, ctx->limits, ctx->scope, location, *mmType, argPack); + solveResult = solveFunctionCall( + ctx->arena, + ctx->builtins, + ctx->simplifier, + ctx->normalizer, + ctx->typeFunctionRuntime, + ctx->ice, + ctx->limits, + ctx->scope, + location, + *mmType, + argPack + ); } if (!solveResult.typePackId.has_value()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; TypePack extracted = extendTypePack(*ctx->arena, ctx->builtins, *solveResult.typePackId, 1); if (extracted.head.empty()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {extracted.head.front(), false, {}, {}}; + return {extracted.head.front(), Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult addTypeFunction( @@ -1051,24 +1420,24 @@ TypeFunctionReductionResult concatTypeFunction( // isPending of `lhsTy` or `rhsTy` would return true, even if it cycles. We want a different answer for that. if (lhsTy == instance || rhsTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1079,19 +1448,19 @@ TypeFunctionReductionResult concatTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can reduce to `any` since we should suppress errors in the result of the usage. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->anyType, false, {}, {}}; + return {ctx->builtins->anyType, Reduction::MaybeOk, {}, {}}; - // if we have a `never`, we can never observe that the numeric operator didn't work. + // if we have a `never`, we can never observe that the operator didn't work. if (is(lhsTy) || is(rhsTy)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // if we're concatenating two elements that are either strings or numbers, the result is `string`. if ((normLhsTy->isSubtypeOfString() || normLhsTy->isExactlyNumber()) && (normRhsTy->isSubtypeOfString() || normRhsTy->isExactlyNumber())) - return {ctx->builtins->stringType, false, {}, {}}; + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(concatTypeFunction, instance, typeParams, packParams, ctx)) return *result; @@ -1109,23 +1478,23 @@ TypeFunctionReductionResult concatTypeFunction( } if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; std::vector inferredArgs; if (!reversed) @@ -1136,13 +1505,13 @@ TypeFunctionReductionResult concatTypeFunction( TypePackId inferredArgPack = ctx->arena->addTypePack(std::move(inferredArgs)); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {ctx->builtins->stringType, false, {}, {}}; + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult andTypeFunction( @@ -1163,27 +1532,27 @@ TypeFunctionReductionResult andTypeFunction( // t1 = and ~> lhs if (follow(rhsTy) == instance && lhsTy != rhsTy) - return {lhsTy, false, {}, {}}; + return {lhsTy, Reduction::MaybeOk, {}, {}}; // t1 = and ~> rhs if (follow(lhsTy) == instance && lhsTy != rhsTy) - return {rhsTy, false, {}, {}}; + return {rhsTy, Reduction::MaybeOk, {}, {}}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1197,7 +1566,7 @@ TypeFunctionReductionResult andTypeFunction( blockedTypes.push_back(ty); for (auto ty : overallResult.blockedTypes) blockedTypes.push_back(ty); - return {overallResult.result, false, std::move(blockedTypes), {}}; + return {overallResult.result, Reduction::MaybeOk, std::move(blockedTypes), {}}; } TypeFunctionReductionResult orTypeFunction( @@ -1218,27 +1587,27 @@ TypeFunctionReductionResult orTypeFunction( // t1 = or ~> lhs if (follow(rhsTy) == instance && lhsTy != rhsTy) - return {lhsTy, false, {}, {}}; + return {lhsTy, Reduction::MaybeOk, {}, {}}; // t1 = or ~> rhs if (follow(lhsTy) == instance && lhsTy != rhsTy) - return {rhsTy, false, {}, {}}; + return {rhsTy, Reduction::MaybeOk, {}, {}}; // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1252,7 +1621,7 @@ TypeFunctionReductionResult orTypeFunction( blockedTypes.push_back(ty); for (auto ty : overallResult.blockedTypes) blockedTypes.push_back(ty); - return {overallResult.result, false, std::move(blockedTypes), {}}; + return {overallResult.result, Reduction::MaybeOk, std::move(blockedTypes), {}}; } static TypeFunctionReductionResult comparisonTypeFunction( @@ -1274,12 +1643,12 @@ static TypeFunctionReductionResult comparisonTypeFunction( TypeId rhsTy = follow(typeParams.at(1)); if (lhsTy == instance || rhsTy == instance) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // Algebra Reduction Rules for comparison type functions // Note that comparing to never tells you nothing about the other operand @@ -1316,15 +1685,15 @@ static TypeFunctionReductionResult comparisonTypeFunction( rhsTy = follow(rhsTy); // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1339,23 +1708,23 @@ static TypeFunctionReductionResult comparisonTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy || lhsInhabited == NormalizationResult::HitLimits || rhsInhabited == NormalizationResult::HitLimits) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can just go ahead and reduce. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // if we have an uninhabited type (e.g. `never`), we can never observe that the comparison didn't work. if (lhsInhabited == NormalizationResult::False || rhsInhabited == NormalizationResult::False) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // If both types are some strict subset of `string`, we can reduce now. if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // If both types are exactly `number`, we can reduce now. if (normLhsTy->isExactlyNumber() && normRhsTy->isExactlyNumber()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; if (auto result = tryDistributeTypeFunctionApp(comparisonTypeFunction, instance, typeParams, packParams, ctx, metamethod)) return *result; @@ -1369,34 +1738,34 @@ static TypeFunctionReductionResult comparisonTypeFunction( mmType = findMetatableEntry(ctx->builtins, dummy, rhsTy, metamethod, Location{}); if (!mmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({lhsTy, rhsTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult ltTypeFunction( @@ -1449,20 +1818,20 @@ TypeFunctionReductionResult eqTypeFunction( // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(lhsTy, ctx->solver)) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (isPending(rhsTy, ctx->solver)) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; // if either type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional lhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, lhsTy); std::optional rhsMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, rhsTy); if (!lhsMaybeGeneralized) - return {std::nullopt, false, {lhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {lhsTy}, {}}; else if (!rhsMaybeGeneralized) - return {std::nullopt, false, {rhsTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {rhsTy}, {}}; lhsTy = *lhsMaybeGeneralized; rhsTy = *rhsMaybeGeneralized; @@ -1475,15 +1844,15 @@ TypeFunctionReductionResult eqTypeFunction( // if either failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normLhsTy || !normRhsTy || lhsInhabited == NormalizationResult::HitLimits || rhsInhabited == NormalizationResult::HitLimits) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if one of the types is error suppressing, we can just go ahead and reduce. if (normLhsTy->shouldSuppressErrors() || normRhsTy->shouldSuppressErrors()) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // if we have a `never`, we can never observe that the comparison didn't work. if (lhsInhabited == NormalizationResult::False || rhsInhabited == NormalizationResult::False) - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // findMetatableEntry demands the ability to emit errors, so we must give it // the necessary state to do that, even if we intend to just eat the errors. @@ -1498,49 +1867,49 @@ TypeFunctionReductionResult eqTypeFunction( if (!mmType) { if (intersectInhabited == NormalizationResult::True) - return {ctx->builtins->booleanType, false, {}, {}}; // if it's inhabited, everything is okay! + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; // if it's inhabited, everything is okay! // we might be in a case where we still want to accept the comparison... if (intersectInhabited == NormalizationResult::False) { // if they're both subtypes of `string` but have no common intersection, the comparison is allowed but always `false`. if (normLhsTy->isSubtypeOfString() && normRhsTy->isSubtypeOfString()) - return {ctx->builtins->falseType, false, {}, {}}; + return {ctx->builtins->falseType, Reduction::MaybeOk, {}, {}}; // if they're both subtypes of `boolean` but have no common intersection, the comparison is allowed but always `false`. if (normLhsTy->isSubtypeOfBooleans() && normRhsTy->isSubtypeOfBooleans()) - return {ctx->builtins->falseType, false, {}, {}}; + return {ctx->builtins->falseType, Reduction::MaybeOk, {}, {}}; } - return {std::nullopt, true, {}, {}}; // if it's not, then this type function is irreducible! + return {std::nullopt, Reduction::Erroneous, {}, {}}; // if it's not, then this type function is irreducible! } mmType = follow(*mmType); if (isPending(*mmType, ctx->solver)) - return {std::nullopt, false, {*mmType}, {}}; + return {std::nullopt, Reduction::MaybeOk, {*mmType}, {}}; const FunctionType* mmFtv = get(*mmType); if (!mmFtv) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; std::optional instantiatedMmType = instantiate(ctx->builtins, ctx->arena, ctx->limits, ctx->scope, *mmType); if (!instantiatedMmType) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; const FunctionType* instantiatedMmFtv = get(*instantiatedMmType); if (!instantiatedMmFtv) - return {ctx->builtins->errorRecoveryType(), false, {}, {}}; + return {ctx->builtins->errorRecoveryType(), Reduction::MaybeOk, {}, {}}; TypePackId inferredArgPack = ctx->arena->addTypePack({lhsTy, rhsTy}); Unifier2 u2{ctx->arena, ctx->builtins, ctx->scope, ctx->ice}; if (!u2.unify(inferredArgPack, instantiatedMmFtv->argTypes)) - return {std::nullopt, true, {}, {}}; // occurs check failed + return {std::nullopt, Reduction::Erroneous, {}, {}}; // occurs check failed - Subtyping subtyping{ctx->builtins, ctx->arena, ctx->normalizer, ctx->ice}; + Subtyping subtyping{ctx->builtins, ctx->arena, ctx->simplifier, ctx->normalizer, ctx->typeFunctionRuntime, ctx->ice}; if (!subtyping.isSubtype(inferredArgPack, instantiatedMmFtv->argTypes, ctx->scope).isSubtype) // TODO: is this the right variance? - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; - return {ctx->builtins->booleanType, false, {}, {}}; + return {ctx->builtins->booleanType, Reduction::MaybeOk, {}, {}}; } // Collect types that prevent us from reducing a particular refinement. @@ -1565,7 +1934,6 @@ struct FindRefinementBlockers : TypeOnceVisitor } }; - TypeFunctionReductionResult refineTypeFunction( TypeId instance, const std::vector& typeParams, @@ -1586,13 +1954,13 @@ TypeFunctionReductionResult refineTypeFunction( // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(targetTy, ctx->solver)) - return {std::nullopt, false, {targetTy}, {}}; + return {std::nullopt, Reduction::MaybeOk, {targetTy}, {}}; else { for (auto t : discriminantTypes) { if (isPending(t, ctx->solver)) - return {std::nullopt, false, {t}, {}}; + return {std::nullopt, Reduction::MaybeOk, {t}, {}}; } } // Refine a target type and a discriminant one at a time. @@ -1600,7 +1968,7 @@ TypeFunctionReductionResult refineTypeFunction( auto stepRefine = [&ctx](TypeId target, TypeId discriminant) -> std::pair> { std::vector toBlock; - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional targetMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, target); std::optional discriminantMaybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, discriminant); @@ -1621,47 +1989,60 @@ TypeFunctionReductionResult refineTypeFunction( if (!frb.found.empty()) return {nullptr, {frb.found.begin(), frb.found.end()}}; - /* HACK: Refinements sometimes produce a type T & ~any under the assumption - * that ~any is the same as any. This is so so weird, but refinements needs - * some way to say "I may refine this, but I'm not sure." - * - * It does this by refining on a blocked type and deferring the decision - * until it is unblocked. - * - * Refinements also get negated, so we wind up with types like T & ~*blocked* - * - * We need to treat T & ~any as T in this case. - */ - if (auto nt = get(discriminant)) - if (get(follow(nt->ty))) - return {target, {}}; - - // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the - // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. - if (get(target)) + if (FFlag::DebugLuauEqSatSimplification) { - SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, target, discriminant); - if (!result.blockedTypes.empty()) - return {nullptr, {result.blockedTypes.begin(), result.blockedTypes.end()}}; + auto simplifyResult = eqSatSimplify(ctx->simplifier, ctx->arena->addType(IntersectionType{{target, discriminant}})); + if (simplifyResult) + { + if (ctx->solver) + { + for (TypeId newTf : simplifyResult->newTypeFunctions) + ctx->pushConstraint(ReduceConstraint{newTf}); + } - return {result.result, {}}; + return {simplifyResult->result, {}}; + } + else + return {nullptr, {}}; } + else + { + if (FFlag::LuauSkipNoRefineDuringRefinement) + if (get(discriminant)) + return {target, {}}; + if (auto nt = get(discriminant)) + { + if (get(follow(nt->ty))) + return {target, {}}; + } - // In the general case, we'll still use normalization though. - TypeId intersection = ctx->arena->addType(IntersectionType{{target, discriminant}}); - std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); - std::shared_ptr normType = ctx->normalizer->normalize(target); + // If the target type is a table, then simplification already implements the logic to deal with refinements properly since the + // type of the discriminant is guaranteed to only ever be an (arbitrarily-nested) table of a single property type. + if (get(target)) + { + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, target, discriminant); + if (!result.blockedTypes.empty()) + return {nullptr, {result.blockedTypes.begin(), result.blockedTypes.end()}}; - // if the intersection failed to normalize, we can't reduce, but know nothing about inhabitance. - if (!normIntersection || !normType) - return {nullptr, {}}; + return {result.result, {}}; + } - TypeId resultTy = ctx->normalizer->typeFromNormal(*normIntersection); - // include the error type if the target type is error-suppressing and the intersection we computed is not - if (normType->shouldSuppressErrors() && !normIntersection->shouldSuppressErrors()) - resultTy = ctx->arena->addType(UnionType{{resultTy, ctx->builtins->errorType}}); + // In the general case, we'll still use normalization though. + TypeId intersection = ctx->arena->addType(IntersectionType{{target, discriminant}}); + std::shared_ptr normIntersection = ctx->normalizer->normalize(intersection); + std::shared_ptr normType = ctx->normalizer->normalize(target); - return {resultTy, {}}; + // if the intersection failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!normIntersection || !normType) + return {nullptr, {}}; + + TypeId resultTy = ctx->normalizer->typeFromNormal(*normIntersection); + // include the error type if the target type is error-suppressing and the intersection we computed is not + if (normType->shouldSuppressErrors() && !normIntersection->shouldSuppressErrors()) + resultTy = ctx->arena->addType(UnionType{{resultTy, ctx->builtins->errorType}}); + + return {resultTy, {}}; + } }; // refine target with each discriminant type in sequence (reverse of insertion order) @@ -1674,15 +2055,15 @@ TypeFunctionReductionResult refineTypeFunction( auto [refined, blocked] = stepRefine(target, discriminant); if (blocked.empty() && refined == nullptr) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; if (!blocked.empty()) - return {std::nullopt, false, blocked, {}}; + return {std::nullopt, Reduction::MaybeOk, blocked, {}}; target = refined; discriminantTypes.pop_back(); } - return {target, false, {}, {}}; + return {target, Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult singletonTypeFunction( @@ -1702,14 +2083,14 @@ TypeFunctionReductionResult singletonTypeFunction( // check to see if both operand types are resolved enough, and wait to reduce if not if (isPending(type, ctx->solver)) - return {std::nullopt, false, {type}, {}}; + return {std::nullopt, Reduction::MaybeOk, {type}, {}}; // if the type is free but has only one remaining reference, we can generalize it to its upper bound here. - if (ctx->solver) + if (ctx->solver && !FFlag::LuauDoNotGeneralizeInTypeFunctions) { std::optional maybeGeneralized = ctx->solver->generalizeFreeType(ctx->scope, type); if (!maybeGeneralized) - return {std::nullopt, false, {type}, {}}; + return {std::nullopt, Reduction::MaybeOk, {type}, {}}; type = *maybeGeneralized; } @@ -1720,12 +2101,49 @@ TypeFunctionReductionResult singletonTypeFunction( // if we have a singleton type or `nil`, which is its own singleton type... if (get(followed) || isNil(followed)) - return {type, false, {}, {}}; + return {type, Reduction::MaybeOk, {}, {}}; // otherwise, we'll return the top type, `unknown`. - return {ctx->builtins->unknownType, false, {}, {}}; + return {ctx->builtins->unknownType, Reduction::MaybeOk, {}, {}}; } +struct CollectUnionTypeOptions : TypeOnceVisitor +{ + NotNull ctx; + DenseHashSet options{nullptr}; + DenseHashSet blockingTypes{nullptr}; + + explicit CollectUnionTypeOptions(NotNull ctx) + : TypeOnceVisitor(/* skipBoundTypes */ true) + , ctx(ctx) + { + } + + bool visit(TypeId ty) override + { + options.insert(ty); + if (isPending(ty, ctx->solver)) + blockingTypes.insert(ty); + return false; + } + + bool visit(TypePackId tp) override + { + return false; + } + + bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) override + { + if (tfit.function->name != builtinTypeFunctions().unionFunc.name) + { + options.insert(ty); + blockingTypes.insert(ty); + return false; + } + return true; + } +}; + TypeFunctionReductionResult unionTypeFunction( TypeId instance, const std::vector& typeParams, @@ -1741,7 +2159,36 @@ TypeFunctionReductionResult unionTypeFunction( // if we only have one parameter, there's nothing to do. if (typeParams.size() == 1) - return {follow(typeParams[0]), false, {}, {}}; + return {follow(typeParams[0]), Reduction::MaybeOk, {}, {}}; + + if (FFlag::LuauClipNestedAndRecursiveUnion) + { + + CollectUnionTypeOptions collector{ctx}; + collector.traverse(instance); + + if (!collector.blockingTypes.empty()) + { + std::vector blockingTypes{collector.blockingTypes.begin(), collector.blockingTypes.end()}; + return {std::nullopt, Reduction::MaybeOk, std::move(blockingTypes), {}}; + } + + TypeId resultTy = ctx->builtins->neverType; + for (auto ty : collector.options) + { + SimplifyResult result = simplifyUnion(ctx->builtins, ctx->arena, resultTy, ty); + // This condition might fire if one of the arguments to this type + // function is a free type somewhere deep in a nested union or + // intersection type, even though we ran a pass above to capture + // some blocked types. + if (!result.blockedTypes.empty()) + return {std::nullopt, Reduction::MaybeOk, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + resultTy = result.result; + } + + return {resultTy, Reduction::MaybeOk, {}, {}}; + } // we need to follow all of the type parameters. std::vector types; @@ -1769,12 +2216,12 @@ TypeFunctionReductionResult unionTypeFunction( // if we still have a `lastType` at the end, we're taking the short-circuit and reducing early. if (lastType) - return {lastType, false, {}, {}}; + return {lastType, Reduction::MaybeOk, {}, {}}; // check to see if the operand types are resolved enough, and wait to reduce if not for (auto ty : types) if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; // fold over the types with `simplifyUnion` TypeId resultTy = ctx->builtins->neverType; @@ -1782,12 +2229,12 @@ TypeFunctionReductionResult unionTypeFunction( { SimplifyResult result = simplifyUnion(ctx->builtins, ctx->arena, resultTy, ty); if (!result.blockedTypes.empty()) - return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + return {std::nullopt, Reduction::MaybeOk, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; resultTy = result.result; } - return {resultTy, false, {}, {}}; + return {resultTy, Reduction::MaybeOk, {}, {}}; } @@ -1806,7 +2253,7 @@ TypeFunctionReductionResult intersectTypeFunction( // if we only have one parameter, there's nothing to do. if (typeParams.size() == 1) - return {follow(typeParams[0]), false, {}, {}}; + return {follow(typeParams[0]), Reduction::MaybeOk, {}, {}}; // we need to follow all of the type parameters. std::vector types; @@ -1814,23 +2261,45 @@ TypeFunctionReductionResult intersectTypeFunction( for (auto ty : typeParams) types.emplace_back(follow(ty)); + // if we only have two parameters and one is `*no-refine*`, we're all done. + if (types.size() == 2 && get(types[1])) + return {types[0], Reduction::MaybeOk, {}, {}}; + else if (types.size() == 2 && get(types[0])) + return {types[1], Reduction::MaybeOk, {}, {}}; + // check to see if the operand types are resolved enough, and wait to reduce if not // if any of them are `never`, the intersection will always be `never`, so we can reduce directly. for (auto ty : types) { if (isPending(ty, ctx->solver)) - return {std::nullopt, false, {ty}, {}}; + return {std::nullopt, Reduction::MaybeOk, {ty}, {}}; else if (get(ty)) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; } // fold over the types with `simplifyIntersection` TypeId resultTy = ctx->builtins->unknownType; for (auto ty : types) { + // skip any `*no-refine*` types. + if (get(ty)) + continue; + SimplifyResult result = simplifyIntersection(ctx->builtins, ctx->arena, resultTy, ty); - if (!result.blockedTypes.empty()) - return {std::nullopt, false, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + + if (FFlag::LuauIntersectNotNil) + { + for (TypeId blockedType : result.blockedTypes) + { + if (!get(blockedType)) + return {std::nullopt, Reduction::MaybeOk, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + } + } + else + { + if (!result.blockedTypes.empty()) + return {std::nullopt, Reduction::MaybeOk, {result.blockedTypes.begin(), result.blockedTypes.end()}, {}}; + } resultTy = result.result; } @@ -1841,10 +2310,10 @@ TypeFunctionReductionResult intersectTypeFunction( if (get(resultTy)) { TypeId intersection = ctx->arena->addType(IntersectionType{typeParams}); - return {intersection, false, {}, {}}; + return {intersection, Reduction::MaybeOk, {}, {}}; } - return {resultTy, false, {}, {}}; + return {resultTy, Reduction::MaybeOk, {}, {}}; } // computes the keys of `ty` into `result` @@ -1944,17 +2413,17 @@ TypeFunctionReductionResult keyofFunctionImpl( // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. if (!normTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if we don't have either just tables or just classes, we've got nothing to get keys of (at least until a future version perhaps adds classes // as well) if (normTy->hasTables() == normTy->hasClasses()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // this is sort of atrocious, but we're trying to reject any type that has not normalized to a table or a union of tables. if (normTy->hasTops() || normTy->hasBooleans() || normTy->hasErrors() || normTy->hasNils() || normTy->hasNumbers() || normTy->hasStrings() || normTy->hasThreads() || normTy->hasBuffers() || normTy->hasFunctions() || normTy->hasTyvars()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // we're going to collect the keys in here Set keys{{}}; @@ -1973,7 +2442,7 @@ TypeFunctionReductionResult keyofFunctionImpl( // collect all the properties from the first class type if (!computeKeysOf(*classesIter, keys, seen, isRaw, ctx)) - return {ctx->builtins->stringType, false, {}, {}}; // if it failed, we have a top type! + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; // if it failed, we have a top type! // we need to look at each class to remove any keys that are not common amongst them all while (++classesIter != classesIterEnd) @@ -1986,7 +2455,7 @@ TypeFunctionReductionResult keyofFunctionImpl( if (!computeKeysOf(*classesIter, localKeys, seen, isRaw, ctx)) continue; - for (auto key : keys) + for (auto& key : keys) { // remove any keys that are not present in each class if (!localKeys.contains(key)) @@ -2008,7 +2477,7 @@ TypeFunctionReductionResult keyofFunctionImpl( // collect all the properties from the first table type if (!computeKeysOf(*tablesIter, keys, seen, isRaw, ctx)) - return {ctx->builtins->stringType, false, {}, {}}; // if it failed, we have the top table type! + return {ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; // if it failed, we have the top table type! // we need to look at each tables to remove any keys that are not common amongst them all while (++tablesIter != normTy->tables.end()) @@ -2021,7 +2490,7 @@ TypeFunctionReductionResult keyofFunctionImpl( if (!computeKeysOf(*tablesIter, localKeys, seen, isRaw, ctx)) continue; - for (auto key : keys) + for (auto& key : keys) { // remove any keys that are not present in each table if (!localKeys.contains(key)) @@ -2032,7 +2501,7 @@ TypeFunctionReductionResult keyofFunctionImpl( // if the set of keys is empty, `keyof` is `never` if (keys.empty()) - return {ctx->builtins->neverType, false, {}, {}}; + return {ctx->builtins->neverType, Reduction::MaybeOk, {}, {}}; // everything is validated, we need only construct our big union of singletons now! std::vector singletons; @@ -2045,9 +2514,9 @@ TypeFunctionReductionResult keyofFunctionImpl( // We can take straight take it from the first entry // because it was added into the type arena already. if (singletons.size() == 1) - return {singletons.front(), false, {}, {}}; + return {singletons.front(), Reduction::MaybeOk, {}, {}}; - return {ctx->arena->addType(UnionType{singletons}), false, {}, {}}; + return {ctx->arena->addType(UnionType{singletons}), Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult keyofTypeFunction( @@ -2118,7 +2587,7 @@ bool searchPropsAndIndexer( // index into tbl's indexer if (tblIndexer) { - if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, *ctx->ice)) + if (isSubtype(ty, tblIndexer->indexType, ctx->scope, ctx->builtins, ctx->simplifier, *ctx->ice)) { TypeId idxResultTy = follow(tblIndexer->indexResultType); @@ -2193,37 +2662,35 @@ TypeFunctionReductionResult indexFunctionImpl( // if the indexee failed to normalize, we can't reduce, but know nothing about inhabitance. if (!indexeeNormTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // if we don't have either just tables or just classes, we've got nothing to index into if (indexeeNormTy->hasTables() == indexeeNormTy->hasClasses()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // we're trying to reject any type that has not normalized to a table/class or a union of tables/classes. if (indexeeNormTy->hasTops() || indexeeNormTy->hasBooleans() || indexeeNormTy->hasErrors() || indexeeNormTy->hasNils() || indexeeNormTy->hasNumbers() || indexeeNormTy->hasStrings() || indexeeNormTy->hasThreads() || indexeeNormTy->hasBuffers() || indexeeNormTy->hasFunctions() || indexeeNormTy->hasTyvars()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; TypeId indexerTy = follow(typeParams.at(1)); if (isPending(indexerTy, ctx->solver)) - { - return {std::nullopt, false, {indexerTy}, {}}; - } + return {std::nullopt, Reduction::MaybeOk, {indexerTy}, {}}; std::shared_ptr indexerNormTy = ctx->normalizer->normalize(indexerTy); // if the indexer failed to normalize, we can't reduce, but know nothing about inhabitance. if (!indexerNormTy) - return {std::nullopt, false, {}, {}}; + return {std::nullopt, Reduction::MaybeOk, {}, {}}; // we're trying to reject any type that is not a string singleton or primitive (string, number, boolean, thread, nil, function, table, or buffer) if (indexerNormTy->hasTops() || indexerNormTy->hasErrors()) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // indexer can be a union —> break them down into a vector - const std::vector* typesToFind; + const std::vector* typesToFind = nullptr; const std::vector singleType{indexerTy}; if (auto unionTy = get(indexerTy)) typesToFind = &unionTy->options; @@ -2237,7 +2704,7 @@ TypeFunctionReductionResult indexFunctionImpl( LUAU_ASSERT(!indexeeNormTy->hasTables()); if (isRaw) // rawget should never reduce for classes (to match the behavior of the rawget global function) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; // at least one class is guaranteed to be in the iterator by .hasClasses() for (auto classesIter = indexeeNormTy->classes.ordering.begin(); classesIter != indexeeNormTy->classes.ordering.end(); ++classesIter) @@ -2246,7 +2713,7 @@ TypeFunctionReductionResult indexFunctionImpl( if (!classTy) { LUAU_ASSERT(false); // this should not be possible according to normalization's spec - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } for (TypeId ty : *typesToFind) @@ -2275,10 +2742,10 @@ TypeFunctionReductionResult indexFunctionImpl( ErrorVec dummy; std::optional mmType = findMetatableEntry(ctx->builtins, dummy, *classesIter, "__index", Location{}); if (!mmType) // if a metatable does not exist, there is no where else to look - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; if (!tblIndexInto(ty, *mmType, properties, ctx, isRaw)) // if indexer is not in the metatable, we fail to reduce - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } } } @@ -2292,7 +2759,7 @@ TypeFunctionReductionResult indexFunctionImpl( { for (TypeId ty : *typesToFind) if (!tblIndexInto(ty, *tablesIter, properties, ctx, isRaw)) - return {std::nullopt, true, {}, {}}; + return {std::nullopt, Reduction::Erroneous, {}, {}}; } } @@ -2309,9 +2776,9 @@ TypeFunctionReductionResult indexFunctionImpl( // If the type being reduced to is a single type, no need to union if (properties.size() == 1) - return {*properties.begin(), false, {}, {}}; + return {*properties.begin(), Reduction::MaybeOk, {}, {}}; - return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), false, {}, {}}; + return {ctx->arena->addType(UnionType{std::vector(properties.begin(), properties.end())}), Reduction::MaybeOk, {}, {}}; } TypeFunctionReductionResult indexTypeFunction( @@ -2346,6 +2813,211 @@ TypeFunctionReductionResult rawgetTypeFunction( return indexFunctionImpl(typeParams, packParams, ctx, /* isRaw */ true); } +TypeFunctionReductionResult setmetatableTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 2 || !packParams.empty()) + { + ctx->ice->ice("setmetatable type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + const Location location = ctx->constraint ? ctx->constraint->location : Location{}; + + TypeId targetTy = follow(typeParams.at(0)); + TypeId metatableTy = follow(typeParams.at(1)); + + std::shared_ptr targetNorm = ctx->normalizer->normalize(targetTy); + + // if the operand failed to normalize, we can't reduce, but know nothing about inhabitance. + if (!targetNorm) + return {std::nullopt, Reduction::MaybeOk, {}, {}}; + + // cannot setmetatable on something without table parts. + if (!targetNorm->hasTables()) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + // we're trying to reject any type that has not normalized to a table or a union/intersection of tables. + if (targetNorm->hasTops() || targetNorm->hasBooleans() || targetNorm->hasErrors() || targetNorm->hasNils() || targetNorm->hasNumbers() || + targetNorm->hasStrings() || targetNorm->hasThreads() || targetNorm->hasBuffers() || targetNorm->hasFunctions() || targetNorm->hasTyvars() || + targetNorm->hasClasses()) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + // if the supposed metatable is not a table, we will fail to reduce. + if (!get(metatableTy) && !get(metatableTy)) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + if (targetNorm->tables.size() == 1) + { + TypeId table = *targetNorm->tables.begin(); + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional metatableMetamethod = findMetatableEntry(ctx->builtins, dummy, table, "__metatable", location); + + // if the `__metatable` metamethod is present, then the table is locked and we cannot `setmetatable` on it. + if (metatableMetamethod) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + TypeId withMetatable = ctx->arena->addType(MetatableType{table, metatableTy}); + + return {withMetatable, Reduction::MaybeOk, {}, {}}; + } + + TypeId result = ctx->builtins->neverType; + + for (auto componentTy : targetNorm->tables) + { + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional metatableMetamethod = findMetatableEntry(ctx->builtins, dummy, componentTy, "__metatable", location); + + // if the `__metatable` metamethod is present, then the table is locked and we cannot `setmetatable` on it. + if (metatableMetamethod) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + TypeId withMetatable = ctx->arena->addType(MetatableType{componentTy, metatableTy}); + SimplifyResult simplified = simplifyUnion(ctx->builtins, ctx->arena, result, withMetatable); + + if (!simplified.blockedTypes.empty()) + { + std::vector blockedTypes{}; + blockedTypes.reserve(simplified.blockedTypes.size()); + for (auto ty : simplified.blockedTypes) + blockedTypes.push_back(ty); + return {std::nullopt, Reduction::MaybeOk, blockedTypes, {}}; + } + + result = simplified.result; + } + + return {result, Reduction::MaybeOk, {}, {}}; +} + +static TypeFunctionReductionResult getmetatableHelper(TypeId targetTy, const Location& location, NotNull ctx) +{ + targetTy = follow(targetTy); + + std::optional metatable = std::nullopt; + bool erroneous = true; + + if (auto table = get(targetTy)) + erroneous = false; + + if (auto mt = get(targetTy)) + { + metatable = mt->metatable; + erroneous = false; + } + + if (auto clazz = get(targetTy)) + { + metatable = clazz->metatable; + erroneous = false; + } + + if (auto primitive = get(targetTy)) + { + metatable = primitive->metatable; + erroneous = false; + } + + if (auto singleton = get(targetTy)) + { + if (get(singleton)) + { + auto primitiveString = get(ctx->builtins->stringType); + metatable = primitiveString->metatable; + } + erroneous = false; + } + + if (erroneous) + return {std::nullopt, Reduction::Erroneous, {}, {}}; + + // findMetatableEntry demands the ability to emit errors, so we must give it + // the necessary state to do that, even if we intend to just eat the errors. + ErrorVec dummy; + + std::optional metatableMetamethod = findMetatableEntry(ctx->builtins, dummy, targetTy, "__metatable", location); + + if (metatableMetamethod) + return {metatableMetamethod, Reduction::MaybeOk, {}, {}}; + + if (metatable) + return {metatable, Reduction::MaybeOk, {}, {}}; + + return {ctx->builtins->nilType, Reduction::MaybeOk, {}, {}}; +} + +TypeFunctionReductionResult getmetatableTypeFunction( + TypeId instance, + const std::vector& typeParams, + const std::vector& packParams, + NotNull ctx +) +{ + if (typeParams.size() != 1 || !packParams.empty()) + { + ctx->ice->ice("getmetatable type function: encountered a type function instance without the required argument structure"); + LUAU_ASSERT(false); + } + + const Location location = ctx->constraint ? ctx->constraint->location : Location{}; + + TypeId targetTy = follow(typeParams.at(0)); + + if (isPending(targetTy, ctx->solver)) + return {std::nullopt, Reduction::MaybeOk, {targetTy}, {}}; + + if (auto ut = get(targetTy)) + { + std::vector options{}; + options.reserve(ut->options.size()); + + for (auto option : ut->options) + { + TypeFunctionReductionResult result = getmetatableHelper(option, location, ctx); + + if (!result.result) + return result; + + options.push_back(*result.result); + } + + return {ctx->arena->addType(UnionType{std::move(options)}), Reduction::MaybeOk, {}, {}}; + } + + if (auto it = get(targetTy)) + { + std::vector parts{}; + parts.reserve(it->parts.size()); + + for (auto part : it->parts) + { + TypeFunctionReductionResult result = getmetatableHelper(part, location, ctx); + + if (!result.result) + return result; + + parts.push_back(*result.result); + } + + return {ctx->arena->addType(IntersectionType{std::move(parts)}), Reduction::MaybeOk, {}, {}}; + } + + return getmetatableHelper(targetTy, location, ctx); +} + + BuiltinTypeFunctions::BuiltinTypeFunctions() : userFunc{"user", userDefinedTypeFunction} , notFunc{"not", notTypeFunction} @@ -2372,6 +3044,8 @@ BuiltinTypeFunctions::BuiltinTypeFunctions() , rawkeyofFunc{"rawkeyof", rawkeyofTypeFunction} , indexFunc{"index", indexTypeFunction} , rawgetFunc{"rawget", rawgetTypeFunction} + , setmetatableFunc{"setmetatable", setmetatableTypeFunction} + , getmetatableFunc{"getmetatable", getmetatableTypeFunction} { } @@ -2418,6 +3092,12 @@ void BuiltinTypeFunctions::addToScope(NotNull arena, NotNull s scope->exportedTypeBindings[indexFunc.name] = mkBinaryTypeFunction(&indexFunc); scope->exportedTypeBindings[rawgetFunc.name] = mkBinaryTypeFunction(&rawgetFunc); + + if (FFlag::LuauMetatableTypeFunctions) + { + scope->exportedTypeBindings[setmetatableFunc.name] = mkBinaryTypeFunction(&setmetatableFunc); + scope->exportedTypeBindings[getmetatableFunc.name] = mkUnaryTypeFunction(&getmetatableFunc); + } } const BuiltinTypeFunctions& builtinTypeFunctions() diff --git a/Analysis/src/TypeFunctionReductionGuesser.cpp b/Analysis/src/TypeFunctionReductionGuesser.cpp index d4a7c7c0..389a797d 100644 --- a/Analysis/src/TypeFunctionReductionGuesser.cpp +++ b/Analysis/src/TypeFunctionReductionGuesser.cpp @@ -3,6 +3,7 @@ #include "Luau/DenseHash.h" #include "Luau/Normalize.h" +#include "Luau/ToString.h" #include "Luau/TypeFunction.h" #include "Luau/Type.h" #include "Luau/TypePack.h" diff --git a/Analysis/src/TypeFunctionRuntime.cpp b/Analysis/src/TypeFunctionRuntime.cpp new file mode 100644 index 00000000..fb33560e --- /dev/null +++ b/Analysis/src/TypeFunctionRuntime.cpp @@ -0,0 +1,2482 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFunctionRuntime.h" + +#include "Luau/DenseHash.h" +#include "Luau/StringUtils.h" +#include "Luau/TypeFunction.h" + +#include "lua.h" +#include "lualib.h" + +#include +#include +#include + +LUAU_FASTFLAGVARIABLE(LuauTypeFunFixHydratedClasses) +LUAU_DYNAMIC_FASTINT(LuauTypeFunctionSerdeIterationLimit) +LUAU_FASTFLAGVARIABLE(LuauTypeFunSingletonEquality) +LUAU_FASTFLAGVARIABLE(LuauUserTypeFunTypeofReturnsType) +LUAU_FASTFLAGVARIABLE(LuauTypeFunPrintFix) +LUAU_FASTFLAGVARIABLE(LuauTypeFunReadWriteParents) + +namespace Luau +{ + +constexpr int kTypeUserdataTag = 42; + +void* typeFunctionAlloc(void* ud, void* ptr, size_t osize, size_t nsize) +{ + if (nsize == 0) + { + ::operator delete(ptr); + return nullptr; + } + else if (osize == 0) + { + return ::operator new(nsize); + } + else + { + void* data = ::operator new(nsize); + memcpy(data, ptr, nsize < osize ? nsize : osize); + + ::operator delete(ptr); + + return data; + } +} + +std::optional checkResultForError(lua_State* L, const char* typeFunctionName, int luaResult) +{ + switch (luaResult) + { + case LUA_OK: + return std::nullopt; + case LUA_YIELD: + case LUA_BREAK: + return format("'%s' type function errored: unexpected yield or break", typeFunctionName); + default: + if (!lua_gettop(L)) + return format("'%s' type function errored unexpectedly", typeFunctionName); + + if (lua_isstring(L, -1)) + return format("'%s' type function errored at runtime: %s", typeFunctionName, lua_tostring(L, -1)); + + return format("'%s' type function errored at runtime: raised an error of type %s", typeFunctionName, lua_typename(L, -1)); + } +} + +static TypeFunctionRuntime* getTypeFunctionRuntime(lua_State* L) +{ + return static_cast(lua_getthreaddata(lua_mainthread(L))); +} + +TypeFunctionType* allocateTypeFunctionType(lua_State* L, TypeFunctionTypeVariant type) +{ + auto ctx = getTypeFunctionRuntime(L); + return ctx->typeArena.allocate(std::move(type)); +} + +TypeFunctionTypePackVar* allocateTypeFunctionTypePack(lua_State* L, TypeFunctionTypePackVariant type) +{ + auto ctx = getTypeFunctionRuntime(L); + return ctx->typePackArena.allocate(std::move(type)); +} + +// Pushes a new type userdata onto the stack +void allocTypeUserData(lua_State* L, TypeFunctionTypeVariant type) +{ + // allocate a new type userdata + TypeFunctionTypeId* ptr = static_cast(lua_newuserdatatagged(L, sizeof(TypeFunctionTypeId), kTypeUserdataTag)); + *ptr = allocateTypeFunctionType(L, std::move(type)); + + // set the new userdata's metatable to type metatable + luaL_getmetatable(L, "type"); + lua_setmetatable(L, -2); +} + +void deallocTypeUserData(lua_State* L, void* data) +{ + // only non-owning pointers into an arena is stored +} + +bool isTypeUserData(lua_State* L, int idx) +{ + if (!lua_isuserdata(L, idx)) + return false; + + return lua_touserdatatagged(L, idx, kTypeUserdataTag) != nullptr; +} + +TypeFunctionTypeId getTypeUserData(lua_State* L, int idx) +{ + if (auto typ = static_cast(lua_touserdatatagged(L, idx, kTypeUserdataTag))) + return *typ; + + luaL_typeerrorL(L, idx, "type"); +} + +std::optional optionalTypeUserData(lua_State* L, int idx) +{ + if (lua_isnoneornil(L, idx)) + return std::nullopt; + else + return getTypeUserData(L, idx); +} + +// returns a string tag of TypeFunctionTypeId +static std::string getTag(lua_State* L, TypeFunctionTypeId ty) +{ + if (auto n = get(ty); n && n->type == TypeFunctionPrimitiveType::Type::NilType) + return "nil"; + else if (auto b = get(ty); b && b->type == TypeFunctionPrimitiveType::Type::Boolean) + return "boolean"; + else if (auto n = get(ty); n && n->type == TypeFunctionPrimitiveType::Type::Number) + return "number"; + else if (auto s = get(ty); s && s->type == TypeFunctionPrimitiveType::Type::String) + return "string"; + else if (auto s = get(ty); s && s->type == TypeFunctionPrimitiveType::Type::Thread) + return "thread"; + else if (auto s = get(ty); s && s->type == TypeFunctionPrimitiveType::Type::Buffer) + return "buffer"; + else if (get(ty)) + return "unknown"; + else if (get(ty)) + return "never"; + else if (get(ty)) + return "any"; + else if (auto s = get(ty)) + return "singleton"; + else if (get(ty)) + return "negation"; + else if (get(ty)) + return "union"; + else if (get(ty)) + return "intersection"; + else if (get(ty)) + return "table"; + else if (get(ty)) + return "function"; + else if (get(ty)) + return "class"; + else if (get(ty)) + return "generic"; + + LUAU_UNREACHABLE(); + luaL_error(L, "VM encountered unexpected type variant when determining tag"); +} + +// Luau: `type.unknown` +// Returns the type instance representing unknown +static int createUnknown(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionUnknownType{}); + + return 1; +} + +// Luau: `type.never` +// Returns the type instance representing never +static int createNever(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionNeverType{}); + + return 1; +} + +// Luau: `type.any` +// Returns the type instance representing any +static int createAny(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionAnyType{}); + + return 1; +} + +// Luau: `type.boolean` +// Returns the type instance representing boolean +static int createBoolean(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Boolean}); + + return 1; +} + +// Luau: `type.number` +// Returns the type instance representing number +static int createNumber(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Number}); + + return 1; +} + +// Luau: `type.string` +// Returns the type instance representing string +static int createString(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::String}); + + return 1; +} + +// Luau: `type.thread` +static int createThread(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Thread}); + + return 1; +} + +// Luau: `type.buffer` +static int createBuffer(lua_State* L) +{ + allocTypeUserData(L, TypeFunctionPrimitiveType{TypeFunctionPrimitiveType::Buffer}); + + return 1; +} + +// Luau: `type.singleton(value: string | boolean | nil) -> type` +// Returns the type instance representing string or boolean singleton or nil +static int createSingleton(lua_State* L) +{ + if (lua_isboolean(L, 1)) // Create boolean singleton + { + bool value = luaL_checkboolean(L, 1); + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionBooleanSingleton{value}}); + + return 1; + } + + // n.b. we cannot use lua_isstring here because lua committed the cardinal sin of calling a number a string + if (lua_type(L, 1) == LUA_TSTRING) // Create string singleton + { + const char* value = luaL_checkstring(L, 1); + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{value}}); + + return 1; + } + + if (lua_isnil(L, 1)) + { + allocTypeUserData(L, TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + + return 1; + } + + luaL_error(L, "types.singleton: can't create singleton from `%s` type", lua_typename(L, 1)); +} + +// Luau: `types.generic(name: string, ispack: boolean?) -> type +// Create a generic type with the specified type. If an optinal boolean is set to true, result is a generic pack +static int createGeneric(lua_State* L) +{ + const char* name = luaL_checkstring(L, 1); + bool isPack = luaL_optboolean(L, 2, false); + + if (strlen(name) == 0) + luaL_error(L, "types.generic: generic name cannot be empty"); + + allocTypeUserData(L, TypeFunctionGenericType{/* isNamed */ true, isPack, name}); + return 1; +} + +// Luau: `self:value() -> type` +// Returns the value of a singleton +static int getSingletonValue(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.value: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfpt = get(self)) + { + if (tfpt->type != TypeFunctionPrimitiveType::NilType) + luaL_error(L, "type.value: expected self to be a singleton, but got %s instead", getTag(L, self).c_str()); + + lua_pushnil(L); + return 1; + } + + auto tfst = get(self); + if (!tfst) + luaL_error(L, "type.value: expected self to be a singleton, but got %s instead", getTag(L, self).c_str()); + + if (auto tfbst = get(tfst)) + { + lua_pushboolean(L, tfbst->value); + return 1; + } + + if (auto tfsst = get(tfst)) + { + lua_pushlstring(L, tfsst->value.c_str(), tfsst->value.length()); + return 1; + } + + luaL_error(L, "type.value: can't call `value` method on `%s` type", getTag(L, self).c_str()); +} + +// Luau: `types.unionof(...: type) -> type` +// Returns the type instance representing union +static int createUnion(lua_State* L) +{ + // get the number of arguments for union + int argSize = lua_gettop(L); + if (argSize < 2) + luaL_error(L, "types.unionof: expected at least 2 types to union, but got %d", argSize); + + std::vector components; + components.reserve(argSize); + + for (int i = 1; i <= argSize; i++) + components.push_back(getTypeUserData(L, i)); + + allocTypeUserData(L, TypeFunctionUnionType{components}); + + return 1; +} + +// Luau: `types.intersectionof(...: type) -> type` +// Returns the type instance representing intersection +static int createIntersection(lua_State* L) +{ + // get the number of arguments for intersection + int argSize = lua_gettop(L); + if (argSize < 2) + luaL_error(L, "types.intersectionof: expected at least 2 types to intersection, but got %d", argSize); + + std::vector components; + components.reserve(argSize); + + for (int i = 1; i <= argSize; i++) + components.push_back(getTypeUserData(L, i)); + + allocTypeUserData(L, TypeFunctionIntersectionType{components}); + + return 1; +} + +// Luau: `self:components() -> {type}` +// Returns the components of union or intersection +static int getComponents(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.components: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfut = get(self); + if (tfut) + { + int argSize = int(tfut->components.size()); + + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + TypeFunctionTypeId component = tfut->components[i]; + allocTypeUserData(L, component->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + + return 1; + } + + auto tfit = get(self); + if (tfit) + { + int argSize = int(tfit->components.size()); + + lua_createtable(L, argSize, 0); + for (int i = 0; i < argSize; i++) + { + TypeFunctionTypeId component = tfit->components[i]; + allocTypeUserData(L, component->type); + lua_rawseti(L, -2, i + 1); // Luau is 1-indexed while C++ is 0-indexed + } + + return 1; + } + + luaL_error(L, "type.components: cannot call components of `%s` type", getTag(L, self).c_str()); +} + +// Luau: `types.negationof(arg: type) -> type` +// Returns the type instance representing negation +static int createNegation(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "types.negationof: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId arg = getTypeUserData(L, 1); + if (get(arg) || get(arg)) + luaL_error(L, "types.negationof: cannot perform negation on `%s` type", getTag(L, arg).c_str()); + + allocTypeUserData(L, TypeFunctionNegationType{arg}); + + return 1; +} + +// Luau: `self:inner() -> type` +// Returns the type instance being negated +static int getNegatedValue(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.inner: expected 1 argument, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + + if (auto tfnt = get(self); tfnt) + allocTypeUserData(L, tfnt->type->type); + else + luaL_error(L, "type.inner: cannot call inner method on non-negation type: `%s` type", getTag(L, self).c_str()); + + return 1; +} + +// Luau: `types.newtable(props: {[type]: type | { read: type, write: type }}?, indexer: {index: type, readresult: type, writeresult: type}?, +// metatable: type?) -> type` Returns the type instance representing table +static int createTable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 3) + luaL_error(L, "types.newtable: expected 0-3 arguments, but got %d", argumentCount); + + // Parse prop + TypeFunctionTableType::Props props{}; + if (lua_istable(L, 1)) + { + lua_pushnil(L); + while (lua_next(L, 1) != 0) + { + TypeFunctionTypeId key = getTypeUserData(L, -2); + + auto tfst = get(key); + if (!tfst) + luaL_error(L, "types.newtable: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "types.newtable: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + if (lua_istable(L, -1)) + { + lua_getfield(L, -1, "read"); + std::optional readTy; + if (!lua_isnil(L, -1)) + readTy = getTypeUserData(L, -1); + lua_pop(L, 1); + + lua_getfield(L, -1, "write"); + std::optional writeTy; + if (!lua_isnil(L, -1)) + writeTy = getTypeUserData(L, -1); + lua_pop(L, 1); + + props[tfsst->value] = TypeFunctionProperty{readTy, writeTy}; + } + else + { + TypeFunctionTypeId value = getTypeUserData(L, -1); + props[tfsst->value] = TypeFunctionProperty::rw(value); + } + + lua_pop(L, 1); + } + } + else if (!lua_isnoneornil(L, 1)) + luaL_typeerrorL(L, 1, "table"); + + // Parse indexer + std::optional indexer; + if (lua_istable(L, 2)) + { + // Parse keyType and valueType + lua_getfield(L, 2, "index"); + TypeFunctionTypeId keyType = getTypeUserData(L, -1); + lua_pop(L, 1); + + lua_getfield(L, 2, "readresult"); + TypeFunctionTypeId valueType = getTypeUserData(L, -1); + lua_pop(L, 1); + + indexer = TypeFunctionTableIndexer(keyType, valueType); + } + else if (!lua_isnoneornil(L, 2)) + luaL_typeerrorL(L, 2, "table"); + + // Parse metatable + std::optional metatable = optionalTypeUserData(L, 3); + if (metatable && !get(*metatable)) + luaL_error(L, "types.newtable: expected to be given a table type as a metatable, but got %s instead", getTag(L, *metatable).c_str()); + + allocTypeUserData(L, TypeFunctionTableType{props, indexer, metatable}); + return 1; +} + +// Luau: `self:setproperty(key: type, value: type?)` +// Sets the properties of a table +static int setTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + tftt->props.erase(tfsst->value); + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + tftt->props[tfsst->value] = TypeFunctionProperty::rw(value, value); + + return 0; +} + +// Luau: `self:setreadproperty(key: type, value: type?)` +// Sets the properties of a table +static int setReadTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setreadproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setreadproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setreadproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setreadproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + auto iter = tftt->props.find(tfsst->value); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + // if it's read-only, remove it altogether + if (iter != tftt->props.end() && iter->second.isReadOnly()) + tftt->props.erase(tfsst->value); + // but if it's not, just null out the read type. + else if (iter != tftt->props.end()) + iter->second.readTy = std::nullopt; + + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + if (iter == tftt->props.end()) + tftt->props[tfsst->value] = TypeFunctionProperty::readonly(value); + else + iter->second.readTy = value; + + return 0; +} + +// Luau: `self:setwriteproperty(key: type, value: type?)` +// Sets the properties of a table +static int setWriteTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setwriteproperty: expected 2-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setwriteproperty: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.setwriteproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.setwriteproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + auto iter = tftt->props.find(tfsst->value); + + if (argumentCount == 2 || lua_isnil(L, 3)) + { + // if it's write-only, remove it altogether + if (iter != tftt->props.end() && iter->second.isWriteOnly()) + tftt->props.erase(tfsst->value); + // but if it's not, just null out the write type. + else if (iter != tftt->props.end()) + iter->second.writeTy = std::nullopt; + + return 0; + } + + TypeFunctionTypeId value = getTypeUserData(L, 3); + if (iter == tftt->props.end()) + tftt->props[tfsst->value] = TypeFunctionProperty::writeonly(value); + else + iter->second.writeTy = value; + + return 0; +} + +// Luau: `self:readproperty(key: type) -> type` +// Returns the property of a table associated with the key +static int readTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.readproperty: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = get(self); + if (!tftt) + luaL_error(L, "type.readproperty: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.readproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.readproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + // Check if key is a valid prop + if (tftt->props.find(tfsst->value) == tftt->props.end()) + { + lua_pushnil(L); + return 1; + } + + auto prop = tftt->props.at(tfsst->value); + if (prop.readTy) + allocTypeUserData(L, (*prop.readTy)->type); + else + lua_pushnil(L); + + return 1; +} +// +// Luau: `self:writeproperty(key: type) -> type` +// Returns the property of a table associated with the key +static int writeTableProp(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.writeproperty: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = get(self); + if (!tftt) + luaL_error(L, "type.writeproperty: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + auto tfst = get(key); + if (!tfst) + luaL_error(L, "type.writeproperty: expected to be given a singleton type, but got %s instead", getTag(L, key).c_str()); + + auto tfsst = get(tfst); + if (!tfsst) + luaL_error(L, "type.writeproperty: expected to be given a string singleton type, but got %s instead", getTag(L, key).c_str()); + + // Check if key is a valid prop + if (tftt->props.find(tfsst->value) == tftt->props.end()) + { + lua_pushnil(L); + return 1; + } + + auto prop = tftt->props.at(tfsst->value); + if (prop.writeTy) + allocTypeUserData(L, (*prop.writeTy)->type); + else + lua_pushnil(L); + + return 1; +} + +// Luau: `self:setindexer(key: type, value: type)` +// Sets the indexer of the table, if the key type is `never`, the indexer is removed +static int setTableIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 3) + luaL_error(L, "type.setindexer: expected 3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setindexer: expected self to be either a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId key = getTypeUserData(L, 2); + TypeFunctionTypeId value = getTypeUserData(L, 3); + + if (auto tfnt = get(key)) + { + tftt->indexer = std::nullopt; + return 0; + } + + tftt->indexer = TypeFunctionTableIndexer{key, value}; + return 0; +} + +// Luau: `self:setreadindexer(key: type, value: type)` +// Sets the read indexer of the table +static int setTableReadIndexer(lua_State* L) +{ + luaL_error(L, "type.setreadindexer: luau does not yet support separate read/write types for indexers."); +} + +// Luau: `self:setwriteindexer(key: type, value: type)` +// Sets the write indexer of the table +static int setTableWriteIndexer(lua_State* L) +{ + luaL_error(L, "type.setwriteindexer: luau does not yet support separate read/write types for indexers."); +} + +// Luau: `self:setmetatable(arg: type)` +// Sets the metatable of the table +static int setTableMetatable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.setmetatable: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + + auto tftt = getMutable(self); + if (!tftt) + luaL_error(L, "type.setmetatable: expected self to be a table, but got %s instead", getTag(L, self).c_str()); + + TypeFunctionTypeId arg = getTypeUserData(L, 2); + if (!get(arg)) + luaL_error(L, "type.setmetatable: expected the argument to be a table, but got %s instead", getTag(L, self).c_str()); + + tftt->metatable = arg; + + return 0; +} + +static std::tuple, std::vector> getGenerics(lua_State* L, int idx, const char* fname) +{ + std::vector types; + std::vector packs; + + if (lua_istable(L, idx)) + { + lua_pushvalue(L, idx); + + for (int i = 1; i <= lua_objlen(L, -1); i++) + { + lua_pushinteger(L, i); + lua_gettable(L, -2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + TypeFunctionTypeId ty = getTypeUserData(L, -1); + + if (auto gty = get(ty)) + { + if (gty->isPack) + { + packs.push_back(allocateTypeFunctionTypePack(L, TypeFunctionGenericTypePack{gty->isNamed, gty->name})); + } + else + { + if (!packs.empty()) + luaL_error(L, "%s: generic type cannot follow a generic pack", fname); + + types.push_back(ty); + } + } + else + { + luaL_error(L, "%s: table member was not a generic type", fname); + } + + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + else if (!lua_isnoneornil(L, idx)) + { + luaL_typeerrorL(L, idx, "table"); + } + + return {types, packs}; +} + +static TypeFunctionTypePackId getTypePack(lua_State* L, int headIdx, int tailIdx) +{ + TypeFunctionTypePackId result = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + + std::vector head; + + if (lua_istable(L, headIdx)) + { + lua_pushvalue(L, headIdx); + + for (int i = 1; i <= lua_objlen(L, -1); i++) + { + lua_pushinteger(L, i); + lua_gettable(L, -2); + + if (lua_isnil(L, -1)) + { + lua_pop(L, 1); + break; + } + + head.push_back(getTypeUserData(L, -1)); + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + + std::optional tail; + + if (auto type = optionalTypeUserData(L, tailIdx)) + { + if (auto gty = get(*type); gty && gty->isPack) + tail = allocateTypeFunctionTypePack(L, TypeFunctionGenericTypePack{gty->isNamed, gty->name}); + else + tail = allocateTypeFunctionTypePack(L, TypeFunctionVariadicTypePack{*type}); + } + + if (head.size() == 0 && tail.has_value()) + result = *tail; + else + result = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{head, tail}); + + return result; +} + +static void pushTypePack(lua_State* L, TypeFunctionTypePackId tp) +{ + if (auto tftp = get(tp)) + { + lua_createtable(L, 0, 2); + + if (!tftp->head.empty()) + { + lua_createtable(L, int(tftp->head.size()), 0); + int pos = 1; + + for (auto el : tftp->head) + { + allocTypeUserData(L, el->type); + lua_rawseti(L, -2, pos++); + } + + lua_setfield(L, -2, "head"); + } + + if (tftp->tail.has_value()) + { + if (auto tfvp = get(*tftp->tail)) + allocTypeUserData(L, tfvp->type->type); + else if (auto tfgp = get(*tftp->tail)) + allocTypeUserData(L, TypeFunctionGenericType{tfgp->isNamed, true, tfgp->name}); + else + luaL_error(L, "unsupported type pack type"); + + lua_setfield(L, -2, "tail"); + } + } + else if (auto tfvp = get(tp)) + { + lua_createtable(L, 0, 1); + + allocTypeUserData(L, tfvp->type->type); + lua_setfield(L, -2, "tail"); + } + else if (auto tfgp = get(tp)) + { + lua_createtable(L, 0, 1); + + allocTypeUserData(L, TypeFunctionGenericType{tfgp->isNamed, true, tfgp->name}); + lua_setfield(L, -2, "tail"); + } + else + { + luaL_error(L, "unsupported type pack type"); + } +} + +// Luau: `types.newfunction(parameters: {head: {type}?, tail: type?}, returns: {head: {type}?, tail: type?}, generics: {type}?) -> type` +// Returns the type instance representing a function +static int createFunction(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 3) + luaL_error(L, "types.newfunction: expected 0-3 arguments, but got %d", argumentCount); + + TypeFunctionTypePackId argTypes = nullptr; + + if (lua_istable(L, 1)) + { + lua_getfield(L, 1, "head"); + lua_getfield(L, 1, "tail"); + + argTypes = getTypePack(L, -2, -1); + + lua_pop(L, 2); + } + else if (!lua_isnoneornil(L, 1)) + { + luaL_typeerrorL(L, 1, "table"); + } + else + { + argTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + } + + TypeFunctionTypePackId retTypes = nullptr; + + if (lua_istable(L, 2)) + { + lua_getfield(L, 2, "head"); + lua_getfield(L, 2, "tail"); + + retTypes = getTypePack(L, -2, -1); + + lua_pop(L, 2); + } + else if (!lua_isnoneornil(L, 2)) + { + luaL_typeerrorL(L, 2, "table"); + } + else + { + retTypes = allocateTypeFunctionTypePack(L, TypeFunctionTypePack{}); + } + + auto [genericTypes, genericPacks] = getGenerics(L, 3, "types.newfunction"); + + allocTypeUserData(L, TypeFunctionFunctionType{std::move(genericTypes), std::move(genericPacks), argTypes, retTypes}); + + return 1; +} + +// Luau: `self:setparameters(head: {type}?, tail: type?)` +// Sets the parameters of the function +static int setFunctionParameters(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount > 3 || argumentCount < 1) + luaL_error(L, "type.setparameters: expected 1-3, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = getMutable(self); + if (!tfft) + luaL_error(L, "type.setparameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + tfft->argTypes = getTypePack(L, 2, 3); + + return 0; +} + +// Luau: `self:parameters() -> {head: {type}?, tail: type?}` +// Returns the parameters of the function +static int getFunctionParameters(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.parameters: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = get(self); + if (!tfft) + luaL_error(L, "type.parameters: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + pushTypePack(L, tfft->argTypes); + + return 1; +} + +// Luau: `self:setreturns(head: {type}?, tail: type?)` +// Sets the returns of the function +static int setFunctionReturns(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount < 2 || argumentCount > 3) + luaL_error(L, "type.setreturns: expected 1-3 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = getMutable(self); + if (!tfft) + luaL_error(L, "type.setreturns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + tfft->retTypes = getTypePack(L, 2, 3); + + return 0; +} + +// Luau: `self:returns() -> {head: {type}?, tail: type?}` +// Returns the returns of the function +static int getFunctionReturns(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.returns: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = get(self); + if (!tfft) + luaL_error(L, "type.returns: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + pushTypePack(L, tfft->retTypes); + + return 1; +} + +// Luau: `self:setgenerics(generics: {type}?)` +static int setFunctionGenerics(lua_State* L) +{ + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = getMutable(self); + if (!tfft) + luaL_error(L, "type.setgenerics: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + int argumentCount = lua_gettop(L); + if (argumentCount > 3) + luaL_error(L, "type.setgenerics: expected 3 arguments, but got %d", argumentCount); + + auto [genericTypes, genericPacks] = getGenerics(L, 2, "types.setgenerics"); + + tfft->generics = std::move(genericTypes); + tfft->genericPacks = std::move(genericPacks); + + return 0; +} + +// Luau: `self:generics() -> {type}` +static int getFunctionGenerics(lua_State* L) +{ + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfft = get(self); + if (!tfft) + luaL_error(L, "type.generics: expected self to be a function, but got %s instead", getTag(L, self).c_str()); + + lua_createtable(L, int(tfft->generics.size()) + int(tfft->genericPacks.size()), 0); + + int pos = 1; + + for (const auto& el : tfft->generics) + { + allocTypeUserData(L, el->type); + lua_rawseti(L, -2, pos++); + } + + for (const auto& el : tfft->genericPacks) + { + auto gty = get(el); + LUAU_ASSERT(gty); + allocTypeUserData(L, TypeFunctionGenericType{gty->isNamed, true, gty->name}); + lua_rawseti(L, -2, pos++); + } + + return 1; +} + +// Luau: `self:parent() -> type` +// Returns the parent of a class type +static int getClassParent_DEPRECATED(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.parent: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfct = get(self); + if (!tfct) + luaL_error(L, "type.parent: expected self to be a class, but got %s instead", getTag(L, self).c_str()); + + // If the parent does not exist, we should return nil + if (!tfct->parent_DEPRECATED) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfct->parent_DEPRECATED)->type); + + return 1; +} + +// Luau: `self:readparent() -> type` +// Returns the read type of the class' parent +static int getReadParent(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.parent: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfct = get(self); + if (!tfct) + luaL_error(L, "type.parent: expected self to be a class, but got %s instead", getTag(L, self).c_str()); + + // If the parent does not exist, we should return nil + if (!tfct->readParent) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfct->readParent)->type); + + return 1; +} +// +// Luau: `self:writeparent() -> type` +// Returns the write type of the class' parent +static int getWriteParent(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.parent: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfct = get(self); + if (!tfct) + luaL_error(L, "type.parent: expected self to be a class, but got %s instead", getTag(L, self).c_str()); + + // If the parent does not exist, we should return nil + if (!tfct->writeParent) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfct->writeParent)->type); + + return 1; +} + +// Luau: `self:name() -> string?` +// Returns the name of the generic or 'nil' if the generic is unnamed +static int getGenericName(lua_State* L) +{ + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfgt = get(self); + if (!tfgt) + luaL_error(L, "type.name: expected self to be a generic, but got %s instead", getTag(L, self).c_str()); + + if (tfgt->isNamed) + lua_pushstring(L, tfgt->name.c_str()); + else + lua_pushnil(L); + + return 1; +} + +// Luau: `self:ispack() -> boolean` +// Returns true if the generic is a pack +static int getGenericIsPack(lua_State* L) +{ + TypeFunctionTypeId self = getTypeUserData(L, 1); + auto tfgt = get(self); + if (!tfgt) + luaL_error(L, "type.ispack: expected self to be a generic, but got %s instead", getTag(L, self).c_str()); + + lua_pushboolean(L, tfgt->isPack); + return 1; +} + +// Luau: `self:properties() -> {[type]: { read: type?, write: type? }}` +// Returns the properties of a table or class type +static int getProps(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.properties: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + lua_createtable(L, int(tftt->props.size()), 0); + for (auto& [name, prop] : tftt->props) + { + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{name}}); + + int size = 0; + if (prop.readTy) + size++; + if (prop.writeTy) + size++; + + lua_createtable(L, 0, size); + if (prop.readTy) + { + allocTypeUserData(L, (*prop.readTy)->type); + lua_setfield(L, -2, "read"); + } + + if (prop.writeTy) + { + allocTypeUserData(L, (*prop.writeTy)->type); + lua_setfield(L, -2, "write"); + } + + lua_settable(L, -3); + } + + return 1; + } + + if (auto tfct = get(self)) + { + lua_createtable(L, int(tfct->props.size()), 0); + for (auto& [name, prop] : tfct->props) + { + allocTypeUserData(L, TypeFunctionSingletonType{TypeFunctionStringSingleton{name}}); + + int size = 0; + if (prop.readTy) + size++; + if (prop.writeTy) + size++; + + lua_createtable(L, 0, size); + if (prop.readTy) + { + allocTypeUserData(L, (*prop.readTy)->type); + lua_setfield(L, -2, "read"); + } + + if (prop.writeTy) + { + allocTypeUserData(L, (*prop.writeTy)->type); + lua_setfield(L, -2, "write"); + } + + lua_settable(L, -3); + } + + return 1; + } + + luaL_error(L, "type.properties: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:indexer() -> {index: type, readresult: type, writeresult: type}?` +// Returns the indexer of a table or class type +static int getIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.indexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 3); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "readresult"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "writeresult"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 3); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "readresult"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "writeresult"); + } + + return 1; + } + + luaL_error(L, "type.indexer: self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:readindexer() -> {index: type, result: type}?` +// Returns the read indexer of a table or class type +static int getReadIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.readindexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + luaL_error(L, "type.readindexer: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:writeindexer() -> {index: type, result: type}?` +// Returns the write indexer of a table or class type +static int getWriteIndexer(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.writeindexer: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tftt = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tftt->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tftt->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tftt->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + if (auto tfct = get(self)) + { + // if the indexer does not exist, we should return nil + if (!tfct->indexer.has_value()) + lua_pushnil(L); + else + { + lua_createtable(L, 0, 2); + allocTypeUserData(L, tfct->indexer->keyType->type); + lua_setfield(L, -2, "index"); + allocTypeUserData(L, tfct->indexer->valueType->type); + lua_setfield(L, -2, "result"); + } + + return 1; + } + + luaL_error(L, "type.writeindexer: expected self to be either a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:metatable() -> type?` +// Returns the metatable of a table or class type +static int getMetatable(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "type.metatable: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + if (auto tfmt = get(self)) + { + // if the metatable does not exist, we should return nil + if (!tfmt->metatable.has_value()) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfmt->metatable)->type); + + return 1; + } + + if (auto tfct = get(self)) + { + // if the metatable does not exist, we should return nil + if (!tfct->metatable.has_value()) + lua_pushnil(L); + else + allocTypeUserData(L, (*tfct->metatable)->type); + + return 1; + } + + luaL_error(L, "type.metatable: expected self to be a table or class, but got %s instead", getTag(L, self).c_str()); +} + +// Luau: `self:is(arg: string) -> boolean` +// Returns true if given argument is a tag of self +static int checkTag(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "type.is: expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + std::string arg = luaL_checkstring(L, 2); + + lua_pushboolean(L, getTag(L, self) == arg); + return 1; +} + +TypeFunctionTypeId deepClone(NotNull runtime, TypeFunctionTypeId ty); // Forward declaration + +// Luau: `types.copy(arg: type) -> type` +// Returns a deep copy of the argument +static int deepCopy(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 1) + luaL_error(L, "types.copy: expected 1 arguments, but got %d", argumentCount); + + TypeFunctionTypeId arg = getTypeUserData(L, 1); + + TypeFunctionTypeId copy = deepClone(NotNull{getTypeFunctionRuntime(L)}, arg); + allocTypeUserData(L, copy->type); + return 1; +} + +// Luau: `self == arg -> boolean` +// Used to set the __eq metamethod +static int isEqualToType(lua_State* L) +{ + int argumentCount = lua_gettop(L); + if (argumentCount != 2) + luaL_error(L, "expected 2 arguments, but got %d", argumentCount); + + TypeFunctionTypeId self = getTypeUserData(L, 1); + TypeFunctionTypeId arg = getTypeUserData(L, 2); + + lua_pushboolean(L, *self == *arg); + return 1; +} + +void registerTypesLibrary(lua_State* L) +{ + luaL_Reg fields[] = { + {"unknown", createUnknown}, + {"never", createNever}, + {"any", createAny}, + {"boolean", createBoolean}, + {"number", createNumber}, + {"string", createString}, + {"thread", createThread}, + {"buffer", createBuffer}, + {nullptr, nullptr} + }; + + luaL_Reg methods[] = { + {"singleton", createSingleton}, + {"negationof", createNegation}, + {"unionof", createUnion}, + {"intersectionof", createIntersection}, + {"newtable", createTable}, + {"newfunction", createFunction}, + {"copy", deepCopy}, + {"generic", createGeneric}, + + {nullptr, nullptr} + }; + + luaL_register(L, "types", methods); + + // Set fields for type userdata + for (luaL_Reg* l = fields; l->name; l++) + { + l->func(L); + lua_setfield(L, -2, l->name); + } + + lua_pop(L, 1); +} + +static int typeUserdataIndex(lua_State* L) +{ + TypeFunctionTypeId self = getTypeUserData(L, 1); + const char* field = luaL_checkstring(L, 2); + + if (strcmp(field, "tag") == 0) + { + lua_pushstring(L, getTag(L, self).c_str()); + return 1; + } + + lua_pushvalue(L, lua_upvalueindex(1)); + lua_getfield(L, -1, field); + return 1; +} + +void registerTypeUserData(lua_State* L) +{ + luaL_Reg typeUserdataMethods[] = { + {"is", checkTag}, + + // Negation type methods + {"inner", getNegatedValue}, + + // Singleton type methods + {"value", getSingletonValue}, + + // Table type methods + {"setproperty", setTableProp}, + {"setreadproperty", setReadTableProp}, + {"setwriteproperty", setWriteTableProp}, + {"readproperty", readTableProp}, + {"writeproperty", writeTableProp}, + {"properties", getProps}, + {"setindexer", setTableIndexer}, + {"setreadindexer", setTableReadIndexer}, + {"setwriteindexer", setTableWriteIndexer}, + {"indexer", getIndexer}, + {"readindexer", getReadIndexer}, + {"writeindexer", getWriteIndexer}, + {"setmetatable", setTableMetatable}, + {"metatable", getMetatable}, + + // Function type methods + {"setparameters", setFunctionParameters}, + {"parameters", getFunctionParameters}, + {"setreturns", setFunctionReturns}, + {"returns", getFunctionReturns}, + {"setgenerics", setFunctionGenerics}, + {"generics", getFunctionGenerics}, + + // Union and Intersection type methods + {"components", getComponents}, + + // Class type methods + {FFlag::LuauTypeFunReadWriteParents ? "readparent" : "parent", FFlag::LuauTypeFunReadWriteParents ? getReadParent : getClassParent_DEPRECATED}, + + // Function type methods (cont.) + {"setgenerics", setFunctionGenerics}, + {"generics", getFunctionGenerics}, + + // Generic type methods + {"name", getGenericName}, + {"ispack", getGenericIsPack}, + + // move this under Class type methods when removing FFlagLuauTypeFunReadWriteParents + {FFlag::LuauTypeFunReadWriteParents ? "writeparent" : nullptr, FFlag::LuauTypeFunReadWriteParents ? getWriteParent : nullptr}, + + {nullptr, nullptr} + }; + + // Create and register metatable for type userdata + luaL_newmetatable(L, "type"); + + if (FFlag::LuauUserTypeFunTypeofReturnsType) + { + lua_pushstring(L, "type"); + lua_setfield(L, -2, "__type"); + } + + // Protect metatable from being changed + lua_pushstring(L, "The metatable is locked"); + lua_setfield(L, -2, "__metatable"); + + lua_pushcfunction(L, isEqualToType, "__eq"); + lua_setfield(L, -2, "__eq"); + + // Indexing will be a dynamic function because some type fields are dynamic + lua_newtable(L); + luaL_register(L, nullptr, typeUserdataMethods); + lua_setreadonly(L, -1, true); + lua_pushcclosure(L, typeUserdataIndex, "__index", 1); + lua_setfield(L, -2, "__index"); + + lua_setreadonly(L, -1, true); + lua_pop(L, 1); + + // Sets up a destructor for the type userdata. + lua_setuserdatadtor(L, kTypeUserdataTag, deallocTypeUserData); +} + +// Used to redirect all the removed global functions to say "this function is unsupported" +static int unsupportedFunction(lua_State* L) +{ + luaL_errorL(L, "this function is not supported in type functions"); + return 0; +} + +static int print(lua_State* L) +{ + std::string result; + + int n = lua_gettop(L); + for (int i = 1; i <= n; i++) + { + size_t l = 0; + const char* s = luaL_tolstring(L, i, &l); // convert to string using __tostring et al + if (i > 1) + { + if (FFlag::LuauTypeFunPrintFix) + result.append(1, '\t'); + else + result.append('\t', 1); + } + result.append(s, l); + lua_pop(L, 1); + } + + auto ctx = getTypeFunctionRuntime(L); + + ctx->messages.push_back(std::move(result)); + + return 0; +} + +// Add libraries / globals for type function environment +void setTypeFunctionEnvironment(lua_State* L) +{ + // Register math library + luaopen_math(L); + lua_pop(L, 1); + + // Register table library + luaopen_table(L); + lua_pop(L, 1); + + // Register string library + luaopen_string(L); + lua_pop(L, 1); + + // Register bit32 library + luaopen_bit32(L); + lua_pop(L, 1); + + // Register utf8 library + luaopen_utf8(L); + lua_pop(L, 1); + + // Register buffer library + luaopen_buffer(L); + lua_pop(L, 1); + + // Register base library + luaopen_base(L); + lua_pop(L, 1); + + // Remove certain global functions from the base library + static const char* unavailableGlobals[] = {"gcinfo", "getfenv", "newproxy", "setfenv", "pcall", "xpcall"}; + for (auto& name : unavailableGlobals) + { + lua_pushcfunction(L, unsupportedFunction, name); + lua_setglobal(L, name); + } + + lua_pushcfunction(L, print, "print"); + lua_setglobal(L, "print"); +} + +void resetTypeFunctionState(lua_State* L) +{ + lua_getglobal(L, "math"); + lua_getfield(L, -1, "randomseed"); + lua_pushnumber(L, 0); + lua_call(L, 1, 0); + lua_pop(L, 1); +} + +/* + * Below are helper methods for __eq + * Same as one from Type.cpp + */ +using SeenSet = std::set>; +bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType& rhs); +bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunctionTypePackVar& rhs); + +bool seenSetContains(SeenSet& seen, const void* lhs, const void* rhs) +{ + if (lhs == rhs) + return true; + + auto p = std::make_pair(lhs, rhs); + if (seen.find(p) != seen.end()) + return true; + + seen.insert(p); + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionSingletonType& lhs, const TypeFunctionSingletonType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + { + const TypeFunctionBooleanSingleton* lp = get(&lhs); + const TypeFunctionBooleanSingleton* rp = get(FFlag::LuauTypeFunSingletonEquality ? &rhs : &lhs); + if (lp && rp) + return lp->value == rp->value; + } + + { + const TypeFunctionStringSingleton* lp = get(&lhs); + const TypeFunctionStringSingleton* rp = get(FFlag::LuauTypeFunSingletonEquality ? &rhs : &lhs); + if (lp && rp) + return lp->value == rp->value; + } + + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionUnionType& lhs, const TypeFunctionUnionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.components.size() != rhs.components.size()) + return false; + + auto l = lhs.components.begin(); + auto r = rhs.components.begin(); + + while (l != lhs.components.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionIntersectionType& lhs, const TypeFunctionIntersectionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.components.size() != rhs.components.size()) + return false; + + auto l = lhs.components.begin(); + auto r = rhs.components.begin(); + + while (l != lhs.components.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionNegationType& lhs, const TypeFunctionNegationType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return areEqual(seen, *lhs.type, *rhs.type); +} + +bool areEqual(SeenSet& seen, const TypeFunctionTableType& lhs, const TypeFunctionTableType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.props.size() != rhs.props.size()) + return false; + + if (bool(lhs.indexer) != bool(rhs.indexer)) + return false; + + if (lhs.indexer && rhs.indexer) + { + if (!areEqual(seen, *lhs.indexer->keyType, *rhs.indexer->keyType)) + return false; + + if (!areEqual(seen, *lhs.indexer->valueType, *rhs.indexer->valueType)) + return false; + } + + auto l = lhs.props.begin(); + auto r = rhs.props.begin(); + + while (l != lhs.props.end()) + { + if ((l->second.readTy && !r->second.readTy) || (!l->second.readTy && r->second.readTy)) + return false; + + if (l->second.readTy && r->second.readTy && !areEqual(seen, **(l->second.readTy), **(r->second.readTy))) + return false; + + if ((l->second.writeTy && !r->second.writeTy) || (!l->second.writeTy && r->second.writeTy)) + return false; + + if (l->second.writeTy && r->second.writeTy && !areEqual(seen, **(l->second.writeTy), **(r->second.writeTy))) + return false; + + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionFunctionType& lhs, const TypeFunctionFunctionType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (lhs.generics.size() != rhs.generics.size()) + return false; + + for (auto l = lhs.generics.begin(), r = rhs.generics.begin(); l != lhs.generics.end() && r != rhs.generics.end(); ++l, ++r) + { + if (!areEqual(seen, **l, **r)) + return false; + } + + if (lhs.genericPacks.size() != rhs.genericPacks.size()) + return false; + + for (auto l = lhs.genericPacks.begin(), r = rhs.genericPacks.begin(); l != lhs.genericPacks.end() && r != rhs.genericPacks.end(); ++l, ++r) + { + if (!areEqual(seen, **l, **r)) + return false; + } + + if (bool(lhs.argTypes) != bool(rhs.argTypes)) + return false; + + if (lhs.argTypes && rhs.argTypes) + { + if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes)) + return false; + } + + if (bool(lhs.retTypes) != bool(rhs.retTypes)) + return false; + + if (lhs.retTypes && rhs.retTypes) + { + if (!areEqual(seen, *lhs.retTypes, *rhs.retTypes)) + return false; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionClassType& lhs, const TypeFunctionClassType& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + if (FFlag::LuauTypeFunFixHydratedClasses) + return lhs.classTy == rhs.classTy; + else + return lhs.name_DEPRECATED == rhs.name_DEPRECATED; +} + +bool areEqual(SeenSet& seen, const TypeFunctionType& lhs, const TypeFunctionType& rhs) +{ + + if (lhs.type.index() != rhs.type.index()) + return false; + + { + const TypeFunctionPrimitiveType* lp = get(&lhs); + const TypeFunctionPrimitiveType* rp = get(&rhs); + if (lp && rp) + return lp->type == rp->type; + } + + if (get(&lhs) && get(&rhs)) + return true; + + if (get(&lhs) && get(&rhs)) + return true; + + if (get(&lhs) && get(&rhs)) + return true; + + { + const TypeFunctionSingletonType* lf = get(&lhs); + const TypeFunctionSingletonType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionUnionType* lf = get(&lhs); + const TypeFunctionUnionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionIntersectionType* lf = get(&lhs); + const TypeFunctionIntersectionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionNegationType* lf = get(&lhs); + const TypeFunctionNegationType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionTableType* lt = get(&lhs); + const TypeFunctionTableType* rt = get(&rhs); + if (lt && rt) + return areEqual(seen, *lt, *rt); + } + + { + const TypeFunctionFunctionType* lf = get(&lhs); + const TypeFunctionFunctionType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionClassType* lf = get(&lhs); + const TypeFunctionClassType* rf = get(&rhs); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TypeFunctionGenericType* lg = get(&lhs); + const TypeFunctionGenericType* rg = get(&rhs); + if (lg && rg) + return lg->isNamed == rg->isNamed && lg->isPack == rg->isPack && lg->name == rg->name; + } + + return false; +} + +bool areEqual(SeenSet& seen, const TypeFunctionTypePack& lhs, const TypeFunctionTypePack& rhs) +{ + if (lhs.head.size() != rhs.head.size()) + return false; + + auto l = lhs.head.begin(); + auto r = rhs.head.begin(); + + while (l != lhs.head.end()) + { + if (!areEqual(seen, **l, **r)) + return false; + ++l; + ++r; + } + + return true; +} + +bool areEqual(SeenSet& seen, const TypeFunctionVariadicTypePack& lhs, const TypeFunctionVariadicTypePack& rhs) +{ + if (seenSetContains(seen, &lhs, &rhs)) + return true; + + return areEqual(seen, *lhs.type, *rhs.type); +} + +bool areEqual(SeenSet& seen, const TypeFunctionTypePackVar& lhs, const TypeFunctionTypePackVar& rhs) +{ + { + const TypeFunctionTypePack* lb = get(&lhs); + const TypeFunctionTypePack* rb = get(&rhs); + if (lb && rb) + return areEqual(seen, *lb, *rb); + } + + { + const TypeFunctionVariadicTypePack* lv = get(&lhs); + const TypeFunctionVariadicTypePack* rv = get(&rhs); + if (lv && rv) + return areEqual(seen, *lv, *rv); + } + + { + const TypeFunctionGenericTypePack* lg = get(&lhs); + const TypeFunctionGenericTypePack* rg = get(&rhs); + if (lg && rg) + return lg->isNamed == rg->isNamed && lg->name == rg->name; + } + + return false; +} + +bool TypeFunctionType::operator==(const TypeFunctionType& rhs) const +{ + SeenSet seen; + return areEqual(seen, *this, rhs); +} + +bool TypeFunctionTypePackVar::operator==(const TypeFunctionTypePackVar& rhs) const +{ + SeenSet seen; + return areEqual(seen, *this, rhs); +} + + +TypeFunctionProperty TypeFunctionProperty::readonly(TypeFunctionTypeId ty) +{ + TypeFunctionProperty p; + p.readTy = ty; + return p; +} + +TypeFunctionProperty TypeFunctionProperty::writeonly(TypeFunctionTypeId ty) +{ + TypeFunctionProperty p; + p.writeTy = ty; + return p; +} + +TypeFunctionProperty TypeFunctionProperty::rw(TypeFunctionTypeId ty) +{ + return TypeFunctionProperty::rw(ty, ty); +} + +TypeFunctionProperty TypeFunctionProperty::rw(TypeFunctionTypeId read, TypeFunctionTypeId write) +{ + TypeFunctionProperty p; + p.readTy = read; + p.writeTy = write; + return p; +} + +bool TypeFunctionProperty::isReadOnly() const +{ + return readTy && !writeTy; +} + +bool TypeFunctionProperty::isWriteOnly() const +{ + return writeTy && !readTy; +} + +/* + * Below is a helper class for type.copy() + * Forked version of Clone.cpp + */ +using TypeFunctionKind = Variant; + +template +const T* get(const TypeFunctionKind& kind) +{ + return get_if(&kind); +} + +class TypeFunctionCloner +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + NotNull typeFunctionRuntime; + + // A queue of TypeFunctionTypeIds that have been cloned, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is TypeFunctionPrimitiveType, + // second must be TypeFunctionPrimitiveType; `second` is trying to copy `first` + std::vector> queue; + + SeenTypes types{{}}; // Mapping of TypeFunctionTypeIds that have been shallow cloned to TypeFunctionTypeIds + SeenTypePacks packs{{}}; // Mapping of TypeFunctionTypePackIds that have been shallow cloned to TypeFunctionTypePackIds + + int steps = 0; + +public: + explicit TypeFunctionCloner(TypeFunctionRuntime* typeFunctionRuntime) + : typeFunctionRuntime(typeFunctionRuntime) + { + } + + TypeFunctionTypeId clone(TypeFunctionTypeId ty) + { + shallowClone(ty); + run(); + + if (hasExceededIterationLimit()) + return nullptr; + + return find(ty).value_or(nullptr); + } + + TypeFunctionTypePackId clone(TypeFunctionTypePackId tp) + { + shallowClone(tp); + run(); + + if (hasExceededIterationLimit()) + return nullptr; + + return find(tp).value_or(nullptr); + } + +private: + bool hasExceededIterationLimit() const + { + return steps + queue.size() >= (size_t)DFInt::LuauTypeFunctionSerdeIterationLimit; + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit()) + break; + + auto [ty, tfti] = queue.back(); + queue.pop_back(); + + cloneChildren(ty, tfti); + } + } + + std::optional find(TypeFunctionTypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionTypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionKind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind?"); + return std::nullopt; + } + } + + TypeFunctionTypeId shallowClone(TypeFunctionTypeId ty) + { + if (auto it = find(ty)) + return *it; + + // Create a shallow serialization + TypeFunctionTypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case TypeFunctionPrimitiveType::NilType: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + break; + case TypeFunctionPrimitiveType::Boolean: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean)); + break; + case TypeFunctionPrimitiveType::Number: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Number)); + break; + case TypeFunctionPrimitiveType::String: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); + break; + case TypeFunctionPrimitiveType::Thread: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread)); + break; + case TypeFunctionPrimitiveType::Buffer: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer)); + break; + default: + break; + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnknownType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNeverType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionAnyType{}); + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionBooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionStringSingleton{ss->value}}); + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnionType{{}}); + else if (auto i = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionIntersectionType{{}}); + else if (auto n = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNegationType{{}}); + else if (auto t = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto f = get(ty)) + { + TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{}); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{{}, {}, emptyTypePack, emptyTypePack}); + } + else if (auto c = get(ty)) + target = ty; // Don't copy a class since they are immutable + else if (auto g = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionGenericType{g->isNamed, g->isPack, g->name}); + else + LUAU_ASSERT(!"Unknown type"); + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypeFunctionTypePackId shallowClone(TypeFunctionTypePackId tp) + { + if (auto it = find(tp)) + return *it; + + // Create a shallow serialization + TypeFunctionTypePackId target = {}; + if (auto tPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); + else if (auto vPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); + else if (auto gPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionGenericTypePack{gPack->isNamed, gPack->name}); + else + LUAU_ASSERT(!"Unknown type"); + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void cloneChildren(TypeFunctionTypeId ty, TypeFunctionTypeId tfti) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + cloneChildren(p1, p2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + cloneChildren(u1, u2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + cloneChildren(n1, n2); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + cloneChildren(a1, a2); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + cloneChildren(s1, s2); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + cloneChildren(u1, u2); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + cloneChildren(i1, i2); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + cloneChildren(n1, n2); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; t1 && t2) + cloneChildren(t1, t2); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + cloneChildren(f1, f2); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + cloneChildren(c1, c2); + else if (auto [g1, g2] = std::tuple{getMutable(ty), getMutable(tfti)}; g1 && g2) + cloneChildren(g1, g2); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionTypePackId tp, TypeFunctionTypePackId tftp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + cloneChildren(tPack1, tPack2); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + cloneChildren(vPack1, vPack2); + else if (auto [gPack1, gPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + gPack1 && gPack2) + cloneChildren(gPack1, gPack2); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionKind kind, TypeFunctionKind tfkind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + cloneChildren(*ty, *tfty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + cloneChildren(*tp, *tftp); + else + LUAU_ASSERT(!"Unknown pair?"); // First and argument should always represent the same types + } + + void cloneChildren(TypeFunctionPrimitiveType* p1, TypeFunctionPrimitiveType* p2) + { + // noop. + } + + void cloneChildren(TypeFunctionUnknownType* u1, TypeFunctionUnknownType* u2) + { + // noop. + } + + void cloneChildren(TypeFunctionNeverType* n1, TypeFunctionNeverType* n2) + { + // noop. + } + + void cloneChildren(TypeFunctionAnyType* a1, TypeFunctionAnyType* a2) + { + // noop. + } + + void cloneChildren(TypeFunctionSingletonType* s1, TypeFunctionSingletonType* s2) + { + // noop. + } + + void cloneChildren(TypeFunctionUnionType* u1, TypeFunctionUnionType* u2) + { + for (TypeFunctionTypeId& ty : u1->components) + u2->components.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionIntersectionType* i1, TypeFunctionIntersectionType* i2) + { + for (TypeFunctionTypeId& ty : i1->components) + i2->components.push_back(shallowClone(ty)); + } + + void cloneChildren(TypeFunctionNegationType* n1, TypeFunctionNegationType* n2) + { + n2->type = shallowClone(n1->type); + } + + void cloneChildren(TypeFunctionTableType* t1, TypeFunctionTableType* t2) + { + for (auto& [k, p] : t1->props) + { + std::optional readTy; + if (p.readTy) + readTy = shallowClone(*p.readTy); + + std::optional writeTy; + if (p.writeTy) + writeTy = shallowClone(*p.writeTy); + + t2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (t1->indexer.has_value()) + t2->indexer = TypeFunctionTableIndexer(shallowClone(t1->indexer->keyType), shallowClone(t1->indexer->valueType)); + + if (t1->metatable.has_value()) + t2->metatable = shallowClone(*t1->metatable); + } + + void cloneChildren(TypeFunctionFunctionType* f1, TypeFunctionFunctionType* f2) + { + f2->generics.reserve(f1->generics.size()); + for (auto ty : f1->generics) + f2->generics.push_back(shallowClone(ty)); + + f2->genericPacks.reserve(f1->genericPacks.size()); + for (auto tp : f1->genericPacks) + f2->genericPacks.push_back(shallowClone(tp)); + + f2->argTypes = shallowClone(f1->argTypes); + f2->retTypes = shallowClone(f1->retTypes); + } + + void cloneChildren(TypeFunctionClassType* c1, TypeFunctionClassType* c2) + { + // noop. + } + + void cloneChildren(TypeFunctionGenericType* g1, TypeFunctionGenericType* g2) + { + // noop. + } + + void cloneChildren(TypeFunctionTypePack* t1, TypeFunctionTypePack* t2) + { + for (TypeFunctionTypeId& ty : t1->head) + t2->head.push_back(shallowClone(ty)); + + if (t1->tail) + t2->tail = shallowClone(*t1->tail); + } + + void cloneChildren(TypeFunctionVariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) + { + v2->type = shallowClone(v1->type); + } + + void cloneChildren(TypeFunctionGenericTypePack* g1, TypeFunctionGenericTypePack* g2) + { + // noop. + } +}; + +TypeFunctionTypeId deepClone(NotNull runtime, TypeFunctionTypeId ty) +{ + return TypeFunctionCloner(runtime).clone(ty); +} + +} // namespace Luau diff --git a/Analysis/src/TypeFunctionRuntimeBuilder.cpp b/Analysis/src/TypeFunctionRuntimeBuilder.cpp new file mode 100644 index 00000000..8a8779b2 --- /dev/null +++ b/Analysis/src/TypeFunctionRuntimeBuilder.cpp @@ -0,0 +1,1034 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/TypeFunctionRuntimeBuilder.h" + +#include "Luau/Ast.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/DenseHash.h" +#include "Luau/StringUtils.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypeFunctionRuntime.h" +#include "Luau/TypePack.h" +#include "Luau/ToString.h" + +#include + +// used to control the recursion limit of any operations done by user-defined type functions +// currently, controls serialization, deserialization, and `type.copy` +LUAU_DYNAMIC_FASTINTVARIABLE(LuauTypeFunctionSerdeIterationLimit, 100'000); +LUAU_FASTFLAG(LuauTypeFunFixHydratedClasses) +LUAU_FASTFLAG(LuauTypeFunReadWriteParents) + +namespace Luau +{ + +// Forked version of Clone.cpp +class TypeFunctionSerializer +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + TypeFunctionRuntimeBuilderState* state = nullptr; + NotNull typeFunctionRuntime; + + // A queue of TypeFunctionTypeIds that have been serialized, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is PrimitiveType, + // second must be TypeFunctionPrimitiveType; else there should be an error + std::vector> queue; + + SeenTypes types; // Mapping of TypeIds that have been shallow serialized to TypeFunctionTypeIds + SeenTypePacks packs; // Mapping of TypePackIds that have been shallow serialized to TypeFunctionTypePackIds + + int steps = 0; + +public: + explicit TypeFunctionSerializer(TypeFunctionRuntimeBuilderState* state) + : state(state) + , typeFunctionRuntime(state->ctx->typeFunctionRuntime) + , queue({}) + , types({}) + , packs({}) + { + } + + TypeFunctionTypeId serialize(TypeId ty) + { + shallowSerialize(ty); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + return nullptr; + + return find(ty).value_or(nullptr); + } + + TypeFunctionTypePackId serialize(TypePackId tp) + { + shallowSerialize(tp); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + return nullptr; + + return find(tp).value_or(nullptr); + } + +private: + bool hasExceededIterationLimit() const + { + if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit); + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit() || state->errors.size() != 0) + break; + + auto [ty, tfti] = queue.back(); + queue.pop_back(); + + serializeChildren(ty, tfti); + } + } + + std::optional find(TypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(Kind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind found at TypeFunctionRuntimeSerializer"); + return std::nullopt; + } + } + + TypeFunctionTypeId shallowSerialize(TypeId ty) + { + ty = follow(ty); + + if (auto it = find(ty)) + return *it; + + // Create a shallow serialization + TypeFunctionTypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case PrimitiveType::NilType: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::NilType)); + break; + case PrimitiveType::Boolean: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Boolean)); + break; + case PrimitiveType::Number: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Number)); + break; + case PrimitiveType::String: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::String)); + break; + case PrimitiveType::Thread: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Thread)); + break; + case PrimitiveType::Buffer: + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionPrimitiveType(TypeFunctionPrimitiveType::Buffer)); + break; + case PrimitiveType::Function: + case PrimitiveType::Table: + default: + { + std::string error = format("Argument of primitive type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnknownType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNeverType{}); + else if (auto a = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionAnyType{}); + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionBooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionSingletonType{TypeFunctionStringSingleton{ss->value}}); + else + { + std::string error = format("Argument of singleton type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + else if (auto u = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionUnionType{{}}); + else if (auto i = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionIntersectionType{{}}); + else if (auto n = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionNegationType{{}}); + else if (auto t = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto m = get(ty)) + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{{}, std::nullopt, std::nullopt}); + else if (auto f = get(ty)) + { + TypeFunctionTypePackId emptyTypePack = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{}); + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionFunctionType{{}, {}, emptyTypePack, emptyTypePack}); + } + else if (auto c = get(ty)) + { + if (FFlag::LuauTypeFunFixHydratedClasses) + { + // Since there aren't any new class types being created in type functions, we will deserialize by using a direct reference to the + // original class + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, ty}); + } + else + { + state->classesSerialized_DEPRECATED[c->name] = ty; + target = typeFunctionRuntime->typeArena.allocate( + TypeFunctionClassType{{}, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, /* classTy */ nullptr, c->name} + ); + } + } + else if (auto g = get(ty)) + { + Name name = g->name; + + if (!g->explicitName) + name = format("g%d", g->index); + + target = typeFunctionRuntime->typeArena.allocate(TypeFunctionGenericType{g->explicitName, false, name}); + } + else + { + std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypeFunctionTypePackId shallowSerialize(TypePackId tp) + { + tp = follow(tp); + + if (auto it = find(tp)) + return *it; + + // Create a shallow serialization + TypeFunctionTypePackId target = {}; + if (auto tPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionTypePack{{}}); + else if (auto vPack = get(tp)) + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionVariadicTypePack{}); + else if (auto gPack = get(tp)) + { + Name name = gPack->name; + + if (!gPack->explicitName) + name = format("g%d", gPack->index); + + target = typeFunctionRuntime->typePackArena.allocate(TypeFunctionGenericTypePack{gPack->explicitName, name}); + } + else + { + std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str()); + state->errors.push_back(error); + } + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void serializeChildren(const TypeId ty, TypeFunctionTypeId tfti) + { + if (auto [p1, p2] = std::tuple{get(ty), getMutable(tfti)}; p1 && p2) + serializeChildren(p1, p2); + else if (auto [u1, u2] = std::tuple{get(ty), getMutable(tfti)}; u1 && u2) + serializeChildren(u1, u2); + else if (auto [n1, n2] = std::tuple{get(ty), getMutable(tfti)}; n1 && n2) + serializeChildren(n1, n2); + else if (auto [a1, a2] = std::tuple{get(ty), getMutable(tfti)}; a1 && a2) + serializeChildren(a1, a2); + else if (auto [s1, s2] = std::tuple{get(ty), getMutable(tfti)}; s1 && s2) + serializeChildren(s1, s2); + else if (auto [u1, u2] = std::tuple{get(ty), getMutable(tfti)}; u1 && u2) + serializeChildren(u1, u2); + else if (auto [i1, i2] = std::tuple{get(ty), getMutable(tfti)}; i1 && i2) + serializeChildren(i1, i2); + else if (auto [n1, n2] = std::tuple{get(ty), getMutable(tfti)}; n1 && n2) + serializeChildren(n1, n2); + else if (auto [t1, t2] = std::tuple{get(ty), getMutable(tfti)}; t1 && t2) + serializeChildren(t1, t2); + else if (auto [m1, m2] = std::tuple{get(ty), getMutable(tfti)}; m1 && m2) + serializeChildren(m1, m2); + else if (auto [f1, f2] = std::tuple{get(ty), getMutable(tfti)}; f1 && f2) + serializeChildren(f1, f2); + else if (auto [c1, c2] = std::tuple{get(ty), getMutable(tfti)}; c1 && c2) + serializeChildren(c1, c2); + else if (auto [g1, g2] = std::tuple{get(ty), getMutable(tfti)}; g1 && g2) + serializeChildren(g1, g2); + else + { // Either this or ty and tfti do not represent the same type + std::string error = format("Argument of type %s is not currently serializable by type functions", toString(ty).c_str()); + state->errors.push_back(error); + } + } + + void serializeChildren(const TypePackId tp, TypeFunctionTypePackId tftp) + { + if (auto [tPack1, tPack2] = std::tuple{get(tp), getMutable(tftp)}; tPack1 && tPack2) + serializeChildren(tPack1, tPack2); + else if (auto [vPack1, vPack2] = std::tuple{get(tp), getMutable(tftp)}; vPack1 && vPack2) + serializeChildren(vPack1, vPack2); + else if (auto [gPack1, gPack2] = std::tuple{get(tp), getMutable(tftp)}; gPack1 && gPack2) + serializeChildren(gPack1, gPack2); + else + { // Either this or ty and tfti do not represent the same type + std::string error = format("Argument of type pack %s is not currently serializable by type functions", toString(tp).c_str()); + state->errors.push_back(error); + } + } + + void serializeChildren(Kind kind, TypeFunctionKind tfkind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + serializeChildren(*ty, *tfty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + serializeChildren(*tp, *tftp); + else + state->ctx->ice->ice("Serializing user defined type function arguments: kind and tfkind do not represent the same type"); + } + + void serializeChildren(const PrimitiveType* p1, TypeFunctionPrimitiveType* p2) + { + // noop. + } + + void serializeChildren(const UnknownType* u1, TypeFunctionUnknownType* u2) + { + // noop. + } + + void serializeChildren(const NeverType* n1, TypeFunctionNeverType* n2) + { + // noop. + } + + void serializeChildren(const AnyType* a1, TypeFunctionAnyType* a2) + { + // noop. + } + + void serializeChildren(const SingletonType* s1, TypeFunctionSingletonType* s2) + { + // noop. + } + + void serializeChildren(const UnionType* u1, TypeFunctionUnionType* u2) + { + for (const TypeId& ty : u1->options) + u2->components.push_back(shallowSerialize(ty)); + } + + void serializeChildren(const IntersectionType* i1, TypeFunctionIntersectionType* i2) + { + for (const TypeId& ty : i1->parts) + i2->components.push_back(shallowSerialize(ty)); + } + + void serializeChildren(const NegationType* n1, TypeFunctionNegationType* n2) + { + n2->type = shallowSerialize(n1->ty); + } + + void serializeChildren(const TableType* t1, TypeFunctionTableType* t2) + { + for (const auto& [k, p] : t1->props) + { + std::optional readTy = std::nullopt; + if (p.readTy) + readTy = shallowSerialize(*p.readTy); + + std::optional writeTy = std::nullopt; + if (p.writeTy) + writeTy = shallowSerialize(*p.writeTy); + + t2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (t1->indexer) + t2->indexer = TypeFunctionTableIndexer(shallowSerialize(t1->indexer->indexType), shallowSerialize(t1->indexer->indexResultType)); + } + + void serializeChildren(const MetatableType* m1, TypeFunctionTableType* m2) + { + // Serialize main part of the metatable immediately + if (auto tableTy = get(m1->table)) + serializeChildren(tableTy, m2); + + m2->metatable = shallowSerialize(m1->metatable); + } + + void serializeChildren(const FunctionType* f1, TypeFunctionFunctionType* f2) + { + f2->generics.reserve(f1->generics.size()); + for (auto ty : f1->generics) + f2->generics.push_back(shallowSerialize(ty)); + + f2->genericPacks.reserve(f1->genericPacks.size()); + for (auto tp : f1->genericPacks) + f2->genericPacks.push_back(shallowSerialize(tp)); + + f2->argTypes = shallowSerialize(f1->argTypes); + f2->retTypes = shallowSerialize(f1->retTypes); + } + + void serializeChildren(const ClassType* c1, TypeFunctionClassType* c2) + { + for (const auto& [k, p] : c1->props) + { + std::optional readTy = std::nullopt; + if (p.readTy) + readTy = shallowSerialize(*p.readTy); + + std::optional writeTy = std::nullopt; + if (p.writeTy) + writeTy = shallowSerialize(*p.writeTy); + + c2->props[k] = TypeFunctionProperty{readTy, writeTy}; + } + + if (c1->indexer) + c2->indexer = TypeFunctionTableIndexer(shallowSerialize(c1->indexer->indexType), shallowSerialize(c1->indexer->indexResultType)); + + if (c1->metatable) + c2->metatable = shallowSerialize(*c1->metatable); + + if (c1->parent) + { + TypeFunctionTypeId parent = shallowSerialize(*c1->parent); + + if (FFlag::LuauTypeFunReadWriteParents) + { + // we don't yet have read/write parents in the type inference engine. + c2->readParent = parent; + c2->writeParent = parent; + } + else + { + c2->parent_DEPRECATED = parent; + } + } + } + + void serializeChildren(const GenericType* g1, TypeFunctionGenericType* g2) + { + // noop. + } + + void serializeChildren(const TypePack* t1, TypeFunctionTypePack* t2) + { + for (const TypeId& ty : t1->head) + t2->head.push_back(shallowSerialize(ty)); + + if (t1->tail.has_value()) + t2->tail = shallowSerialize(*t1->tail); + } + + void serializeChildren(const VariadicTypePack* v1, TypeFunctionVariadicTypePack* v2) + { + v2->type = shallowSerialize(v1->ty); + } + + void serializeChildren(const GenericTypePack* v1, TypeFunctionGenericTypePack* v2) + { + // noop. + } +}; + +template +struct SerializedGeneric +{ + bool isNamed = false; + std::string name; + T type = nullptr; +}; + +struct SerializedFunctionScope +{ + size_t oldQueueSize = 0; + TypeFunctionFunctionType* function = nullptr; +}; + +// Complete inverse of TypeFunctionSerializer +class TypeFunctionDeserializer +{ + using SeenTypes = DenseHashMap; + using SeenTypePacks = DenseHashMap; + + TypeFunctionRuntimeBuilderState* state = nullptr; + NotNull typeFunctionRuntime; + + // A queue of TypeIds that have been deserialized, but whose interior types hasn't + // been updated to point to itself. Once all of its interior types + // has been updated, it gets removed from the queue. + + // queue.back() should always return two of same type in their respective sides + // For example `auto [first, second] = queue.back()`: if first is TypeFunctionPrimitiveType, + // second must be PrimitiveType; else there should be an error + std::vector> queue; + + // Generic types and packs currently in scope + // Generics are resolved by name even if runtime generic type pointers are different + // Multiple names mapping to the same generic can be in scope for nested generic functions + std::vector> genericTypes; + std::vector> genericPacks; + + // To track when generics go out of scope, we have a list of queue positions at which a specific function has introduced generics + std::vector functionScopes; + + SeenTypes types; // Mapping of TypeFunctionTypeIds that have been shallow deserialized to TypeIds + SeenTypePacks packs; // Mapping of TypeFunctionTypePackIds that have been shallow deserialized to TypePackIds + + int steps = 0; + +public: + explicit TypeFunctionDeserializer(TypeFunctionRuntimeBuilderState* state) + : state(state) + , typeFunctionRuntime(state->ctx->typeFunctionRuntime) + , queue({}) + , types({}) + , packs({}) + { + } + + TypeId deserialize(TypeFunctionTypeId ty) + { + shallowDeserialize(ty); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + { + TypeId error = state->ctx->builtins->errorRecoveryType(); + types[ty] = error; + return error; + } + + return find(ty).value_or(state->ctx->builtins->errorRecoveryType()); + } + + TypePackId deserialize(TypeFunctionTypePackId tp) + { + shallowDeserialize(tp); + run(); + + if (hasExceededIterationLimit() || state->errors.size() != 0) + { + TypePackId error = state->ctx->builtins->errorRecoveryTypePack(); + packs[tp] = error; + return error; + } + + return find(tp).value_or(state->ctx->builtins->errorRecoveryTypePack()); + } + +private: + bool hasExceededIterationLimit() const + { + if (DFInt::LuauTypeFunctionSerdeIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(DFInt::LuauTypeFunctionSerdeIterationLimit); + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit() || state->errors.size() != 0) + break; + + auto [tfti, ty] = queue.back(); + queue.pop_back(); + + deserializeChildren(tfti, ty); + + // If we have completed working on all children of a function, remove the generic parameters from scope + if (!functionScopes.empty() && queue.size() == functionScopes.back().oldQueueSize && state->errors.empty()) + { + closeFunctionScope(functionScopes.back().function); + functionScopes.pop_back(); + } + } + } + + std::optional find(TypeFunctionTypeId ty) const + { + if (auto result = types.find(ty)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionTypePackId tp) const + { + if (auto result = packs.find(tp)) + return *result; + + return std::nullopt; + } + + std::optional find(TypeFunctionKind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind found at TypeFunctionDeserializer"); + return std::nullopt; + } + } + + void closeFunctionScope(TypeFunctionFunctionType* f) + { + if (!f->generics.empty()) + { + LUAU_ASSERT(genericTypes.size() >= f->generics.size()); + genericTypes.erase(genericTypes.begin() + int(genericTypes.size() - f->generics.size()), genericTypes.end()); + } + + if (!f->genericPacks.empty()) + { + LUAU_ASSERT(genericPacks.size() >= f->genericPacks.size()); + genericPacks.erase(genericPacks.begin() + int(genericPacks.size() - f->genericPacks.size()), genericPacks.end()); + } + } + + TypeId shallowDeserialize(TypeFunctionTypeId ty) + { + if (auto it = find(ty)) + return *it; + + // Create a shallow deserialization + TypeId target = {}; + if (auto p = get(ty)) + { + switch (p->type) + { + case TypeFunctionPrimitiveType::Type::NilType: + target = state->ctx->builtins->nilType; + break; + case TypeFunctionPrimitiveType::Type::Boolean: + target = state->ctx->builtins->booleanType; + break; + case TypeFunctionPrimitiveType::Type::Number: + target = state->ctx->builtins->numberType; + break; + case TypeFunctionPrimitiveType::Type::String: + target = state->ctx->builtins->stringType; + break; + case TypeFunctionPrimitiveType::Type::Thread: + target = state->ctx->builtins->threadType; + break; + case TypeFunctionPrimitiveType::Type::Buffer: + target = state->ctx->builtins->bufferType; + break; + default: + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + } + else if (auto u = get(ty)) + target = state->ctx->builtins->unknownType; + else if (auto n = get(ty)) + target = state->ctx->builtins->neverType; + else if (auto a = get(ty)) + target = state->ctx->builtins->anyType; + else if (auto s = get(ty)) + { + if (auto bs = get(s)) + target = state->ctx->arena->addType(SingletonType{BooleanSingleton{bs->value}}); + else if (auto ss = get(s)) + target = state->ctx->arena->addType(SingletonType{StringSingleton{ss->value}}); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + else if (auto u = get(ty)) + target = state->ctx->arena->addTV(Type(UnionType{{}})); + else if (auto i = get(ty)) + target = state->ctx->arena->addTV(Type(IntersectionType{{}})); + else if (auto n = get(ty)) + target = state->ctx->arena->addType(NegationType{state->ctx->builtins->unknownType}); + else if (auto t = get(ty); t && !t->metatable.has_value()) + target = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + else if (auto m = get(ty); m && m->metatable.has_value()) + { + TypeId emptyTable = state->ctx->arena->addType(TableType{TableType::Props{}, std::nullopt, TypeLevel{}, TableState::Sealed}); + target = state->ctx->arena->addType(MetatableType{emptyTable, emptyTable}); + } + else if (auto f = get(ty)) + { + TypePackId emptyTypePack = state->ctx->arena->addTypePack(TypePack{}); + target = state->ctx->arena->addType(FunctionType{emptyTypePack, emptyTypePack, {}, false}); + } + else if (auto c = get(ty)) + { + if (FFlag::LuauTypeFunFixHydratedClasses) + { + target = c->classTy; + } + else + { + if (auto result = state->classesSerialized_DEPRECATED.find(c->name_DEPRECATED)) + target = *result; + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious class type is being deserialized"); + } + } + else if (auto g = get(ty)) + { + if (g->isPack) + { + state->errors.push_back(format("Generic type pack '%s...' cannot be placed in a type position", g->name.c_str())); + return nullptr; + } + else + { + auto it = std::find_if( + genericTypes.rbegin(), + genericTypes.rend(), + [&](const SerializedGeneric& el) + { + return g->isNamed == el.isNamed && g->name == el.name; + } + ); + + if (it == genericTypes.rend()) + { + state->errors.push_back(format("Generic type '%s' is not in a scope of the active generic function", g->name.c_str())); + return nullptr; + } + + target = it->type; + } + } + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + + types[ty] = target; + queue.emplace_back(ty, target); + return target; + } + + TypePackId shallowDeserialize(TypeFunctionTypePackId tp) + { + if (auto it = find(tp)) + return *it; + + // Create a shallow deserialization + TypePackId target = {}; + if (auto tPack = get(tp)) + { + target = state->ctx->arena->addTypePack(TypePack{}); + } + else if (auto vPack = get(tp)) + { + target = state->ctx->arena->addTypePack(VariadicTypePack{}); + } + else if (auto gPack = get(tp)) + { + auto it = std::find_if( + genericPacks.rbegin(), + genericPacks.rend(), + [&](const SerializedGeneric& el) + { + return gPack->isNamed == el.isNamed && gPack->name == el.name; + } + ); + + if (it == genericPacks.rend()) + { + state->errors.push_back(format("Generic type pack '%s...' is not in a scope of the active generic function", gPack->name.c_str())); + return nullptr; + } + + target = it->type; + } + else + { + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + + packs[tp] = target; + queue.emplace_back(tp, target); + return target; + } + + void deserializeChildren(TypeFunctionTypeId tfti, TypeId ty) + { + if (auto [p1, p2] = std::tuple{getMutable(ty), getMutable(tfti)}; p1 && p2) + deserializeChildren(p2, p1); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + deserializeChildren(u2, u1); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + deserializeChildren(n2, n1); + else if (auto [a1, a2] = std::tuple{getMutable(ty), getMutable(tfti)}; a1 && a2) + deserializeChildren(a2, a1); + else if (auto [s1, s2] = std::tuple{getMutable(ty), getMutable(tfti)}; s1 && s2) + deserializeChildren(s2, s1); + else if (auto [u1, u2] = std::tuple{getMutable(ty), getMutable(tfti)}; u1 && u2) + deserializeChildren(u2, u1); + else if (auto [i1, i2] = std::tuple{getMutable(ty), getMutable(tfti)}; i1 && i2) + deserializeChildren(i2, i1); + else if (auto [n1, n2] = std::tuple{getMutable(ty), getMutable(tfti)}; n1 && n2) + deserializeChildren(n2, n1); + else if (auto [t1, t2] = std::tuple{getMutable(ty), getMutable(tfti)}; + t1 && t2 && !t2->metatable.has_value()) + deserializeChildren(t2, t1); + else if (auto [m1, m2] = std::tuple{getMutable(ty), getMutable(tfti)}; + m1 && m2 && m2->metatable.has_value()) + deserializeChildren(m2, m1); + else if (auto [f1, f2] = std::tuple{getMutable(ty), getMutable(tfti)}; f1 && f2) + deserializeChildren(f2, f1); + else if (auto [c1, c2] = std::tuple{getMutable(ty), getMutable(tfti)}; c1 && c2) + deserializeChildren(c2, c1); + else if (auto [g1, g2] = std::tuple{getMutable(ty), getMutable(tfti)}; g1 && g2) + deserializeChildren(g2, g1); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + + void deserializeChildren(TypeFunctionTypePackId tftp, TypePackId tp) + { + if (auto [tPack1, tPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; tPack1 && tPack2) + deserializeChildren(tPack2, tPack1); + else if (auto [vPack1, vPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; + vPack1 && vPack2) + deserializeChildren(vPack2, vPack1); + else if (auto [gPack1, gPack2] = std::tuple{getMutable(tp), getMutable(tftp)}; gPack1 && gPack2) + deserializeChildren(gPack2, gPack1); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: mysterious type is being deserialized"); + } + + void deserializeChildren(TypeFunctionKind tfkind, Kind kind) + { + if (auto [ty, tfty] = std::tuple{get(kind), get(tfkind)}; ty && tfty) + deserializeChildren(*tfty, *ty); + else if (auto [tp, tftp] = std::tuple{get(kind), get(tfkind)}; tp && tftp) + deserializeChildren(*tftp, *tp); + else + state->ctx->ice->ice("Deserializing user defined type function arguments: tfkind and kind do not represent the same type"); + } + + void deserializeChildren(TypeFunctionPrimitiveType* p2, PrimitiveType* p1) + { + // noop. + } + + void deserializeChildren(TypeFunctionUnknownType* u2, UnknownType* u1) + { + // noop. + } + + void deserializeChildren(TypeFunctionNeverType* n2, NeverType* n1) + { + // noop. + } + + void deserializeChildren(TypeFunctionAnyType* a2, AnyType* a1) + { + // noop. + } + + void deserializeChildren(TypeFunctionSingletonType* s2, SingletonType* s1) + { + // noop. + } + + void deserializeChildren(TypeFunctionUnionType* u2, UnionType* u1) + { + for (TypeFunctionTypeId& ty : u2->components) + u1->options.push_back(shallowDeserialize(ty)); + } + + void deserializeChildren(TypeFunctionIntersectionType* i2, IntersectionType* i1) + { + for (TypeFunctionTypeId& ty : i2->components) + i1->parts.push_back(shallowDeserialize(ty)); + } + + void deserializeChildren(TypeFunctionNegationType* n2, NegationType* n1) + { + n1->ty = shallowDeserialize(n2->type); + } + + void deserializeChildren(TypeFunctionTableType* t2, TableType* t1) + { + for (const auto& [k, p] : t2->props) + { + if (p.readTy && p.writeTy) + t1->props[k] = Property::rw(shallowDeserialize(*p.readTy), shallowDeserialize(*p.writeTy)); + else if (p.readTy) + t1->props[k] = Property::readonly(shallowDeserialize(*p.readTy)); + else if (p.writeTy) + t1->props[k] = Property::writeonly(shallowDeserialize(*p.writeTy)); + } + + if (t2->indexer.has_value()) + t1->indexer = TableIndexer(shallowDeserialize(t2->indexer->keyType), shallowDeserialize(t2->indexer->valueType)); + } + + void deserializeChildren(TypeFunctionTableType* m2, MetatableType* m1) + { + TypeFunctionTypeId temp = typeFunctionRuntime->typeArena.allocate(TypeFunctionTableType{m2->props, m2->indexer}); + m1->table = shallowDeserialize(temp); + + if (m2->metatable.has_value()) + m1->metatable = shallowDeserialize(*m2->metatable); + } + + void deserializeChildren(TypeFunctionFunctionType* f2, FunctionType* f1) + { + functionScopes.push_back({queue.size(), f2}); + + std::set> genericNames; + + // Introduce generic function parameters into scope + for (auto ty : f2->generics) + { + auto gty = get(ty); + LUAU_ASSERT(gty && !gty->isPack); + + std::pair nameKey = std::make_pair(gty->isNamed, gty->name); + + // Duplicates are not allowed + if (genericNames.find(nameKey) != genericNames.end()) + { + state->errors.push_back(format("Duplicate type parameter '%s'", gty->name.c_str())); + return; + } + + genericNames.insert(nameKey); + + TypeId mapping = state->ctx->arena->addTV(Type(gty->isNamed ? GenericType{state->ctx->scope.get(), gty->name} : GenericType{})); + genericTypes.push_back({gty->isNamed, gty->name, mapping}); + } + + for (auto tp : f2->genericPacks) + { + auto gtp = get(tp); + LUAU_ASSERT(gtp); + + std::pair nameKey = std::make_pair(gtp->isNamed, gtp->name); + + // Duplicates are not allowed + if (genericNames.find(nameKey) != genericNames.end()) + { + state->errors.push_back(format("Duplicate type parameter '%s'", gtp->name.c_str())); + return; + } + + genericNames.insert(nameKey); + + TypePackId mapping = + state->ctx->arena->addTypePack(TypePackVar(gtp->isNamed ? GenericTypePack{state->ctx->scope.get(), gtp->name} : GenericTypePack{})); + genericPacks.push_back({gtp->isNamed, gtp->name, mapping}); + } + + f1->generics.reserve(f2->generics.size()); + for (auto ty : f2->generics) + f1->generics.push_back(shallowDeserialize(ty)); + + f1->genericPacks.reserve(f2->genericPacks.size()); + for (auto tp : f2->genericPacks) + f1->genericPacks.push_back(shallowDeserialize(tp)); + + if (f2->argTypes) + f1->argTypes = shallowDeserialize(f2->argTypes); + + if (f2->retTypes) + f1->retTypes = shallowDeserialize(f2->retTypes); + } + + void deserializeChildren(TypeFunctionClassType* c2, ClassType* c1) + { + // noop. + } + + void deserializeChildren(TypeFunctionGenericType* g2, GenericType* g1) + { + // noop. + } + + void deserializeChildren(TypeFunctionTypePack* t2, TypePack* t1) + { + for (TypeFunctionTypeId& ty : t2->head) + t1->head.push_back(shallowDeserialize(ty)); + + if (t2->tail.has_value()) + t1->tail = shallowDeserialize(*t2->tail); + } + + void deserializeChildren(TypeFunctionVariadicTypePack* v2, VariadicTypePack* v1) + { + v1->ty = shallowDeserialize(v2->type); + } + + void deserializeChildren(TypeFunctionGenericTypePack* v2, GenericTypePack* v1) + { + // noop. + } +}; + +TypeFunctionTypeId serialize(TypeId ty, TypeFunctionRuntimeBuilderState* state) +{ + return TypeFunctionSerializer(state).serialize(ty); +} + +TypeId deserialize(TypeFunctionTypeId ty, TypeFunctionRuntimeBuilderState* state) +{ + return TypeFunctionDeserializer(state).deserialize(ty); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 6b2e861d..73f8b1be 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -23,18 +23,18 @@ #include #include -LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) +LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 165) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 20000) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 300) LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification) LUAU_FASTFLAG(LuauInstantiateInSubtyping) -LUAU_FASTFLAGVARIABLE(LuauRemoveBadRelationalOperatorWarning, false) -LUAU_FASTFLAGVARIABLE(LuauOkWithIteratingOverTableProperties, false) -LUAU_FASTFLAGVARIABLE(LuauAcceptIndexingTableUnionsIntersections, false) +LUAU_FASTFLAGVARIABLE(LuauOldSolverCreatesChildScopePointers) +LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -265,11 +265,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo ScopePtr parentScope = environmentScope.value_or(globalScope); ScopePtr moduleScope = std::make_shared(parentScope); - if (module.cyclic) - moduleScope->returnType = addTypePack(TypePack{{anyType}, std::nullopt}); - else - moduleScope->returnType = freshTypePack(moduleScope); - + moduleScope->returnType = freshTypePack(moduleScope); moduleScope->varargPack = anyTypePack; currentModule->scopes.push_back(std::make_pair(module.root->location, moduleScope)); @@ -767,8 +763,12 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& state struct Demoter : Substitution { - Demoter(TypeArena* arena) + TypeArena* arena = nullptr; + NotNull builtins; + Demoter(TypeArena* arena, NotNull builtins) : Substitution(TxnLog::empty(), arena) + , arena(arena) + , builtins(builtins) { } @@ -794,7 +794,8 @@ struct Demoter : Substitution { auto ftv = get(ty); LUAU_ASSERT(ftv); - return addType(FreeType{demotedLevel(ftv->level)}); + return FFlag::LuauFreeTypesMustHaveBounds ? arena->freshType(builtins, demotedLevel(ftv->level)) + : addType(FreeType{demotedLevel(ftv->level)}); } TypePackId clean(TypePackId tp) override @@ -841,7 +842,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatReturn& retur } } - Demoter demoter{¤tModule->internalTypes}; + Demoter demoter{¤tModule->internalTypes, builtinTypes}; demoter.demote(expectedTypes); TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; @@ -958,7 +959,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assig else if (auto tail = valueIter.tail()) { TypePackId tailPack = follow(*tail); - if (get(tailPack)) + if (get(tailPack)) right = errorRecoveryType(scope); else if (auto vtp = get(tailPack)) right = vtp->ty; @@ -1238,7 +1239,7 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) iterTy = freshType(scope); unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location); } - else if (get(callRetPack) || !first(callRetPack)) + else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) unify(errorRecoveryType(scope), var, scope, forin.location); @@ -1284,20 +1285,11 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) for (size_t i = 2; i < varTypes.size(); ++i) unify(nilType, varTypes[i], scope, forin.location); } - else if (isNonstrictMode() || FFlag::LuauOkWithIteratingOverTableProperties) + else { for (TypeId var : varTypes) unify(unknownType, var, scope, forin.location); } - else - { - TypeId varTy = errorRecoveryType(loopScope); - - for (TypeId var : varTypes) - unify(varTy, var, scope, forin.location); - - reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); - } return check(loopScope, *forin.body); } @@ -1975,7 +1967,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp *asMutable(varargPack) = TypePack{{head}, tail}; return WithPredicate{head}; } - if (get(varargPack)) + if (get(varargPack)) return WithPredicate{errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return WithPredicate{vtp->ty}; @@ -2005,7 +1997,7 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp unify(pack, retPack, scope, expr.location); return {head, std::move(result.predicates)}; } - if (get(retPack)) + if (get(retPack)) return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; @@ -2804,34 +2796,19 @@ TypeId TypeChecker::checkRelationalOperation( { reportErrors(state.errors); - if (FFlag::LuauRemoveBadRelationalOperatorWarning) + // The original version of this check also produced this error when we had a union type. + // However, the old solver does not readily have the ability to discern if the union is comparable. + // This is the case when the lhs is e.g. a union of singletons and the rhs is the combined type. + // The new solver has much more powerful logic for resolving relational operators, but for now, + // we need to be conservative in the old solver to deliver a reasonable developer experience. + if (!isEquality && state.errors.empty() && isBoolean(leftType)) { - // The original version of this check also produced this error when we had a union type. - // However, the old solver does not readily have the ability to discern if the union is comparable. - // This is the case when the lhs is e.g. a union of singletons and the rhs is the combined type. - // The new solver has much more powerful logic for resolving relational operators, but for now, - // we need to be conservative in the old solver to deliver a reasonable developer experience. - if (!isEquality && state.errors.empty() && isBoolean(leftType)) - { - reportError( - expr.location, - GenericError{ - format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str()) - } - ); - } - } - else - { - if (!isEquality && state.errors.empty() && (get(leftType) || isBoolean(leftType))) - { - reportError( - expr.location, - GenericError{ - format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str()) - } - ); - } + reportError( + expr.location, + GenericError{ + format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str()) + } + ); } return booleanType; @@ -2896,7 +2873,7 @@ TypeId TypeChecker::checkRelationalOperation( std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location, /* addErrors= */ true); if (metamethod) { - if (const FunctionType* ftv = get(*metamethod)) + if (const FunctionType* ftv = get(follow(*metamethod))) { if (isEquality) { @@ -3507,7 +3484,6 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } } - if (FFlag::LuauAcceptIndexingTableUnionsIntersections) { // We're going to have a whole vector. std::vector tableTypes{}; @@ -3658,57 +3634,6 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex return addType(IntersectionType{{resultTypes.begin(), resultTypes.end()}}); } - else - { - TableType* exprTable = getMutableTableType(exprType); - if (!exprTable) - { - reportError(TypeError{expr.expr->location, NotATable{exprType}}); - return errorRecoveryType(scope); - } - - if (value) - { - const auto& it = exprTable->props.find(value->value.data); - if (it != exprTable->props.end()) - { - return it->second.type(); - } - else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free) - { - TypeId resultType = freshType(scope); - Property& property = exprTable->props[value->value.data]; - property.setType(resultType); - property.location = expr.index->location; - return resultType; - } - } - - if (exprTable->indexer) - { - const TableIndexer& indexer = *exprTable->indexer; - unify(indexType, indexer.indexType, scope, expr.index->location); - return indexer.indexResultType; - } - else if ((ctx == ValueContext::LValue && exprTable->state == TableState::Unsealed) || exprTable->state == TableState::Free) - { - TypeId indexerType = freshType(exprTable->level); - unify(indexType, indexerType, scope, expr.location); - TypeId indexResultType = freshType(exprTable->level); - - exprTable->indexer = TableIndexer{anyIfNonstrict(indexerType), anyIfNonstrict(indexResultType)}; - return indexResultType; - } - else - { - /* - * If we use [] indexing to fetch a property from a sealed table that - * has no indexer, we have no idea if it will work so we just return any - * and hope for the best. - */ - return anyType; - } - } } // Answers the question: "Can I define another function with this name?" @@ -4163,7 +4088,7 @@ void TypeChecker::checkArgumentList( if (argIter.tail()) { TypePackId tail = *argIter.tail(); - if (state.log.getMutable(tail)) + if (state.log.getMutable(tail)) { // Unify remaining parameters so we don't leave any free-types hanging around. while (paramIter != endIter) @@ -4248,7 +4173,7 @@ void TypeChecker::checkArgumentList( } TypePackId tail = state.log.follow(*paramIter.tail()); - if (state.log.getMutable(tail)) + if (state.log.getMutable(tail)) { // Function is variadic. Ok. return; @@ -4384,7 +4309,7 @@ WithPredicate TypeChecker::checkExprPackHelper(const ScopePtr& scope WithPredicate argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); TypePackId argPack = argListResult.type; - if (get(argPack)) + if (get(argPack)) return WithPredicate{errorRecoveryTypePack(scope)}; TypePack* args = nullptr; @@ -4490,7 +4415,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } } - Demoter demoter{¤tModule->internalTypes}; + Demoter demoter{¤tModule->internalTypes, builtinTypes}; demoter.demote(expectedTypes); return expectedTypes; @@ -4588,10 +4513,10 @@ std::unique_ptr> TypeChecker::checkCallOverload( // When this function type has magic functions and did return something, we select that overload instead. // TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution. - if (ftv->magicFunction) + if (ftv->magic) { // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 - if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) + if (std::optional> ret = ftv->magic->handleOldSolver(*this, scope, expr, argListResult)) return std::make_unique>(std::move(*ret)); } @@ -4974,7 +4899,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module TypePackId modulePack = module->returnType; - if (get(modulePack)) + if (get(modulePack)) return errorRecoveryType(scope); std::optional moduleType = first(modulePack); @@ -5063,17 +4988,17 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c { // First try unifying with the original uninstantiated type // but if that fails, try the instantiated one. - Unifier child = state.makeChildUnifier(); - child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); - if (!child.errors.empty()) + std::unique_ptr child = state.makeChildUnifier(); + child->tryUnify(subTy, superTy, /*isFunctionCall*/ false); + if (!child->errors.empty()) { - TypeId instantiated = instantiate(scope, subTy, state.location, &child.log); + TypeId instantiated = instantiate(scope, subTy, state.location, &child->log); if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors - state.log.concat(std::move(child.log)); + state.log.concat(std::move(child->log)); - state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); + state.errors.insert(state.errors.end(), child->errors.begin(), child->errors.end()); } else { @@ -5082,7 +5007,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, c } else { - state.log.concat(std::move(child.log)); + state.log.concat(std::move(child->log)); } } } @@ -5287,6 +5212,13 @@ LUAU_NOINLINE void TypeChecker::reportErrorCodeTooComplex(const Location& locati ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel) { ScopePtr scope = std::make_shared(parent, subLevel); + if (FFlag::LuauOldSolverCreatesChildScopePointers) + { + scope->location = location; + scope->returnType = parent->returnType; + parent->children.emplace_back(scope.get()); + } + currentModule->scopes.push_back(std::make_pair(location, scope)); return scope; } @@ -5297,6 +5229,12 @@ ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& locatio ScopePtr scope = std::make_shared(parent); scope->level = parent->level; scope->varargPack = parent->varargPack; + if (FFlag::LuauOldSolverCreatesChildScopePointers) + { + scope->location = location; + scope->returnType = parent->returnType; + parent->children.emplace_back(scope.get()); + } currentModule->scopes.push_back(std::make_pair(location, scope)); return scope; @@ -5342,7 +5280,8 @@ TypeId TypeChecker::freshType(const ScopePtr& scope) TypeId TypeChecker::freshType(TypeLevel level) { - return currentModule->internalTypes.addType(Type(FreeType(level))); + return FFlag::LuauFreeTypesMustHaveBounds ? currentModule->internalTypes.freshType(builtinTypes, level) + : currentModule->internalTypes.addType(Type(FreeType(level))); } TypeId TypeChecker::singletonType(bool value) @@ -5787,6 +5726,12 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno } else if (const auto& un = annotation.as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (un->types.size == 1) + return resolveType(scope, *un->types.data[0]); + } + std::vector types; for (AstType* ann : un->types) types.push_back(resolveType(scope, *ann)); @@ -5795,12 +5740,22 @@ TypeId TypeChecker::resolveTypeWorker(const ScopePtr& scope, const AstType& anno } else if (const auto& un = annotation.as()) { + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (un->types.size == 1) + return resolveType(scope, *un->types.data[0]); + } + std::vector types; for (AstType* ann : un->types) types.push_back(resolveType(scope, *ann)); return addType(IntersectionType{types}); } + else if (const auto& g = annotation.as()) + { + return resolveType(scope, *g->type); + } else if (const auto& tsb = annotation.as()) { return singletonType(tsb->value); @@ -5958,8 +5913,8 @@ GenericTypeDefinitions TypeChecker::createGenericTypes( const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, - const AstArray& genericPackNames, + const AstArray& genericNames, + const AstArray& genericPackNames, bool useCache ) { @@ -5969,14 +5924,14 @@ GenericTypeDefinitions TypeChecker::createGenericTypes( std::vector generics; - for (const AstGenericType& generic : genericNames) + for (const AstGenericType* generic : genericNames) { std::optional defaultValue; - if (generic.defaultValue) - defaultValue = resolveType(scope, *generic.defaultValue); + if (generic->defaultValue) + defaultValue = resolveType(scope, *generic->defaultValue); - Name n = generic.name.value; + Name n = generic->name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic types have the same name. @@ -6005,14 +5960,14 @@ GenericTypeDefinitions TypeChecker::createGenericTypes( std::vector genericPacks; - for (const AstGenericTypePack& genericPack : genericPackNames) + for (const AstGenericTypePack* genericPack : genericPackNames) { std::optional defaultValue; - if (genericPack.defaultValue) - defaultValue = resolveTypePack(scope, *genericPack.defaultValue); + if (genericPack->defaultValue) + defaultValue = resolveTypePack(scope, *genericPack->defaultValue); - Name n = genericPack.name.value; + Name n = genericPack->name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic types have the same name. @@ -6418,7 +6373,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r } // We're only interested in the root class of any classes. - if (auto ctv = get(type); !ctv || ctv->parent != builtinTypes->classType) + if (auto ctv = get(type); !ctv || (ctv->parent != builtinTypes->classType && !hasTag(type, kTypeofRootTag))) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. diff --git a/Analysis/src/TypePath.cpp b/Analysis/src/TypePath.cpp index d2113ee3..baf7bb11 100644 --- a/Analysis/src/TypePath.cpp +++ b/Analysis/src/TypePath.cpp @@ -14,7 +14,7 @@ #include LUAU_FASTFLAG(LuauSolverV2); - +LUAU_FASTFLAGVARIABLE(LuauDisableNewSolverAssertsInMixedMode); // Maximum number of steps to follow when traversing a path. May not always // equate to the number of components in a path, depending on the traversal // logic. @@ -156,14 +156,16 @@ Path PathBuilder::build() PathBuilder& PathBuilder::readProp(std::string name) { - LUAU_ASSERT(FFlag::LuauSolverV2); + if (!FFlag::LuauDisableNewSolverAssertsInMixedMode) + LUAU_ASSERT(FFlag::LuauSolverV2); components.push_back(Property{std::move(name), true}); return *this; } PathBuilder& PathBuilder::writeProp(std::string name) { - LUAU_ASSERT(FFlag::LuauSolverV2); + if (!FFlag::LuauDisableNewSolverAssertsInMixedMode) + LUAU_ASSERT(FFlag::LuauSolverV2); components.push_back(Property{std::move(name), false}); return *this; } @@ -415,6 +417,14 @@ struct TraversalState switch (field) { + case TypePath::TypeField::Table: + if (auto mt = get(current)) + { + updateCurrent(mt->table); + return true; + } + + return false; case TypePath::TypeField::Metatable: if (auto currentType = get(current)) { @@ -561,6 +571,9 @@ std::string toString(const TypePath::Path& path, bool prefixDot) switch (c) { + case TypePath::TypeField::Table: + result << "table"; + break; case TypePath::TypeField::Metatable: result << "metatable"; break; diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index f1c60f06..bf8cf533 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -5,12 +5,16 @@ #include "Luau/Normalize.h" #include "Luau/Scope.h" #include "Luau/ToString.h" +#include "Luau/Type.h" #include "Luau/TypeInfer.h" #include LUAU_FASTFLAG(LuauSolverV2); - +LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete); +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope); +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) +LUAU_FASTFLAG(LuauDisableNewSolverAssertsInMixedMode) namespace Luau { @@ -317,9 +321,11 @@ TypePack extendTypePack( { FreeType ft{ftp->scope, builtinTypes->neverType, builtinTypes->unknownType}; t = arena.addType(ft); + if (FFlag::LuauTrackInteriorFreeTypesOnScope) + trackInteriorFreeType(ftp->scope, t); } else - t = arena.freshType(ftp->scope); + t = FFlag::LuauFreeTypesMustHaveBounds ? arena.freshType(builtinTypes, ftp->scope) : arena.freshType_DEPRECATED(ftp->scope); } newPack.head.push_back(t); @@ -331,7 +337,7 @@ TypePack extendTypePack( return result; } - else if (const Unifiable::Error* etp = getMutable(pack)) + else if (auto etp = getMutable(pack)) { while (result.head.size() < length) result.head.push_back(builtinTypes->errorRecoveryType()); @@ -426,7 +432,7 @@ TypeId stripNil(NotNull builtinTypes, TypeArena& arena, TypeId ty) ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty) { - LUAU_ASSERT(FFlag::LuauSolverV2); + LUAU_ASSERT(FFlag::LuauSolverV2 || FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete); std::shared_ptr normType = normalizer->normalize(ty); if (!normType) @@ -479,4 +485,87 @@ ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId return result; } +bool isLiteral(const AstExpr* expr) +{ + return ( + expr->is() || expr->is() || expr->is() || expr->is() || + expr->is() || expr->is() + ); +} +/** + * Visitor which, given an expression and a mapping from expression to TypeId, + * determines if there are any literal expressions that contain blocked types. + * This is used for bi-directional inference: we want to "apply" a type from + * a function argument or a type annotation to a literal. + */ +class BlockedTypeInLiteralVisitor : public AstVisitor +{ +public: + explicit BlockedTypeInLiteralVisitor(NotNull> astTypes, NotNull> toBlock) + : astTypes_{astTypes} + , toBlock_{toBlock} + { + } + bool visit(AstNode*) override + { + return false; + } + + bool visit(AstExpr* e) override + { + auto ty = astTypes_->find(e); + if (ty && (get(follow(*ty)) != nullptr)) + { + toBlock_->push_back(*ty); + } + return isLiteral(e) || e->is(); + } + +private: + NotNull> astTypes_; + NotNull> toBlock_; +}; + +std::vector findBlockedTypesIn(AstExprTable* expr, NotNull> astTypes) +{ + std::vector toBlock; + BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}}; + expr->visit(&v); + return toBlock; +} + +std::vector findBlockedArgTypesIn(AstExprCall* expr, NotNull> astTypes) +{ + std::vector toBlock; + BlockedTypeInLiteralVisitor v{astTypes, NotNull{&toBlock}}; + for (auto arg : expr->args) + { + if (isLiteral(arg) || arg->is()) + { + arg->visit(&v); + } + } + return toBlock; +} + +void trackInteriorFreeType(Scope* scope, TypeId ty) +{ + if (FFlag::LuauDisableNewSolverAssertsInMixedMode) + LUAU_ASSERT(FFlag::LuauTrackInteriorFreeTypesOnScope); + else + LUAU_ASSERT(FFlag::LuauSolverV2 && FFlag::LuauTrackInteriorFreeTypesOnScope); + for (; scope; scope = scope->parent.get()) + { + if (scope->interiorFreeTypes) + { + scope->interiorFreeTypes->push_back(ty); + return; + } + } + // There should at least be *one* generalization constraint per module + // where `interiorFreeTypes` is present, which would be the one made + // by ConstraintGenerator::visitModuleRoot. + LUAU_ASSERT(!"No scopes in parent chain had a present `interiorFreeTypes` member."); +} + } // namespace Luau diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index 2ceb97aa..d9f3947f 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -1,5 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Unifiable.h" +#include "Luau/TypeFwd.h" +#include "Luau/TypePack.h" namespace Luau { @@ -13,12 +15,17 @@ int freshIndex() return ++nextIndex; } -Error::Error() +template +Error::Error() : index(++nextIndex) { } -int Error::nextIndex = 0; +template +int Error::nextIndex = 0; + +template struct Error; +template struct Error; } // namespace Unifiable } // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index fa7ff876..47b7cc41 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -17,12 +17,12 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) -LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping, false) +LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping) +LUAU_FASTFLAGVARIABLE(LuauTransitiveSubtyping) LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) -LUAU_FASTFLAGVARIABLE(LuauUnifierShouldNotCopyError, false) -LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart, false) +LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering) +LUAU_FASTFLAGVARIABLE(LuauUnifierRecursionOnRestart) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) namespace Luau { @@ -33,38 +33,20 @@ struct PromoteTypeLevels final : TypeOnceVisitor const TypeArena* typeArena = nullptr; TypeLevel minLevel; - Scope* outerScope = nullptr; - bool useScopes; - - PromoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes) + PromoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel) : log(log) , typeArena(typeArena) , minLevel(minLevel) - , outerScope(outerScope) - , useScopes(useScopes) { } template void promote(TID ty, T* t) { - if (useScopes && !t) - return; - LUAU_ASSERT(t); - if (useScopes) - { - if (subsumesStrict(outerScope, t->scope)) - log.changeScope(ty, NotNull{outerScope}); - } - else - { - if (minLevel.subsumesStrict(t->level)) - { - log.changeLevel(ty, minLevel); - } - } + if (minLevel.subsumesStrict(t->level)) + log.changeLevel(ty, minLevel); } bool visit(TypeId ty) override @@ -141,23 +123,23 @@ struct PromoteTypeLevels final : TypeOnceVisitor } }; -static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes, TypeId ty) +static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (ty->owningArena != typeArena) return; - PromoteTypeLevels ptl{log, typeArena, minLevel, outerScope, useScopes}; + PromoteTypeLevels ptl{log, typeArena, minLevel}; ptl.traverse(ty); } -void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, Scope* outerScope, bool useScopes, TypePackId tp) +void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (tp->owningArena != typeArena) return; - PromoteTypeLevels ptl{log, typeArena, minLevel, outerScope, useScopes}; + PromoteTypeLevels ptl{log, typeArena, minLevel}; ptl.traverse(tp); } @@ -370,12 +352,9 @@ static std::optional> getTableMatchT } template -static bool subsumes(bool useScopes, TY_A* left, TY_B* right) +static bool subsumes(TY_A* left, TY_B* right) { - if (useScopes) - return subsumes(left->scope, right->scope); - else - return left->level.subsumes(right->level); + return left->level.subsumes(right->level); } TypeMismatch::Context Unifier::mismatchContext() @@ -464,7 +443,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool auto superFree = log.getMutable(superTy); auto subFree = log.getMutable(subTy); - if (superFree && subFree && subsumes(useNewSolver, superFree, subFree)) + if (superFree && subFree && subsumes(superFree, subFree)) { if (!occursCheck(subTy, superTy, /* reversed = */ false)) log.replace(subTy, BoundType(superTy)); @@ -475,7 +454,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { if (!occursCheck(superTy, subTy, /* reversed = */ true)) { - if (subsumes(useNewSolver, superFree, subFree)) + if (subsumes(superFree, subFree)) { log.changeLevel(subTy, superFree->level); } @@ -489,7 +468,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { // Unification can't change the level of a generic. auto subGeneric = log.getMutable(subTy); - if (subGeneric && !subsumes(useNewSolver, subGeneric, superFree)) + if (subGeneric && !subsumes(subGeneric, superFree)) { // TODO: a more informative error message? CLI-39912 reportError(location, GenericError{"Generic subtype escaping scope"}); @@ -498,7 +477,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursCheck(superTy, subTy, /* reversed = */ true)) { - promoteTypeLevels(log, types, superFree->level, superFree->scope, useNewSolver, subTy); + promoteTypeLevels(log, types, superFree->level, subTy); Widen widen{types, builtinTypes}; log.replace(superTy, BoundType(widen(subTy))); @@ -515,7 +494,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto superGeneric = log.getMutable(superTy); - if (superGeneric && !subsumes(useNewSolver, superGeneric, subFree)) + if (superGeneric && !subsumes(superGeneric, subFree)) { // TODO: a more informative error message? CLI-39912 reportError(location, GenericError{"Generic supertype escaping scope"}); @@ -524,7 +503,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursCheck(subTy, superTy, /* reversed = */ false)) { - promoteTypeLevels(log, types, subFree->level, subFree->scope, useNewSolver, superTy); + promoteTypeLevels(log, types, subFree->level, superTy); log.replace(subTy, BoundType(superTy)); } @@ -536,7 +515,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool auto superGeneric = log.getMutable(superTy); auto subGeneric = log.getMutable(subTy); - if (superGeneric && subGeneric && subsumes(useNewSolver, superGeneric, subGeneric)) + if (superGeneric && subGeneric && subsumes(superGeneric, subGeneric)) { if (!occursCheck(subTy, superTy, /* reversed = */ false)) log.replace(subTy, BoundType(superTy)); @@ -750,25 +729,22 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ for (TypeId type : subUnion->options) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, superTy); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(type, superTy); - if (useNewSolver) - logs.push_back(std::move(innerState.log)); - - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) unificationTooComplex = e; - else if (innerState.failure) + else if (innerState->failure) { // If errors were suppressed, we store the log up, so we can commit it if no other option succeeds. - if (innerState.errors.empty()) - logs.push_back(std::move(innerState.log)); + if (innerState->errors.empty()) + logs.push_back(std::move(innerState->log)); // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' else if (!firstFailedOption && !isNil(type)) - firstFailedOption = {innerState.errors.front()}; + firstFailedOption = {innerState->errors.front()}; failed = true; - errorsSuppressed &= innerState.errors.empty(); + errorsSuppressed &= innerState->errors.empty(); } } @@ -863,26 +839,21 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp for (size_t i = 0; i < uv->options.size(); ++i) { TypeId type = uv->options[(i + startIndex) % uv->options.size()]; - Unifier innerState = makeChildUnifier(); - innerState.normalize = false; - innerState.tryUnify_(subTy, type, isFunctionCall); + std::unique_ptr innerState = makeChildUnifier(); + innerState->normalize = false; + innerState->tryUnify_(subTy, type, isFunctionCall); - if (!innerState.failure) + if (!innerState->failure) { found = true; - if (useNewSolver) - logs.push_back(std::move(innerState.log)); - else - { - log.concat(std::move(innerState.log)); - break; - } + log.concat(std::move(innerState->log)); + break; } - else if (innerState.errors.empty()) + else if (innerState->errors.empty()) { errorsSuppressed = true; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) { unificationTooComplex = e; } @@ -891,13 +862,10 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp failedOptionCount++; if (!failedOption) - failedOption = {innerState.errors.front()}; + failedOption = {innerState->errors.front()}; } } - if (useNewSolver) - log.concatAsUnion(combineLogsIntoUnion(std::move(logs)), NotNull{types}); - if (unificationTooComplex) { reportError(*unificationTooComplex); @@ -907,25 +875,25 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp // It is possible that T <: A | B even though T innerState = makeChildUnifier(); std::shared_ptr subNorm = normalizer->normalize(subTy); std::shared_ptr superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) return reportError(location, NormalizationTooComplex{}); else if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - innerState.tryUnifyNormalizedTypes( + innerState->tryUnifyNormalizedTypes( subTy, superTy, *subNorm, *superNorm, "None of the union options are compatible. For example:", *failedOption ); else - innerState.tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); + innerState->tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the union options are compatible"); - if (!innerState.failure) - log.concat(std::move(innerState.log)); - else if (errorsSuppressed || innerState.errors.empty()) + if (!innerState->failure) + log.concat(std::move(innerState->log)); + else if (errorsSuppressed || innerState->errors.empty()) failure = true; else - reportError(std::move(innerState.errors.front())); + reportError(std::move(innerState->errors.front())); } else if (!found && normalize) { @@ -964,27 +932,21 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I // T <: A & B if and only if T <: A and T <: B for (TypeId type : uv->parts) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) unificationTooComplex = e; - else if (!innerState.errors.empty()) + else if (!innerState->errors.empty()) { if (!firstFailedOption) - firstFailedOption = {innerState.errors.front()}; + firstFailedOption = {innerState->errors.front()}; } - if (useNewSolver) - logs.push_back(std::move(innerState.log)); - else - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } - if (useNewSolver) - log.concat(combineLogsIntoIntersection(std::move(logs))); - if (unificationTooComplex) reportError(*unificationTooComplex); else if (firstFailedOption) @@ -1032,62 +994,38 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* } } - if (useNewSolver && normalize) - { - // Sometimes a negation type is inside one of the types, e.g. { p: number } & { p: ~number }. - NegationTypeFinder finder; - finder.traverse(subTy); - - if (finder.found) - { - // It is possible that A & B <: T even though A subNorm = normalizer->normalize(subTy); - std::shared_ptr superNorm = normalizer->normalize(superTy); - if (subNorm && superNorm) - tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, "none of the intersection parts are compatible"); - else - reportError(location, NormalizationTooComplex{}); - - return; - } - } - std::vector logs; for (size_t i = 0; i < uv->parts.size(); ++i) { TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; - Unifier innerState = makeChildUnifier(); - innerState.normalize = false; - innerState.tryUnify_(type, superTy, isFunctionCall); + std::unique_ptr innerState = makeChildUnifier(); + innerState->normalize = false; + innerState->tryUnify_(type, superTy, isFunctionCall); // TODO: This sets errorSuppressed to true if any of the parts is error-suppressing, // in paricular any & T is error-suppressing. Really, errorSuppressed should be true if // all of the parts are error-suppressing, but that fails to typecheck lua-apps. - if (innerState.errors.empty()) + if (innerState->errors.empty()) { found = true; - errorsSuppressed = innerState.failure; - if (useNewSolver || innerState.failure) - logs.push_back(std::move(innerState.log)); + errorsSuppressed = innerState->failure; + if (innerState->failure) + logs.push_back(std::move(innerState->log)); else { errorsSuppressed = false; - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); break; } } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) { unificationTooComplex = e; } } - if (useNewSolver) - log.concat(combineLogsIntoIntersection(std::move(logs))); - else if (errorsSuppressed) + if (errorsSuppressed) log.concat(std::move(logs.front())); if (unificationTooComplex) @@ -1201,24 +1139,6 @@ void Unifier::tryUnifyNormalizedTypes( } } - if (useNewSolver) - { - for (TypeId superTable : superNorm.tables) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify(subClass, superTable); - - if (innerState.errors.empty()) - { - found = true; - log.concat(std::move(innerState.log)); - break; - } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - return reportError(*e); - } - } - if (!found) { return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); @@ -1236,17 +1156,17 @@ void Unifier::tryUnifyNormalizedTypes( break; } - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); - innerState.tryUnify(subTable, superTable); + innerState->tryUnify(subTable, superTable); - if (innerState.errors.empty()) + if (innerState->errors.empty()) { found = true; - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); break; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) return reportError(*e); } if (!found) @@ -1259,15 +1179,15 @@ void Unifier::tryUnifyNormalizedTypes( return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); for (TypeId superFun : superNorm.functions.parts) { - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); const FunctionType* superFtv = get(superFun); if (!superFtv) return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); - TypePackId tgt = innerState.tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); - innerState.tryUnify_(tgt, superFtv->retTypes); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - else if (auto e = hasUnificationTooComplex(innerState.errors)) + TypePackId tgt = innerState->tryApplyOverloadedFunction(subTy, subNorm.functions, superFtv->argTypes); + innerState->tryUnify_(tgt, superFtv->retTypes); + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + else if (auto e = hasUnificationTooComplex(innerState->errors)) return reportError(*e); else return reportError(location, TypeMismatch{superTy, subTy, reason, error, mismatchContext()}); @@ -1305,17 +1225,17 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized { if (!firstFun) firstFun = ftv; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(args, ftv->argTypes); - if (innerState.errors.empty()) + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(args, ftv->argTypes); + if (innerState->errors.empty()) { - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); if (result) { - innerState.log.clear(); - innerState.tryUnify_(*result, ftv->retTypes); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + innerState->log.clear(); + innerState->tryUnify_(*result, ftv->retTypes); + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); // Annoyingly, since we don't support intersection of generic type packs, // the intersection may fail. We rather arbitrarily use the first matching overload // in that case. @@ -1325,7 +1245,7 @@ TypePackId Unifier::tryApplyOverloadedFunction(TypeId function, const Normalized else result = ftv->retTypes; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) + else if (auto e = hasUnificationTooComplex(innerState->errors)) { reportError(*e); return builtinTypes->errorRecoveryTypePack(args); @@ -1503,26 +1423,20 @@ struct WeirdIter } }; -void Unifier::enableNewSolver() -{ - useNewSolver = true; - log.useScopes = true; -} - ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) { - Unifier s = makeChildUnifier(); - s.tryUnify_(subTy, superTy); + std::unique_ptr s = makeChildUnifier(); + s->tryUnify_(subTy, superTy); - return s.errors; + return s->errors; } ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall) { - Unifier s = makeChildUnifier(); - s.tryUnify_(subTy, superTy, isFunctionCall); + std::unique_ptr s = makeChildUnifier(); + s->tryUnify_(subTy, superTy, isFunctionCall); - return s.errors; + return s->errors; } void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall) @@ -1588,8 +1502,6 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (!occursCheck(superTp, subTp, /* reversed = */ true)) { Widen widen{types, builtinTypes}; - if (useNewSolver) - promoteTypeLevels(log, types, superFree->level, superFree->scope, /*useScopes*/ true, subTp); log.replace(superTp, Unifiable::Bound(widen(subTp))); } } @@ -1597,8 +1509,6 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal { if (!occursCheck(subTp, superTp, /* reversed = */ false)) { - if (useNewSolver) - promoteTypeLevels(log, types, subFree->level, subFree->scope, /*useScopes*/ true, superTp); log.replace(subTp, Unifiable::Bound(superTp)); } } @@ -1617,9 +1527,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal log.replace(subTp, Unifiable::Bound(superTp)); } } - else if (log.getMutable(superTp)) + else if (log.getMutable(superTp)) tryUnifyWithAny(subTp, superTp); - else if (log.getMutable(subTp)) + else if (log.getMutable(subTp)) tryUnifyWithAny(superTp, subTp); else if (log.getMutable(superTp)) tryUnifyVariadics(subTp, superTp, false); @@ -1649,7 +1559,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (FFlag::LuauSolverV2) return freshType(NotNull{types}, builtinTypes, scope); else - return types->freshType(scope, level); + return FFlag::LuauFreeTypesMustHaveBounds ? types->freshType(builtinTypes, scope, level) : types->freshType_DEPRECATED(scope, level); }; const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); @@ -1688,74 +1598,28 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - if (useNewSolver) + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; + if (lFreeTail && rFreeTail) { - if (subIter.tail() && superIter.tail()) - tryUnify_(*subIter.tail(), *superIter.tail()); - else if (subIter.tail()) - { - const TypePackId subTail = log.follow(*subIter.tail()); - - if (log.get(subTail)) - tryUnify_(subTail, emptyTp); - else if (log.get(subTail)) - reportError(location, TypePackMismatch{subTail, emptyTp}); - else if (log.get(subTail) || log.get(subTail)) - { - // Nothing. This is ok. - } - else - { - ice("Unexpected subtype tail pack " + toString(subTail), location); - } - } - else if (superIter.tail()) - { - const TypePackId superTail = log.follow(*superIter.tail()); - - if (log.get(superTail)) - tryUnify_(emptyTp, superTail); - else if (log.get(superTail)) - reportError(location, TypePackMismatch{emptyTp, superTail}); - else if (log.get(superTail) || log.get(superTail)) - { - // Nothing. This is ok. - } - else - { - ice("Unexpected supertype tail pack " + toString(superTail), location); - } - } - else - { - // Nothing. This is ok. - } + tryUnify_(*subTpv->tail, *superTpv->tail); } - else + else if (lFreeTail) { - const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; - const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) - { + tryUnify_(emptyTp, *superTpv->tail); + } + else if (rFreeTail) + { + tryUnify_(emptyTp, *subTpv->tail); + } + else if (subTpv->tail && superTpv->tail) + { + if (log.getMutable(superIter.packId)) + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + else if (log.getMutable(subIter.packId)) + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + else tryUnify_(*subTpv->tail, *superTpv->tail); - } - else if (lFreeTail) - { - tryUnify_(emptyTp, *superTpv->tail); - } - else if (rFreeTail) - { - tryUnify_(emptyTp, *subTpv->tail); - } - else if (subTpv->tail && superTpv->tail) - { - if (log.getMutable(superIter.packId)) - tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); - else if (log.getMutable(subIter.packId)) - tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); - else - tryUnify_(*subTpv->tail, *superTpv->tail); - } } break; @@ -1885,9 +1749,9 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal // generic methods in tables to be marked read-only. if (FFlag::LuauInstantiateInSubtyping && shouldInstantiate) { - Instantiation instantiation{&log, types, builtinTypes, scope->level, scope}; + std::unique_ptr instantiation = std::make_unique(&log, types, builtinTypes, scope->level, scope); - std::optional instantiated = instantiation.substitute(subTy); + std::optional instantiated = instantiation->substitute(subTy); if (instantiated.has_value()) { subFunction = log.getMutable(*instantiated); @@ -1931,54 +1795,54 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (!isFunctionCall) { - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); - innerState.ctx = CountMismatch::Arg; - innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); + innerState->ctx = CountMismatch::Arg; + innerState->tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - bool reported = !innerState.errors.empty(); + bool reported = !innerState->errors.empty(); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + else if (!innerState->errors.empty() && innerState->firstPackErrorPos) reportError( location, TypeMismatch{ superTy, subTy, - format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front(), + format("Argument #%d type is not compatible.", *innerState->firstPackErrorPos), + innerState->errors.front(), mismatchContext() } ); - else if (!innerState.errors.empty()) - reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); + else if (!innerState->errors.empty()) + reportError(location, TypeMismatch{superTy, subTy, "", innerState->errors.front(), mismatchContext()}); - innerState.ctx = CountMismatch::FunctionResult; - innerState.tryUnify_(subFunction->retTypes, superFunction->retTypes); + innerState->ctx = CountMismatch::FunctionResult; + innerState->tryUnify_(subFunction->retTypes, superFunction->retTypes); if (!reported) { - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) - reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front(), mismatchContext()}); - else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + else if (!innerState->errors.empty() && size(superFunction->retTypes) == 1 && finite(superFunction->retTypes)) + reportError(location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState->errors.front(), mismatchContext()}); + else if (!innerState->errors.empty() && innerState->firstPackErrorPos) reportError( location, TypeMismatch{ superTy, subTy, - format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front(), + format("Return #%d type is not compatible.", *innerState->firstPackErrorPos), + innerState->errors.front(), mismatchContext() } ); - else if (!innerState.errors.empty()) - reportError(location, TypeMismatch{superTy, subTy, "", innerState.errors.front(), mismatchContext()}); + else if (!innerState->errors.empty()) + reportError(location, TypeMismatch{superTy, subTy, "", innerState->errors.front(), mismatchContext()}); } - log.concat(std::move(innerState.log)); + log.concat(std::move(innerState->log)); } else { @@ -2116,14 +1980,14 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (!literalProperties || !literalProperties->contains(name)) variance = Invariant; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(r->second.type(), prop.type()); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(r->second.type(), prop.type()); - checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (subTable->indexer && maybeString(subTable->indexer->indexType)) { @@ -2133,14 +1997,14 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (!literalProperties || !literalProperties->contains(name)) variance = Invariant; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTable->indexer->indexResultType, prop.type()); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(subTable->indexer->indexResultType, prop.type()); - checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (subTable->state == TableState::Unsealed && isOptional(prop.type())) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` @@ -2211,20 +2075,20 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (!literalProperties || !literalProperties->contains(name)) variance = Invariant; - Unifier innerState = makeChildUnifier(); - if (useNewSolver || FFlag::LuauFixIndexerSubtypingOrdering) - innerState.tryUnify_(prop.type(), superTable->indexer->indexResultType); + std::unique_ptr innerState = makeChildUnifier(); + if (FFlag::LuauFixIndexerSubtypingOrdering) + innerState->tryUnify_(prop.type(), superTable->indexer->indexResultType); else { // Incredibly, the old solver depends on this bug somehow. - innerState.tryUnify_(superTable->indexer->indexResultType, prop.type()); + innerState->tryUnify_(superTable->indexer->indexResultType, prop.type()); } - checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (superTable->state == TableState::Unsealed) { @@ -2295,22 +2159,22 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, Resetter resetter{&variance}; variance = Invariant; - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); - innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); + innerState->tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); - bool reported = !innerState.errors.empty(); + bool reported = !innerState->errors.empty(); - checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, "[indexer key]", superTy, subTy); - innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + innerState->tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); if (!reported) - checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); + checkChildUnifierTypeMismatch(innerState->errors, "[indexer value]", superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + if (innerState->errors.empty()) + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (superTable->indexer) { @@ -2409,13 +2273,13 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) if (auto it = mttv->props.find("__index"); it != mttv->props.end()) { TypeId ty = it->second.type(); - Unifier child = makeChildUnifier(); - child.tryUnify_(ty, superTy); + std::unique_ptr child = makeChildUnifier(); + child->tryUnify_(ty, superTy); // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check - TypeId newSuperTy = child.log.follow(superTy); + TypeId newSuperTy = child->log.follow(superTy); if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) { @@ -2423,16 +2287,16 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) return; } - if (auto e = hasUnificationTooComplex(child.errors)) + if (auto e = hasUnificationTooComplex(child->errors)) reportError(*e); - else if (!child.errors.empty()) - fail(child.errors.front()); + else if (!child->errors.empty()) + fail(child->errors.front()); - log.concat(std::move(child.log)); + log.concat(std::move(child->log)); // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype - if (child.errors.empty()) + if (child->errors.empty()) log.replace(superTy, BoundType{subTy}); return; @@ -2477,19 +2341,19 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (const MetatableType* subMetatable = log.getMutable(subTy)) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subMetatable->table, superMetatable->table); - innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(subMetatable->table, superMetatable->table); + innerState->tryUnify_(subMetatable->metatable, superMetatable->metatable); - if (auto e = hasUnificationTooComplex(innerState.errors)) + if (auto e = hasUnificationTooComplex(innerState->errors)) reportError(*e); - else if (!innerState.errors.empty()) + else if (!innerState->errors.empty()) reportError( - location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()} + location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState->errors.front(), mismatchContext()} ); - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else if (TableType* subTable = log.getMutable(subTy)) { @@ -2497,49 +2361,8 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { case TableState::Free: { - if (useNewSolver) - { - Unifier innerState = makeChildUnifier(); - bool missingProperty = false; - - for (const auto& [propName, prop] : subTable->props) - { - if (std::optional mtPropTy = findTablePropertyRespectingMeta(superTy, propName)) - { - innerState.tryUnify(prop.type(), *mtPropTy); - } - else - { - reportError(mismatchError); - missingProperty = true; - break; - } - } - - if (const TableType* superTable = log.get(log.follow(superMetatable->table))) - { - // TODO: Unify indexers. - } - - if (auto e = hasUnificationTooComplex(innerState.errors)) - reportError(*e); - else if (!innerState.errors.empty()) - reportError(TypeError{ - location, - TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front(), mismatchContext()} - }); - else if (!missingProperty) - { - log.concat(std::move(innerState.log)); - log.bindTable(subTy, superTy); - failure |= innerState.failure; - } - } - else - { - tryUnify_(subTy, superMetatable->table); - log.bindTable(subTy, superTy); - } + tryUnify_(subTy, superMetatable->table); + log.bindTable(subTy, superTy); break; } @@ -2619,15 +2442,15 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) } else { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(classProp->type(), prop.type()); + std::unique_ptr innerState = makeChildUnifier(); + innerState->tryUnify_(classProp->type(), prop.type()); - checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); + checkChildUnifierTypeMismatch(innerState->errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - if (innerState.errors.empty()) + if (innerState->errors.empty()) { - log.concat(std::move(innerState.log)); - failure |= innerState.failure; + log.concat(std::move(innerState->log)); + failure |= innerState->failure; } else { @@ -2663,9 +2486,9 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) return reportError(location, NormalizationTooComplex{}); // T state = makeChildUnifier(); + state->tryUnifyNormalizedTypes(subTy, superTy, *subNorm, *superNorm, ""); + if (state->errors.empty()) reportError(location, TypeMismatch{superTy, subTy, mismatchContext()}); } @@ -2742,7 +2565,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever else log.replace(tail, BoundTypePack{superTp}); } - else if (get(tail)) + else if (get(tail)) { // Nothing to do here. } @@ -2846,7 +2669,7 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) { - LUAU_ASSERT(get(anyTp)); + LUAU_ASSERT(get(anyTp)); const TypeId anyTy = builtinTypes->errorRecoveryType(); @@ -2865,18 +2688,9 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(builtinTypes, errors, lhsType, name, location); } -TxnLog Unifier::combineLogsIntoIntersection(std::vector logs) -{ - LUAU_ASSERT(useNewSolver); - TxnLog result(useNewSolver); - for (TxnLog& log : logs) - result.concatAsIntersections(std::move(log), NotNull{types}); - return result; -} - TxnLog Unifier::combineLogsIntoUnion(std::vector logs) { - TxnLog result(useNewSolver); + TxnLog result; for (TxnLog& log : logs) result.concatAsUnion(std::move(log), NotNull{types}); return result; @@ -2890,27 +2704,27 @@ bool Unifier::occursCheck(TypeId needle, TypeId haystack, bool reversed) if (occurs) { - Unifier innerState = makeChildUnifier(); + std::unique_ptr innerState = makeChildUnifier(); if (const UnionType* ut = get(haystack)) { if (reversed) - innerState.tryUnifyUnionWithType(haystack, ut, needle); + innerState->tryUnifyUnionWithType(haystack, ut, needle); else - innerState.tryUnifyTypeWithUnion(needle, haystack, ut, /* cacheEnabled = */ false, /* isFunction = */ false); + innerState->tryUnifyTypeWithUnion(needle, haystack, ut, /* cacheEnabled = */ false, /* isFunction = */ false); } else if (const IntersectionType* it = get(haystack)) { if (reversed) - innerState.tryUnifyIntersectionWithType(haystack, it, needle, /* cacheEnabled = */ false, /* isFunction = */ false); + innerState->tryUnifyIntersectionWithType(haystack, it, needle, /* cacheEnabled = */ false, /* isFunction = */ false); else - innerState.tryUnifyTypeWithIntersection(needle, haystack, it); + innerState->tryUnifyTypeWithIntersection(needle, haystack, it); } else { - innerState.failure = true; + innerState->failure = true; } - if (innerState.failure) + if (innerState->failure) { reportError(location, OccursCheckFailed{}); log.replace(needle, BoundType{builtinTypes->errorRecoveryType()}); @@ -2974,10 +2788,7 @@ bool Unifier::occursCheck(TypePackId needle, TypePackId haystack, bool reversed) if (occurs) { reportError(location, OccursCheckFailed{}); - if (FFlag::LuauUnifierShouldNotCopyError) - log.replace(needle, BoundTypePack{builtinTypes->errorRecoveryTypePack()}); - else - log.replace(needle, *builtinTypes->errorRecoveryTypePack()); + log.replace(needle, BoundTypePack{builtinTypes->errorRecoveryTypePack()}); } return occurs; @@ -3001,7 +2812,7 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); - while (!log.getMutable(haystack)) + while (!log.getMutable(haystack)) { if (needle == haystack) return true; @@ -3018,14 +2829,11 @@ bool Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ return false; } -Unifier Unifier::makeChildUnifier() +std::unique_ptr Unifier::makeChildUnifier() { - Unifier u = Unifier{normalizer, scope, location, variance, &log}; - u.normalize = normalize; - u.checkInhabited = checkInhabited; - - if (useNewSolver) - u.enableNewSolver(); + std::unique_ptr u = std::make_unique(normalizer, scope, location, variance, &log); + u->normalize = normalize; + u->checkInhabited = checkInhabited; return u; } diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 5ea11ad0..e63856d3 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -908,7 +908,7 @@ OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypePack RecursionLimiter _ra(&recursionCount, recursionLimit); - while (!getMutable(haystack)) + while (!getMutable(haystack)) { if (needle == haystack) return OccursCheckResult::Fail; diff --git a/Ast/include/Luau/Allocator.h b/Ast/include/Luau/Allocator.h new file mode 100644 index 00000000..bd7d423f --- /dev/null +++ b/Ast/include/Luau/Allocator.h @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Location.h" +#include "Luau/DenseHash.h" +#include "Luau/Common.h" + +#include + +namespace Luau +{ + +class Allocator +{ +public: + Allocator(); + Allocator(Allocator&&); + + Allocator& operator=(Allocator&&) = delete; + + ~Allocator(); + + void* allocate(size_t size); + + template + T* alloc(Args&&... args) + { + static_assert(std::is_trivially_destructible::value, "Objects allocated with this allocator will never have their destructors run!"); + + T* t = static_cast(allocate(sizeof(T))); + new (t) T(std::forward(args)...); + return t; + } + +private: + struct Page + { + Page* next; + + alignas(8) char data[8192]; + }; + + Page* root; + size_t offset; +}; + +} // namespace Luau diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 099ece2b..34f0072e 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -120,20 +120,6 @@ struct AstTypeList using AstArgumentName = std::pair; // TODO: remove and replace when we get a common struct for this pair instead of AstName -struct AstGenericType -{ - AstName name; - Location location; - AstType* defaultValue = nullptr; -}; - -struct AstGenericTypePack -{ - AstName name; - Location location; - AstTypePack* defaultValue = nullptr; -}; - extern int gAstRttiIndex; template @@ -253,6 +239,32 @@ public: bool hasSemicolon; }; +class AstGenericType : public AstNode +{ +public: + LUAU_RTTI(AstGenericType) + + explicit AstGenericType(const Location& location, AstName name, AstType* defaultValue = nullptr); + + void visit(AstVisitor* visitor) override; + + AstName name; + AstType* defaultValue = nullptr; +}; + +class AstGenericTypePack : public AstNode +{ +public: + LUAU_RTTI(AstGenericTypePack) + + explicit AstGenericTypePack(const Location& location, AstName name, AstTypePack* defaultValue = nullptr); + + void visit(AstVisitor* visitor) override; + + AstName name; + AstTypePack* defaultValue = nullptr; +}; + class AstExprGroup : public AstExpr { public: @@ -316,16 +328,18 @@ public: enum QuoteStyle { - Quoted, + QuotedSimple, + QuotedRaw, Unquoted }; - AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle = Quoted); + AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle); void visit(AstVisitor* visitor) override; + bool isQuoted() const; AstArray value; - QuoteStyle quoteStyle = Quoted; + QuoteStyle quoteStyle; }; class AstExprLocal : public AstExpr @@ -422,8 +436,8 @@ public: AstExprFunction( const Location& location, const AstArray& attributes, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, @@ -441,8 +455,8 @@ public: bool hasNativeAttribute() const; AstArray attributes; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstLocal* self; AstArray args; std::optional returnAnnotation; @@ -855,8 +869,8 @@ public: const Location& location, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported ); @@ -865,8 +879,8 @@ public: AstName name; Location nameLocation; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -876,13 +890,14 @@ class AstStatTypeFunction : public AstStat public: LUAU_RTTI(AstStatTypeFunction); - AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body); + AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body, bool exported); void visit(AstVisitor* visitor) override; AstName name; Location nameLocation; AstExprFunction* body; + bool exported; }; class AstStatDeclareGlobal : public AstStat @@ -908,8 +923,8 @@ public: const Location& location, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, bool vararg, @@ -922,8 +937,8 @@ public: const AstArray& attributes, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, bool vararg, @@ -939,8 +954,8 @@ public: AstArray attributes; AstName name; Location nameLocation; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList params; AstArray paramNames; bool vararg = false; @@ -1071,8 +1086,8 @@ public: AstTypeFunction( const Location& location, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes @@ -1081,8 +1096,8 @@ public: AstTypeFunction( const Location& location, const AstArray& attributes, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes @@ -1093,8 +1108,8 @@ public: bool isCheckedFunction() const; AstArray attributes; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; @@ -1201,6 +1216,18 @@ public: const AstArray value; }; +class AstTypeGroup : public AstType +{ +public: + LUAU_RTTI(AstTypeGroup) + + explicit AstTypeGroup(const Location& location, AstType* type); + + void visit(AstVisitor* visitor) override; + + AstType* type; +}; + class AstTypePack : public AstNode { public: @@ -1261,6 +1288,16 @@ public: return visit(static_cast(node)); } + virtual bool visit(class AstGenericType* node) + { + return visit(static_cast(node)); + } + + virtual bool visit(class AstGenericTypePack* node) + { + return visit(static_cast(node)); + } + virtual bool visit(class AstExpr* node) { return visit(static_cast(node)); @@ -1467,6 +1504,10 @@ public: { return visit(static_cast(node)); } + virtual bool visit(class AstTypeGroup* node) + { + return visit(static_cast(node)); + } virtual bool visit(class AstTypeError* node) { return visit(static_cast(node)); @@ -1490,6 +1531,7 @@ public: } }; +bool isLValue(const AstExpr*); AstName getIdentifier(AstExpr*); Location getLocation(const AstTypeList& typeList); @@ -1520,4 +1562,4 @@ struct hash } }; -} // namespace std \ No newline at end of file +} // namespace std diff --git a/Ast/include/Luau/Cst.h b/Ast/include/Luau/Cst.h new file mode 100644 index 00000000..95211f14 --- /dev/null +++ b/Ast/include/Luau/Cst.h @@ -0,0 +1,385 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" + +#include + +namespace Luau +{ + +extern int gCstRttiIndex; + +template +struct CstRtti +{ + static const int value; +}; + +template +const int CstRtti::value = ++gCstRttiIndex; + +#define LUAU_CST_RTTI(Class) \ + static int CstClassIndex() \ + { \ + return CstRtti::value; \ + } + +class CstNode +{ +public: + explicit CstNode(int classIndex) + : classIndex(classIndex) + { + } + + template + bool is() const + { + return classIndex == T::CstClassIndex(); + } + template + T* as() + { + return classIndex == T::CstClassIndex() ? static_cast(this) : nullptr; + } + template + const T* as() const + { + return classIndex == T::CstClassIndex() ? static_cast(this) : nullptr; + } + + const int classIndex; +}; + +class CstExprConstantNumber : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprConstantNumber) + + explicit CstExprConstantNumber(const AstArray& value); + + AstArray value; +}; + +class CstExprConstantString : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprConstantNumber) + + enum QuoteStyle + { + QuotedSingle, + QuotedDouble, + QuotedRaw, + QuotedInterp, + }; + + CstExprConstantString(AstArray sourceString, QuoteStyle quoteStyle, unsigned int blockDepth); + + AstArray sourceString; + QuoteStyle quoteStyle; + unsigned int blockDepth; +}; + +class CstExprCall : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprCall) + + CstExprCall(std::optional openParens, std::optional closeParens, AstArray commaPositions); + + std::optional openParens; + std::optional closeParens; + AstArray commaPositions; +}; + +class CstExprIndexExpr : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprIndexExpr) + + CstExprIndexExpr(Position openBracketPosition, Position closeBracketPosition); + + Position openBracketPosition; + Position closeBracketPosition; +}; + +class CstExprTable : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprTable) + + enum Separator + { + Comma, + Semicolon, + }; + + struct Item + { + std::optional indexerOpenPosition; // '[', only if Kind == General + std::optional indexerClosePosition; // ']', only if Kind == General + std::optional equalsPosition; // only if Kind != List + std::optional separator; // may be missing for last Item + std::optional separatorPosition; + }; + + explicit CstExprTable(const AstArray& items); + + AstArray items; +}; + +// TODO: Shared between unary and binary, should we split? +class CstExprOp : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprOp) + + explicit CstExprOp(Position opPosition); + + Position opPosition; +}; + +class CstExprTypeAssertion : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprTypeAssertion) + + explicit CstExprTypeAssertion(Position opPosition); + + Position opPosition; +}; + +class CstExprIfElse : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprIfElse) + + CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf); + + Position thenPosition; + Position elsePosition; + bool isElseIf; +}; + +class CstExprInterpString : public CstNode +{ +public: + LUAU_CST_RTTI(CstExprInterpString) + + explicit CstExprInterpString(AstArray> sourceStrings, AstArray stringPositions); + + AstArray> sourceStrings; + AstArray stringPositions; +}; + +class CstStatDo : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatDo) + + explicit CstStatDo(Position endPosition); + + Position endPosition; +}; + +class CstStatRepeat : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatRepeat) + + explicit CstStatRepeat(Position untilPosition); + + Position untilPosition; +}; + +class CstStatReturn : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatReturn) + + explicit CstStatReturn(AstArray commaPositions); + + AstArray commaPositions; +}; + +class CstStatLocal : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatLocal) + + CstStatLocal(AstArray varsCommaPositions, AstArray valuesCommaPositions); + + AstArray varsCommaPositions; + AstArray valuesCommaPositions; +}; + +class CstStatFor : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatFor) + + CstStatFor(Position equalsPosition, Position endCommaPosition, std::optional stepCommaPosition); + + Position equalsPosition; + Position endCommaPosition; + std::optional stepCommaPosition; +}; + +class CstStatForIn : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatForIn) + + CstStatForIn(AstArray varsCommaPositions, AstArray valuesCommaPositions); + + AstArray varsCommaPositions; + AstArray valuesCommaPositions; +}; + +class CstStatAssign : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatAssign) + + CstStatAssign(AstArray varsCommaPositions, Position equalsPosition, AstArray valuesCommaPositions); + + AstArray varsCommaPositions; + Position equalsPosition; + AstArray valuesCommaPositions; +}; + +class CstStatCompoundAssign : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatCompoundAssign) + + explicit CstStatCompoundAssign(Position opPosition); + + Position opPosition; +}; + +class CstStatLocalFunction : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatLocalFunction) + + explicit CstStatLocalFunction(Position functionKeywordPosition); + + Position functionKeywordPosition; +}; + +class CstGenericType : public CstNode +{ +public: + LUAU_CST_RTTI(CstGenericType) + + CstGenericType(std::optional defaultEqualsPosition); + + std::optional defaultEqualsPosition; +}; + +class CstGenericTypePack : public CstNode +{ +public: + LUAU_CST_RTTI(CstGenericTypePack) + + CstGenericTypePack(Position ellipsisPosition, std::optional defaultEqualsPosition); + + Position ellipsisPosition; + std::optional defaultEqualsPosition; +}; + +class CstStatTypeAlias : public CstNode +{ +public: + LUAU_CST_RTTI(CstStatTypeAlias) + + CstStatTypeAlias( + Position typeKeywordPosition, + Position genericsOpenPosition, + AstArray genericsCommaPositions, + Position genericsClosePosition, + Position equalsPosition + ); + + Position typeKeywordPosition; + Position genericsOpenPosition; + AstArray genericsCommaPositions; + Position genericsClosePosition; + Position equalsPosition; +}; + +class CstTypeReference : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeReference) + + CstTypeReference( + std::optional prefixPointPosition, + Position openParametersPosition, + AstArray parametersCommaPositions, + Position closeParametersPosition + ); + + std::optional prefixPointPosition; + Position openParametersPosition; + AstArray parametersCommaPositions; + Position closeParametersPosition; +}; + +class CstTypeTable : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeTable) + + struct Item + { + enum struct Kind + { + Indexer, + Property, + StringProperty, + }; + + Kind kind; + Position indexerOpenPosition; // '[', only if Kind != Property + Position indexerClosePosition; // ']' only if Kind != Property + Position colonPosition; + std::optional separator; // may be missing for last Item + std::optional separatorPosition; + + CstExprConstantString* stringInfo = nullptr; // only if Kind == StringProperty + }; + + CstTypeTable(AstArray items, bool isArray); + + AstArray items; + bool isArray = false; +}; + +class CstTypeTypeof : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeTypeof) + + CstTypeTypeof(Position openPosition, Position closePosition); + + Position openPosition; + Position closePosition; +}; + +class CstTypeSingletonString : public CstNode +{ +public: + LUAU_CST_RTTI(CstTypeSingletonString) + + CstTypeSingletonString(AstArray sourceString, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth); + + AstArray sourceString; + CstExprConstantString::QuoteStyle quoteStyle; + unsigned int blockDepth; +}; + +} // namespace Luau \ No newline at end of file diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index f6ac28ad..3570a35c 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Allocator.h" #include "Luau/Ast.h" #include "Luau/Location.h" #include "Luau/DenseHash.h" @@ -11,40 +12,6 @@ namespace Luau { -class Allocator -{ -public: - Allocator(); - Allocator(Allocator&&); - - Allocator& operator=(Allocator&&) = delete; - - ~Allocator(); - - void* allocate(size_t size); - - template - T* alloc(Args&&... args) - { - static_assert(std::is_trivially_destructible::value, "Objects allocated with this allocator will never have their destructors run!"); - - T* t = static_cast(allocate(sizeof(T))); - new (t) T(std::forward(args)...); - return t; - } - -private: - struct Page - { - Page* next; - - char data[8192]; - }; - - Page* root; - size_t offset; -}; - struct Lexeme { enum Type @@ -120,6 +87,12 @@ struct Lexeme Reserved_END }; + enum struct QuoteStyle + { + Single, + Double, + }; + Type type; Location location; @@ -144,6 +117,8 @@ public: Lexeme(const Location& location, Type type, const char* name); unsigned int getLength() const; + unsigned int getBlockDepth() const; + QuoteStyle getQuoteStyle() const; std::string toString() const; }; @@ -186,7 +161,7 @@ private: class Lexer { public: - Lexer(const char* buffer, std::size_t bufferSize, AstNameTable& names); + Lexer(const char* buffer, std::size_t bufferSize, AstNameTable& names, Position startPosition = {0, 0}); void setSkipComments(bool skip); void setReadNames(bool read); @@ -212,6 +187,11 @@ public: static bool fixupQuotedString(std::string& data); static void fixupMultilineString(std::string& data); + unsigned int getOffset() const + { + return offset; + } + private: char peekch() const; char peekch(unsigned int lookahead) const; diff --git a/Ast/include/Luau/Location.h b/Ast/include/Luau/Location.h index 3fc8921a..95d4c78a 100644 --- a/Ast/include/Luau/Location.h +++ b/Ast/include/Luau/Location.h @@ -14,12 +14,37 @@ struct Position { } - bool operator==(const Position& rhs) const; - bool operator!=(const Position& rhs) const; - bool operator<(const Position& rhs) const; - bool operator>(const Position& rhs) const; - bool operator<=(const Position& rhs) const; - bool operator>=(const Position& rhs) const; + bool operator==(const Position& rhs) const + { + return this->column == rhs.column && this->line == rhs.line; + } + + bool operator!=(const Position& rhs) const + { + return !(*this == rhs); + } + bool operator<(const Position& rhs) const + { + if (line == rhs.line) + return column < rhs.column; + else + return line < rhs.line; + } + bool operator>(const Position& rhs) const + { + if (line == rhs.line) + return column > rhs.column; + else + return line > rhs.line; + } + bool operator<=(const Position& rhs) const + { + return *this == rhs || *this < rhs; + } + bool operator>=(const Position& rhs) const + { + return *this == rhs || *this > rhs; + } void shift(const Position& start, const Position& oldEnd, const Position& newEnd); }; @@ -52,8 +77,14 @@ struct Location { } - bool operator==(const Location& rhs) const; - bool operator!=(const Location& rhs) const; + bool operator==(const Location& rhs) const + { + return this->begin == rhs.begin && this->end == rhs.end; + } + bool operator!=(const Location& rhs) const + { + return !(*this == rhs); + } bool encloses(const Location& l) const; bool overlaps(const Location& l) const; diff --git a/Ast/include/Luau/ParseOptions.h b/Ast/include/Luau/ParseOptions.h index 01f2a74f..ac8e9348 100644 --- a/Ast/include/Luau/ParseOptions.h +++ b/Ast/include/Luau/ParseOptions.h @@ -1,6 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +#include + namespace Luau { @@ -12,10 +17,20 @@ enum class Mode Definition, // Type definition module, has special parsing rules }; +struct FragmentParseResumeSettings +{ + DenseHashMap localMap{AstName()}; + std::vector localStack; + Position resumePosition; +}; + struct ParseOptions { bool allowDeclarationSyntax = false; bool captureComments = false; + std::optional parseFragment = std::nullopt; + bool storeCstData = false; + bool noErrorLimit = false; }; } // namespace Luau diff --git a/Ast/include/Luau/ParseResult.h b/Ast/include/Luau/ParseResult.h index 9c0a9527..7803dc55 100644 --- a/Ast/include/Luau/ParseResult.h +++ b/Ast/include/Luau/ParseResult.h @@ -10,6 +10,7 @@ namespace Luau { class AstStatBlock; +class CstNode; class ParseError : public std::exception { @@ -55,6 +56,8 @@ struct Comment Location location; }; +using CstNodeMap = DenseHashMap; + struct ParseResult { AstStatBlock* root; @@ -64,6 +67,21 @@ struct ParseResult std::vector errors; std::vector commentLocations; + + CstNodeMap cstNodeMap{nullptr}; +}; + +struct ParseExprResult +{ + AstExpr* expr; + size_t lines = 0; + + std::vector hotcomments; + std::vector errors; + + std::vector commentLocations; + + CstNodeMap cstNodeMap{nullptr}; }; static constexpr const char* kParseNameError = "%error-id%"; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 4e49028a..cfe7d08c 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -8,6 +8,7 @@ #include "Luau/StringUtils.h" #include "Luau/DenseHash.h" #include "Luau/Common.h" +#include "Luau/Cst.h" #include #include @@ -62,6 +63,14 @@ public: ParseOptions options = ParseOptions() ); + static ParseExprResult parseExpr( + const char* buffer, + std::size_t bufferSize, + AstNameTable& names, + Allocator& allocator, + ParseOptions options = ParseOptions() + ); + private: struct Name; struct Binding; @@ -116,7 +125,7 @@ private: AstStat* parseFor(); // funcname ::= Name {`.' Name} [`:' Name] - AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); + AstExpr* parseFunctionName(Location start_DEPRECATED, bool& hasself, AstName& debugname); // function funcname funcbody LUAU_FORCEINLINE AstStat* parseFunctionStat(const AstArray& attributes = {nullptr, 0}); @@ -143,10 +152,10 @@ private: AstStat* parseReturn(); // type Name `=' Type - AstStat* parseTypeAlias(const Location& start, bool exported); + AstStat* parseTypeAlias(const Location& start, bool exported, Position typeKeywordPosition); // type function Name ... end - AstStat* parseTypeFunction(const Location& start); + AstStat* parseTypeFunction(const Location& start, bool exported); AstDeclaredClassProp parseDeclaredClassMethod(); @@ -173,14 +182,18 @@ private: ); // explist ::= {exp `,'} exp - void parseExprList(TempVector& result); + void parseExprList(TempVector& result, TempVector* commaPositions = nullptr); // binding ::= Name [`:` Type] Binding parseBinding(); // bindinglist ::= (binding | `...') {`,' bindinglist} // Returns the location of the vararg ..., or std::nullopt if the function is not vararg. - std::tuple parseBindingList(TempVector& result, bool allowDot3 = false); + std::tuple parseBindingList( + TempVector& result, + bool allowDot3 = false, + TempVector* commaPositions = nullptr + ); AstType* parseOptionalType(); @@ -201,14 +214,24 @@ private: std::optional parseOptionalReturnType(); std::pair parseReturnType(); - AstTableIndexer* parseTableIndexer(AstTableAccess access, std::optional accessLocation); + struct TableIndexerResult + { + AstTableIndexer* node; + Position indexerOpenPosition; + Position indexerClosePosition; + Position colonPosition; + }; + + TableIndexerResult parseTableIndexer(AstTableAccess access, std::optional accessLocation); + // Remove with FFlagLuauStoreCSTData + AstTableIndexer* parseTableIndexer_DEPRECATED(AstTableAccess access, std::optional accessLocation); AstTypeOrPack parseFunctionType(bool allowPack, const AstArray& attributes); AstType* parseFunctionTypeTail( const Lexeme& begin, const AstArray& attributes, - AstArray generics, - AstArray genericPacks, + AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation @@ -217,7 +240,7 @@ private: AstType* parseTableType(bool inDeclarationContext = false); AstTypeOrPack parseSimpleType(bool allowPack, bool inDeclarationContext = false); - AstTypeOrPack parseTypeOrPack(); + AstTypeOrPack parseSimpleTypeOrPack(); AstType* parseType(bool inDeclarationContext = false); AstTypePack* parseTypePack(); @@ -259,6 +282,8 @@ private: // args ::= `(' [explist] `)' | tableconstructor | String AstExpr* parseFunctionArgs(AstExpr* func, bool self); + std::optional tableSeparator(); + // tableconstructor ::= `{' [fieldlist] `}' // fieldlist ::= field {fieldsep field} [fieldsep] // field ::= `[' exp `]' `=' exp | Name `=' exp | exp @@ -277,12 +302,21 @@ private: Name parseIndexName(const char* context, const Position& previous); // `<' namelist `>' - std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); + std::pair, AstArray> parseGenericTypeList( + bool withDefaultValues, + Position* openPosition = nullptr, + TempVector* commaPositions = nullptr, + Position* closePosition = nullptr + ); // `<' Type[, ...] `>' - AstArray parseTypeParams(); + AstArray parseTypeParams( + Position* openingPosition = nullptr, + TempVector* commaPositions = nullptr, + Position* closingPosition = nullptr + ); - std::optional> parseCharArray(); + std::optional> parseCharArray(AstArray* originalString = nullptr); AstExpr* parseString(); AstExpr* parseNumber(); @@ -292,6 +326,9 @@ private: void restoreLocals(unsigned int offset); + /// Returns string quote style and block depth + std::pair extractStringDetails(); + // check that parser is at lexeme/symbol, move to next lexeme/symbol on success, report failure and continue on failure bool expectAndConsume(char value, const char* context = nullptr); bool expectAndConsume(Lexeme::Type type, const char* context = nullptr); @@ -423,6 +460,7 @@ private: MatchLexeme endMismatchSuspect; std::vector functionStack; + size_t typeFunctionDepth = 0; DenseHashMap localMap; std::vector localStack; @@ -434,6 +472,7 @@ private: std::vector scratchAttr; std::vector scratchStat; std::vector> scratchString; + std::vector> scratchString2; std::vector scratchExpr; std::vector scratchExprAux; std::vector scratchName; @@ -441,15 +480,20 @@ private: std::vector scratchBinding; std::vector scratchLocal; std::vector scratchTableTypeProps; + std::vector scratchCstTableTypeProps; std::vector scratchType; std::vector scratchTypeOrPack; std::vector scratchDeclaredClassProps; std::vector scratchItem; + std::vector scratchCstItem; std::vector scratchArgName; - std::vector scratchGenericTypes; - std::vector scratchGenericTypePacks; + std::vector scratchGenericTypes; + std::vector scratchGenericTypePacks; std::vector> scratchOptArgName; + std::vector scratchPosition; std::string scratchData; + + CstNodeMap cstNodeMap; }; -} // namespace Luau \ No newline at end of file +} // namespace Luau diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index bd2ca86b..2259f21c 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -7,6 +7,7 @@ #include #include +#include LUAU_FASTFLAG(DebugLuauTimeTracing) diff --git a/Ast/src/Allocator.cpp b/Ast/src/Allocator.cpp new file mode 100644 index 00000000..c7614d8c --- /dev/null +++ b/Ast/src/Allocator.cpp @@ -0,0 +1,66 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Allocator.h" + +namespace Luau +{ + +Allocator::Allocator() + : root(static_cast(operator new(sizeof(Page)))) + , offset(0) +{ + root->next = nullptr; +} + +Allocator::Allocator(Allocator&& rhs) + : root(rhs.root) + , offset(rhs.offset) +{ + rhs.root = nullptr; + rhs.offset = 0; +} + +Allocator::~Allocator() +{ + Page* page = root; + + while (page) + { + Page* next = page->next; + + operator delete(page); + + page = next; + } +} + +void* Allocator::allocate(size_t size) +{ + constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double); + + if (root) + { + uintptr_t data = reinterpret_cast(root->data); + uintptr_t result = (data + offset + align - 1) & ~(align - 1); + if (result + size <= data + sizeof(root->data)) + { + offset = result - data + size; + return reinterpret_cast(result); + } + } + + // allocate new page + size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data); + void* pageData = operator new(offsetof(Page, data) + pageSize); + + Page* page = static_cast(pageData); + + page->next = root; + + root = page; + offset = size; + + return page->data; +} + +} // namespace Luau diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index ff7c7cc6..ab42ec8c 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -3,8 +3,6 @@ #include "Luau/Common.h" -LUAU_FASTFLAG(LuauNativeAttribute); - namespace Luau { @@ -30,6 +28,38 @@ void AstAttr::visit(AstVisitor* visitor) int gAstRttiIndex = 0; +AstGenericType::AstGenericType(const Location& location, AstName name, AstType* defaultValue) + : AstNode(ClassIndex(), location) + , name(name) + , defaultValue(defaultValue) +{ +} + +void AstGenericType::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + if (defaultValue) + defaultValue->visit(visitor); + } +} + +AstGenericTypePack::AstGenericTypePack(const Location& location, AstName name, AstTypePack* defaultValue) + : AstNode(ClassIndex(), location) + , name(name) + , defaultValue(defaultValue) +{ +} + +void AstGenericTypePack::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + if (defaultValue) + defaultValue->visit(visitor); + } +} + AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) : AstExpr(ClassIndex(), location) , expr(expr) @@ -87,6 +117,11 @@ void AstExprConstantString::visit(AstVisitor* visitor) visitor->visit(this); } +bool AstExprConstantString::isQuoted() const +{ + return quoteStyle == QuoteStyle::QuotedSimple || quoteStyle == QuoteStyle::QuotedRaw; +} + AstExprLocal::AstExprLocal(const Location& location, AstLocal* local, bool upvalue) : AstExpr(ClassIndex(), location) , local(local) @@ -182,8 +217,8 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) AstExprFunction::AstExprFunction( const Location& location, const AstArray& attributes, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, AstLocal* self, const AstArray& args, bool vararg, @@ -234,8 +269,6 @@ void AstExprFunction::visit(AstVisitor* visitor) bool AstExprFunction::hasNativeAttribute() const { - LUAU_ASSERT(FFlag::LuauNativeAttribute); - for (const auto attribute : attributes) { if (attribute->type == AstAttr::Type::Native) @@ -720,8 +753,8 @@ AstStatTypeAlias::AstStatTypeAlias( const Location& location, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported ) @@ -739,27 +772,32 @@ void AstStatTypeAlias::visit(AstVisitor* visitor) { if (visitor->visit(this)) { - for (const AstGenericType& el : generics) + for (AstGenericType* el : generics) { - if (el.defaultValue) - el.defaultValue->visit(visitor); + el->visit(visitor); } - for (const AstGenericTypePack& el : genericPacks) + for (AstGenericTypePack* el : genericPacks) { - if (el.defaultValue) - el.defaultValue->visit(visitor); + el->visit(visitor); } type->visit(visitor); } } -AstStatTypeFunction::AstStatTypeFunction(const Location& location, const AstName& name, const Location& nameLocation, AstExprFunction* body) +AstStatTypeFunction::AstStatTypeFunction( + const Location& location, + const AstName& name, + const Location& nameLocation, + AstExprFunction* body, + bool exported +) : AstStat(ClassIndex(), location) , name(name) , nameLocation(nameLocation) , body(body) + , exported(exported) { } @@ -787,8 +825,8 @@ AstStatDeclareFunction::AstStatDeclareFunction( const Location& location, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, bool vararg, @@ -814,8 +852,8 @@ AstStatDeclareFunction::AstStatDeclareFunction( const AstArray& attributes, const AstName& name, const Location& nameLocation, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, bool vararg, @@ -962,8 +1000,8 @@ void AstTypeTable::visit(AstVisitor* visitor) AstTypeFunction::AstTypeFunction( const Location& location, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes @@ -982,8 +1020,8 @@ AstTypeFunction::AstTypeFunction( AstTypeFunction::AstTypeFunction( const Location& location, const AstArray& attributes, - const AstArray& generics, - const AstArray& genericPacks, + const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes @@ -1083,6 +1121,18 @@ void AstTypeSingletonString::visit(AstVisitor* visitor) visitor->visit(this); } +AstTypeGroup::AstTypeGroup(const Location& location, AstType* type) + : AstType(ClassIndex(), location) + , type(type) +{ +} + +void AstTypeGroup::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + type->visit(visitor); +} + AstTypeError::AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex) : AstType(ClassIndex(), location) , types(types) @@ -1141,6 +1191,11 @@ void AstTypePackGeneric::visit(AstVisitor* visitor) visitor->visit(this); } +bool isLValue(const AstExpr* expr) +{ + return expr->is() || expr->is() || expr->is() || expr->is(); +} + AstName getIdentifier(AstExpr* node) { if (AstExprGlobal* expr = node->as()) @@ -1165,4 +1220,4 @@ Location getLocation(const AstTypeList& typeList) return result; } -} // namespace Luau \ No newline at end of file +} // namespace Luau diff --git a/Ast/src/Cst.cpp b/Ast/src/Cst.cpp new file mode 100644 index 00000000..0d1b8352 --- /dev/null +++ b/Ast/src/Cst.cpp @@ -0,0 +1,200 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Ast.h" +#include "Luau/Cst.h" +#include "Luau/Common.h" + +namespace Luau +{ + +int gCstRttiIndex = 0; + +CstExprConstantNumber::CstExprConstantNumber(const AstArray& value) + : CstNode(CstClassIndex()) + , value(value) +{ +} + +CstExprConstantString::CstExprConstantString(AstArray sourceString, QuoteStyle quoteStyle, unsigned int blockDepth) + : CstNode(CstClassIndex()) + , sourceString(sourceString) + , quoteStyle(quoteStyle) + , blockDepth(blockDepth) +{ + LUAU_ASSERT(blockDepth == 0 || quoteStyle == QuoteStyle::QuotedRaw); +} + +CstExprCall::CstExprCall(std::optional openParens, std::optional closeParens, AstArray commaPositions) + : CstNode(CstClassIndex()) + , openParens(openParens) + , closeParens(closeParens) + , commaPositions(commaPositions) +{ +} + +CstExprIndexExpr::CstExprIndexExpr(Position openBracketPosition, Position closeBracketPosition) + : CstNode(CstClassIndex()) + , openBracketPosition(openBracketPosition) + , closeBracketPosition(closeBracketPosition) +{ +} + +CstExprTable::CstExprTable(const AstArray& items) + : CstNode(CstClassIndex()) + , items(items) +{ +} + +CstExprOp::CstExprOp(Position opPosition) + : CstNode(CstClassIndex()) + , opPosition(opPosition) +{ +} + +CstExprTypeAssertion::CstExprTypeAssertion(Position opPosition) + : CstNode(CstClassIndex()) + , opPosition(opPosition) +{ +} + +CstExprIfElse::CstExprIfElse(Position thenPosition, Position elsePosition, bool isElseIf) + : CstNode(CstClassIndex()) + , thenPosition(thenPosition) + , elsePosition(elsePosition) + , isElseIf(isElseIf) +{ +} + +CstExprInterpString::CstExprInterpString(AstArray> sourceStrings, AstArray stringPositions) + : CstNode(CstClassIndex()) + , sourceStrings(sourceStrings) + , stringPositions(stringPositions) +{ +} + +CstStatDo::CstStatDo(Position endPosition) + : CstNode(CstClassIndex()) + , endPosition(endPosition) +{ +} + +CstStatRepeat::CstStatRepeat(Position untilPosition) + : CstNode(CstClassIndex()) + , untilPosition(untilPosition) +{ +} + +CstStatReturn::CstStatReturn(AstArray commaPositions) + : CstNode(CstClassIndex()) + , commaPositions(commaPositions) +{ +} + +CstStatLocal::CstStatLocal(AstArray varsCommaPositions, AstArray valuesCommaPositions) + : CstNode(CstClassIndex()) + , varsCommaPositions(varsCommaPositions) + , valuesCommaPositions(valuesCommaPositions) +{ +} + +CstStatFor::CstStatFor(Position equalsPosition, Position endCommaPosition, std::optional stepCommaPosition) + : CstNode(CstClassIndex()) + , equalsPosition(equalsPosition) + , endCommaPosition(endCommaPosition) + , stepCommaPosition(stepCommaPosition) +{ +} + +CstStatForIn::CstStatForIn(AstArray varsCommaPositions, AstArray valuesCommaPositions) + : CstNode(CstClassIndex()) + , varsCommaPositions(varsCommaPositions) + , valuesCommaPositions(valuesCommaPositions) +{ +} + +CstStatAssign::CstStatAssign(AstArray varsCommaPositions, Position equalsPosition, AstArray valuesCommaPositions) + : CstNode(CstClassIndex()) + , varsCommaPositions(varsCommaPositions) + , equalsPosition(equalsPosition) + , valuesCommaPositions(valuesCommaPositions) +{ +} + +CstStatCompoundAssign::CstStatCompoundAssign(Position opPosition) + : CstNode(CstClassIndex()) + , opPosition(opPosition) +{ +} + +CstStatLocalFunction::CstStatLocalFunction(Position functionKeywordPosition) + : CstNode(CstClassIndex()) + , functionKeywordPosition(functionKeywordPosition) +{ +} + +CstGenericType::CstGenericType(std::optional defaultEqualsPosition) + : CstNode(CstClassIndex()) + , defaultEqualsPosition(defaultEqualsPosition) +{ +} + +CstGenericTypePack::CstGenericTypePack(Position ellipsisPosition, std::optional defaultEqualsPosition) + : CstNode(CstClassIndex()) + , ellipsisPosition(ellipsisPosition) + , defaultEqualsPosition(defaultEqualsPosition) +{ +} + +CstStatTypeAlias::CstStatTypeAlias( + Position typeKeywordPosition, + Position genericsOpenPosition, + AstArray genericsCommaPositions, + Position genericsClosePosition, + Position equalsPosition +) + : CstNode(CstClassIndex()) + , typeKeywordPosition(typeKeywordPosition) + , genericsOpenPosition(genericsOpenPosition) + , genericsCommaPositions(genericsCommaPositions) + , genericsClosePosition(genericsClosePosition) + , equalsPosition(equalsPosition) +{ +} + +CstTypeReference::CstTypeReference( + std::optional prefixPointPosition, + Position openParametersPosition, + AstArray parametersCommaPositions, + Position closeParametersPosition +) + : CstNode(CstClassIndex()) + , prefixPointPosition(prefixPointPosition) + , openParametersPosition(openParametersPosition) + , parametersCommaPositions(parametersCommaPositions) + , closeParametersPosition(closeParametersPosition) +{ +} + +CstTypeTable::CstTypeTable(AstArray items, bool isArray) + : CstNode(CstClassIndex()) + , items(items) + , isArray(isArray) +{ +} + +CstTypeTypeof::CstTypeTypeof(Position openPosition, Position closePosition) + : CstNode(CstClassIndex()) + , openPosition(openPosition) + , closePosition(closePosition) +{ +} + +CstTypeSingletonString::CstTypeSingletonString(AstArray sourceString, CstExprConstantString::QuoteStyle quoteStyle, unsigned int blockDepth) + : CstNode(CstClassIndex()) + , sourceString(sourceString) + , quoteStyle(quoteStyle) + , blockDepth(blockDepth) +{ + LUAU_ASSERT(quoteStyle != CstExprConstantString::QuotedInterp); +} + +} // namespace Luau diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index a5e1d40e..557295e0 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -1,75 +1,19 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Lexer.h" +#include "Luau/Allocator.h" #include "Luau/Common.h" #include "Luau/Confusables.h" #include "Luau/StringUtils.h" #include -LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) +LUAU_FASTFLAGVARIABLE(LexerResumesFromPosition2) +LUAU_FASTFLAGVARIABLE(LexerFixInterpStringStart) namespace Luau { -Allocator::Allocator() - : root(static_cast(operator new(sizeof(Page)))) - , offset(0) -{ - root->next = nullptr; -} - -Allocator::Allocator(Allocator&& rhs) - : root(rhs.root) - , offset(rhs.offset) -{ - rhs.root = nullptr; - rhs.offset = 0; -} - -Allocator::~Allocator() -{ - Page* page = root; - - while (page) - { - Page* next = page->next; - - operator delete(page); - - page = next; - } -} - -void* Allocator::allocate(size_t size) -{ - constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double); - - if (root) - { - uintptr_t data = reinterpret_cast(root->data); - uintptr_t result = (data + offset + align - 1) & ~(align - 1); - if (result + size <= data + sizeof(root->data)) - { - offset = result - data + size; - return reinterpret_cast(result); - } - } - - // allocate new page - size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data); - void* pageData = operator new(offsetof(Page, data) + pageSize); - - Page* page = static_cast(pageData); - - page->next = root; - - root = page; - offset = size; - - return page->data; -} - Lexeme::Lexeme(const Location& location, Type type) : type(type) , location(location) @@ -362,13 +306,48 @@ static char unescape(char ch) } } -Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names) +unsigned int Lexeme::getBlockDepth() const +{ + LUAU_ASSERT(type == Lexeme::RawString || type == Lexeme::BlockComment); + + // If we have a well-formed string, we are guaranteed to see 2 `]` characters after the end of the string contents + LUAU_ASSERT(*(data + length) == ']'); + unsigned int depth = 0; + do + { + depth++; + } while (*(data + length + depth) != ']'); + + return depth - 1; +} + +Lexeme::QuoteStyle Lexeme::getQuoteStyle() const +{ + LUAU_ASSERT(type == Lexeme::QuotedString); + + // If we have a well-formed string, we are guaranteed to see a closing delimiter after the string + LUAU_ASSERT(data); + + char quote = *(data + length); + if (quote == '\'') + return Lexeme::QuoteStyle::Single; + else if (quote == '"') + return Lexeme::QuoteStyle::Double; + + LUAU_ASSERT(!"Unknown quote style"); + return Lexeme::QuoteStyle::Double; // unreachable, but required due to compiler warning +} + +Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names, Position startPosition) : buffer(buffer) , bufferSize(bufferSize) , offset(0) - , line(0) - , lineOffset(0) - , lexeme(Location(Position(0, 0), 0), Lexeme::Eof) + , line(FFlag::LexerResumesFromPosition2 ? startPosition.line : 0) + , lineOffset(FFlag::LexerResumesFromPosition2 ? 0u - startPosition.column : 0) + , lexeme( + (FFlag::LexerResumesFromPosition2 ? Location(Position(startPosition.line, startPosition.column), 0) : Location(Position(0, 0), 0)), + Lexeme::Eof + ) , names(names) , skipComments(false) , readNames(true) @@ -434,13 +413,11 @@ Lexeme Lexer::lookahead() lineOffset = currentLineOffset; lexeme = currentLexeme; prevLocation = currentPrevLocation; - if (FFlag::LuauLexerLookaheadRemembersBraceType) - { - if (braceStack.size() < currentBraceStackSize) - braceStack.push_back(currentBraceType); - else if (braceStack.size() > currentBraceStackSize) - braceStack.pop_back(); - } + + if (braceStack.size() < currentBraceStackSize) + braceStack.push_back(currentBraceType); + else if (braceStack.size() > currentBraceStackSize) + braceStack.pop_back(); return result; } @@ -466,6 +443,7 @@ char Lexer::peekch(unsigned int lookahead) const return (offset + lookahead < bufferSize) ? buffer[offset + lookahead] : 0; } +LUAU_FORCEINLINE Position Lexer::position() const { return Position(line, offset - lineOffset); @@ -815,7 +793,7 @@ Lexeme Lexer::readNext() return Lexeme(Location(start, 1), '}'); } - return readInterpolatedStringSection(position(), Lexeme::InterpStringMid, Lexeme::InterpStringEnd); + return readInterpolatedStringSection(FFlag::LexerFixInterpStringStart ? start : position(), Lexeme::InterpStringMid, Lexeme::InterpStringEnd); } case '=': diff --git a/Ast/src/Location.cpp b/Ast/src/Location.cpp index c2c66d9f..e96fafb7 100644 --- a/Ast/src/Location.cpp +++ b/Ast/src/Location.cpp @@ -4,42 +4,6 @@ namespace Luau { -bool Position::operator==(const Position& rhs) const -{ - return this->column == rhs.column && this->line == rhs.line; -} - -bool Position::operator!=(const Position& rhs) const -{ - return !(*this == rhs); -} - -bool Position::operator<(const Position& rhs) const -{ - if (line == rhs.line) - return column < rhs.column; - else - return line < rhs.line; -} - -bool Position::operator>(const Position& rhs) const -{ - if (line == rhs.line) - return column > rhs.column; - else - return line > rhs.line; -} - -bool Position::operator<=(const Position& rhs) const -{ - return *this == rhs || *this < rhs; -} - -bool Position::operator>=(const Position& rhs) const -{ - return *this == rhs || *this > rhs; -} - void Position::shift(const Position& start, const Position& oldEnd, const Position& newEnd) { if (*this >= start) @@ -54,16 +18,6 @@ void Position::shift(const Position& start, const Position& oldEnd, const Positi } } -bool Location::operator==(const Location& rhs) const -{ - return this->begin == rhs.begin && this->end == rhs.end; -} - -bool Location::operator!=(const Location& rhs) const -{ - return !(*this == rhs); -} - bool Location::encloses(const Location& l) const { return begin <= l.begin && end >= l.end; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 4b9eddda..a7c81dd9 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -8,6 +8,7 @@ #include #include +#include LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauTypeLengthLimit, 1000) @@ -16,10 +17,17 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) // Warning: If you are introducing new syntax, ensure that it is behind a separate // flag so that we don't break production games by reverting syntax changes. // See docs/SyntaxChanges.md for an explanation. -LUAU_FASTFLAGVARIABLE(LuauSolverV2, false) -LUAU_FASTFLAGVARIABLE(LuauNativeAttribute, false) -LUAU_FASTFLAGVARIABLE(LuauAttributeSyntaxFunExpr, false) -LUAU_FASTFLAGVARIABLE(LuauUserDefinedTypeFunctions, false) +LUAU_FASTFLAGVARIABLE(LuauSolverV2) +LUAU_FASTFLAGVARIABLE(LuauAllowComplexTypesInGenericParams) +LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForTableTypes) +LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryForClassNames) +LUAU_FASTFLAGVARIABLE(LuauFixFunctionNameStartPosition) +LUAU_FASTFLAGVARIABLE(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAGVARIABLE(LuauStoreCSTData) +LUAU_FASTFLAGVARIABLE(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) +LUAU_FASTFLAGVARIABLE(LuauAstTypeGroup2) +LUAU_FASTFLAGVARIABLE(ParserNoErrorLimit) +LUAU_FASTFLAGVARIABLE(LuauFixDoBlockEndLocation) namespace Luau { @@ -162,24 +170,47 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n AstStatBlock* root = p.parseChunk(); size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n'); - return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; + return ParseResult{root, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations), std::move(p.cstNodeMap)}; } catch (ParseError& err) { // when catching a fatal error, append it to the list of non-fatal errors and return p.parseErrors.push_back(err); - return ParseResult{nullptr, 0, {}, p.parseErrors}; + return ParseResult{nullptr, 0, {}, p.parseErrors, {}, std::move(p.cstNodeMap)}; + } +} + +ParseExprResult Parser::parseExpr(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options) +{ + LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser"); + + Parser p(buffer, bufferSize, names, allocator, options); + + try + { + AstExpr* expr = p.parseExpr(); + size_t lines = p.lexer.current().location.end.line + (bufferSize > 0 && buffer[bufferSize - 1] != '\n'); + + return ParseExprResult{expr, lines, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations), std::move(p.cstNodeMap)}; + } + catch (ParseError& err) + { + // when catching a fatal error, append it to the list of non-fatal errors and return + p.parseErrors.push_back(err); + + return ParseExprResult{nullptr, 0, {}, p.parseErrors, {}, std::move(p.cstNodeMap)}; } } Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, const ParseOptions& options) : options(options) - , lexer(buffer, bufferSize, names) + , lexer(buffer, bufferSize, names, options.parseFragment ? options.parseFragment->resumePosition : Position(0, 0)) , allocator(allocator) , recursionCounter(0) , endMismatchSuspect(Lexeme(Location(), Lexeme::Eof)) , localMap(AstName()) + , cstNodeMap(nullptr) { Function top; top.vararg = true; @@ -187,9 +218,9 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc functionStack.reserve(8); functionStack.push_back(top); - nameSelf = names.addStatic("self"); - nameNumber = names.addStatic("number"); - nameError = names.addStatic(kParseNameError); + nameSelf = names.getOrAdd("self"); + nameNumber = names.getOrAdd("number"); + nameError = names.getOrAdd(kParseNameError); nameNil = names.getOrAdd("nil"); // nil is a reserved keyword matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0); @@ -211,6 +242,12 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc scratchExpr.reserve(16); scratchLocal.reserve(16); scratchBinding.reserve(16); + + if (options.parseFragment) + { + localMap = options.parseFragment->localMap; + localStack = options.parseFragment->localStack; + } } bool Parser::blockFollow(const Lexeme& l) @@ -267,6 +304,10 @@ AstStatBlock* Parser::parseBlockNoScope() { nextLexeme(); stat->hasSemicolon = true; + if (FFlag::LuauExtendStatEndPosWithSemicolon) + { + stat->location.end = lexer.previousLocation().end; + } } body.push_back(stat); @@ -343,12 +384,13 @@ AstStat* Parser::parseStat() AstName ident = getIdentifier(expr); if (ident == "type") - return parseTypeAlias(expr->location, /* exported= */ false); + return parseTypeAlias(expr->location, /* exported= */ false, expr->location.begin); if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") { + Position typeKeywordPosition = lexer.current().location.begin; nextLexeme(); - return parseTypeAlias(expr->location, /* exported= */ true); + return parseTypeAlias(expr->location, /* exported= */ true, typeKeywordPosition); } if (ident == "continue") @@ -470,6 +512,7 @@ AstStat* Parser::parseRepeat() functionStack.back().loopDepth--; + Position untilPosition = lexer.current().location.begin; bool hasUntil = expectMatchEndAndConsume(Lexeme::ReservedUntil, matchRepeat); body->hasEnd = hasUntil; @@ -477,7 +520,17 @@ AstStat* Parser::parseRepeat() restoreLocals(localsBegin); - return allocator.alloc(Location(start, cond->location), cond, body, hasUntil); + if (FFlag::LuauStoreCSTData) + { + AstStatRepeat* node = allocator.alloc(Location(start, cond->location), cond, body, hasUntil); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(untilPosition); + return node; + } + else + { + return allocator.alloc(Location(start, cond->location), cond, body, hasUntil); + } } // do block end @@ -492,7 +545,13 @@ AstStat* Parser::parseDo() body->location.begin = start.begin; + Location endLocation = lexer.current().location; body->hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + if (FFlag::LuauFixDoBlockEndLocation && body->hasEnd) + body->location.end = endLocation.end; + + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[body] = allocator.alloc(endLocation.begin); return body; } @@ -533,18 +592,22 @@ AstStat* Parser::parseFor() if (lexer.current().type == '=') { + Position equalsPosition = lexer.current().location.begin; nextLexeme(); AstExpr* from = parseExpr(); + Position endCommaPosition = lexer.current().location.begin; expectAndConsume(',', "index range"); AstExpr* to = parseExpr(); + std::optional stepCommaPosition = std::nullopt; AstExpr* step = nullptr; if (lexer.current().type == ',') { + stepCommaPosition = lexer.current().location.begin; nextLexeme(); step = parseExpr(); @@ -570,25 +633,46 @@ AstStat* Parser::parseFor() bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); + if (FFlag::LuauStoreCSTData) + { + AstStatFor* node = allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(equalsPosition, endCommaPosition, stepCommaPosition); + return node; + } + else + { + return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location); + } } else { TempVector names(scratchBinding); + TempVector varsCommaPosition(scratchPosition); names.push_back(varname); if (lexer.current().type == ',') { - nextLexeme(); + if (FFlag::LuauStoreCSTData && options.storeCstData) + { + varsCommaPosition.push_back(lexer.current().location.begin); + nextLexeme(); + parseBindingList(names, false, &varsCommaPosition); + } + else + { + nextLexeme(); - parseBindingList(names); + parseBindingList(names); + } } Location inLocation = lexer.current().location; bool hasIn = expectAndConsume(Lexeme::ReservedIn, "for loop"); TempVector values(scratchExpr); - parseExprList(values); + TempVector valuesCommaPositions(scratchPosition); + parseExprList(values, (FFlag::LuauStoreCSTData && options.storeCstData) ? &valuesCommaPositions : nullptr); Lexeme matchDo = lexer.current(); bool hasDo = expectAndConsume(Lexeme::ReservedDo, "for loop"); @@ -613,12 +697,23 @@ AstStat* Parser::parseFor() bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); body->hasEnd = hasEnd; - return allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); + if (FFlag::LuauStoreCSTData) + { + AstStatForIn* node = + allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(varsCommaPosition), copy(valuesCommaPositions)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location); + } } } // funcname ::= Name {`.' Name} [`:' Name] -AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debugname) +AstExpr* Parser::parseFunctionName(Location start_DEPRECATED, bool& hasself, AstName& debugname) { if (lexer.current().type == Lexeme::Name) debugname = AstName(lexer.current().name); @@ -638,7 +733,14 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug // while we could concatenate the name chain, for now let's just write the short name debugname = name.name; - expr = allocator.alloc(Location(start, name.location), expr, name.name, name.location, opPosition, '.'); + expr = allocator.alloc( + Location(FFlag::LuauFixFunctionNameStartPosition ? expr->location : start_DEPRECATED, name.location), + expr, + name.name, + name.location, + opPosition, + '.' + ); // note: while the parser isn't recursive here, we're generating recursive structures of unbounded depth incrementRecursionCounter("function name"); @@ -657,7 +759,14 @@ AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debug // while we could concatenate the name chain, for now let's just write the short name debugname = name.name; - expr = allocator.alloc(Location(start, name.location), expr, name.name, name.location, opPosition, ':'); + expr = allocator.alloc( + Location(FFlag::LuauFixFunctionNameStartPosition ? expr->location : start_DEPRECATED, name.location), + expr, + name.name, + name.location, + opPosition, + ':' + ); hasself = true; } @@ -701,10 +810,6 @@ std::pair Parser::validateAttribute(const char* attributeNa if (found) { type = kAttributeEntries[i].type; - - if (!FFlag::LuauNativeAttribute && type == AstAttr::Type::Native) - found = false; - break; } } @@ -809,6 +914,7 @@ AstStat* Parser::parseLocal(const AstArray& attributes) Lexeme matchFunction = lexer.current(); nextLexeme(); + Position functionKeywordPosition = matchFunction.location.begin; // matchFunction is only used for diagnostics; to make it suitable for detecting missed indentation between // `local function` and `end`, we patch the token to begin at the column where `local` starts if (matchFunction.location.begin.line == start.begin.line) @@ -824,7 +930,17 @@ AstStat* Parser::parseLocal(const AstArray& attributes) Location location{start.begin, body->location.end}; - return allocator.alloc(location, var, body); + if (FFlag::LuauStoreCSTData) + { + AstStatLocalFunction* node = allocator.alloc(location, var, body); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(functionKeywordPosition); + return node; + } + else + { + return allocator.alloc(location, var, body); + } } else { @@ -842,13 +958,18 @@ AstStat* Parser::parseLocal(const AstArray& attributes) matchRecoveryStopOnToken['=']++; TempVector names(scratchBinding); - parseBindingList(names); + TempVector varsCommaPositions(scratchPosition); + if (FFlag::LuauStoreCSTData && options.storeCstData) + parseBindingList(names, false, &varsCommaPositions); + else + parseBindingList(names); matchRecoveryStopOnToken['=']--; TempVector vars(scratchLocal); TempVector values(scratchExpr); + TempVector valuesCommaPositions(scratchPosition); std::optional equalsSignLocation; @@ -858,7 +979,7 @@ AstStat* Parser::parseLocal(const AstArray& attributes) nextLexeme(); - parseExprList(values); + parseExprList(values, (FFlag::LuauStoreCSTData && options.storeCstData) ? &valuesCommaPositions : nullptr); } for (size_t i = 0; i < names.size(); ++i) @@ -866,7 +987,17 @@ AstStat* Parser::parseLocal(const AstArray& attributes) Location end = values.empty() ? lexer.previousLocation() : values.back()->location; - return allocator.alloc(Location(start, end), copy(vars), copy(values), equalsSignLocation); + if (FFlag::LuauStoreCSTData) + { + AstStatLocal* node = allocator.alloc(Location(start, end), copy(vars), copy(values), equalsSignLocation); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(varsCommaPositions), copy(valuesCommaPositions)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(vars), copy(values), equalsSignLocation); + } } } @@ -878,24 +1009,32 @@ AstStat* Parser::parseReturn() nextLexeme(); TempVector list(scratchExpr); + TempVector commaPositions(scratchPosition); if (!blockFollow(lexer.current()) && lexer.current().type != ';') - parseExprList(list); + parseExprList(list, (FFlag::LuauStoreCSTData && options.storeCstData) ? &commaPositions : nullptr); Location end = list.empty() ? start : list.back()->location; - return allocator.alloc(Location(start, end), copy(list)); + if (FFlag::LuauStoreCSTData) + { + AstStatReturn* node = allocator.alloc(Location(start, end), copy(list)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(commaPositions)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(list)); + } } // type Name [`<' varlist `>'] `=' Type -AstStat* Parser::parseTypeAlias(const Location& start, bool exported) +AstStat* Parser::parseTypeAlias(const Location& start, bool exported, Position typeKeywordPosition) { // parsing a type function - if (FFlag::LuauUserDefinedTypeFunctions) - { - if (lexer.current().type == Lexeme::ReservedFunction) - return parseTypeFunction(start); - } + if (lexer.current().type == Lexeme::ReservedFunction) + return parseTypeFunction(start, exported); // parsing a type alias @@ -907,17 +1046,38 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ true); + Position genericsOpenPosition{0, 0}; + TempVector genericsCommaPositions(scratchPosition); + Position genericsClosePosition{0, 0}; + auto [generics, genericPacks] = FFlag::LuauStoreCSTData && options.storeCstData + ? parseGenericTypeList( + /* withDefaultValues= */ true, &genericsOpenPosition, &genericsCommaPositions, &genericsClosePosition + ) + : parseGenericTypeList(/* withDefaultValues= */ true); + Position equalsPosition = lexer.current().location.begin; expectAndConsume('=', "type alias"); AstType* type = parseType(); - return allocator.alloc(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported); + if (FFlag::LuauStoreCSTData) + { + AstStatTypeAlias* node = + allocator.alloc(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc( + typeKeywordPosition, genericsOpenPosition, copy(genericsCommaPositions), genericsClosePosition, equalsPosition + ); + return node; + } + else + { + return allocator.alloc(Location(start, type->location), name->name, name->location, generics, genericPacks, type, exported); + } } // type function Name `(' arglist `)' `=' funcbody `end' -AstStat* Parser::parseTypeFunction(const Location& start) +AstStat* Parser::parseTypeFunction(const Location& start, bool exported) { Lexeme matchFn = lexer.current(); nextLexeme(); @@ -929,11 +1089,16 @@ AstStat* Parser::parseTypeFunction(const Location& start) matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; + size_t oldTypeFunctionDepth = typeFunctionDepth; + typeFunctionDepth = functionStack.size(); + AstExprFunction* body = parseFunctionBody(/* hasself */ false, matchFn, fnName->name, nullptr, AstArray({nullptr, 0})).first; + typeFunctionDepth = oldTypeFunctionDepth; + matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; - return allocator.alloc(Location(start, body->location), fnName->name, fnName->location, body); + return allocator.alloc(Location(start, body->location), fnName->name, fnName->location, body, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -945,8 +1110,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; generics.size = 0; generics.data = nullptr; genericPacks.size = 0; @@ -1104,7 +1269,7 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArraydata, chars->size) < chars->size); + bool containsNull = chars && (memchr(chars->data, 0, chars->size) != nullptr); if (chars && !containsNull) { @@ -1123,24 +1288,49 @@ AstStat* Parser::parseDeclaration(const Location& start, const AstArraylocation, "Cannot have more than one class indexer"); } else { - indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt); + if (FFlag::LuauStoreCSTData) + indexer = parseTableIndexer(AstTableAccess::ReadWrite, std::nullopt).node; + else + indexer = parseTableIndexer_DEPRECATED(AstTableAccess::ReadWrite, std::nullopt); } } else { - Location propStart = lexer.current().location; - Name propName = parseName("property name"); - expectAndConsume(':', "property type annotation"); - AstType* propType = parseType(); - props.push_back(AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())} - ); + if (FFlag::LuauErrorRecoveryForClassNames) + { + Location propStart = lexer.current().location; + std::optional propName = parseNameOpt("property name"); + + if (!propName) + break; + + expectAndConsume(':', "property type annotation"); + AstType* propType = parseType(); + props.push_back( + AstDeclaredClassProp{propName->name, propName->location, propType, false, Location(propStart, lexer.previousLocation())} + ); + } + else + { + Location propStart = lexer.current().location; + Name propName = parseName("property name"); + expectAndConsume(':', "property type annotation"); + AstType* propType = parseType(); + props.push_back( + AstDeclaredClassProp{propName.name, propName.location, propType, false, Location(propStart, lexer.previousLocation())} + ); + } } } @@ -1174,10 +1364,13 @@ AstStat* Parser::parseAssignment(AstExpr* initial) initial = reportExprError(initial->location, copy({initial}), "Assigned expression must be a variable or a field"); TempVector vars(scratchExpr); + TempVector varsCommaPositions(scratchPosition); vars.push_back(initial); while (lexer.current().type == ',') { + if (FFlag::LuauStoreCSTData && options.storeCstData) + varsCommaPositions.push_back(lexer.current().location.begin); nextLexeme(); AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); @@ -1188,12 +1381,23 @@ AstStat* Parser::parseAssignment(AstExpr* initial) vars.push_back(expr); } + Position equalsPosition = lexer.current().location.begin; expectAndConsume('=', "assignment"); TempVector values(scratchExprAux); - parseExprList(values); + TempVector valuesCommaPositions(scratchPosition); + parseExprList(values, FFlag::LuauStoreCSTData && options.storeCstData ? &valuesCommaPositions : nullptr); - return allocator.alloc(Location(initial->location, values.back()->location), copy(vars), copy(values)); + if (FFlag::LuauStoreCSTData) + { + AstStatAssign* node = allocator.alloc(Location(initial->location, values.back()->location), copy(vars), copy(values)); + cstNodeMap[node] = allocator.alloc(copy(varsCommaPositions), equalsPosition, copy(valuesCommaPositions)); + return node; + } + else + { + return allocator.alloc(Location(initial->location, values.back()->location), copy(vars), copy(values)); + } } // var [`+=' | `-=' | `*=' | `/=' | `%=' | `^=' | `..='] exp @@ -1204,11 +1408,22 @@ AstStat* Parser::parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op) initial = reportExprError(initial->location, copy({initial}), "Assigned expression must be a variable or a field"); } + Position opPosition = lexer.current().location.begin; nextLexeme(); AstExpr* value = parseExpr(); - return allocator.alloc(Location(initial->location, value->location), op, initial, value); + if (FFlag::LuauStoreCSTData) + { + AstStatCompoundAssign* node = allocator.alloc(Location(initial->location, value->location), op, initial, value); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(opPosition); + return node; + } + else + { + return allocator.alloc(Location(initial->location, value->location), op, initial, value); + } } std::pair> Parser::prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args) @@ -1243,6 +1458,19 @@ std::pair Parser::parseFunctionBody( MatchLexeme matchParen = lexer.current(); expectAndConsume('(', "function"); + // NOTE: This was added in conjunction with passing `searchForMissing` to + // `expectMatchAndConsume` inside `parseTableType` so that the behavior of + // parsing code like below (note the missing `}`): + // + // function (t: { a: number ) end + // + // ... will still parse as (roughly): + // + // function (t: { a: number }) end + // + if (FFlag::LuauErrorRecoveryForTableTypes) + matchRecoveryStopOnToken[')']++; + TempVector args(scratchBinding); bool vararg = false; @@ -1259,6 +1487,9 @@ std::pair Parser::parseFunctionBody( expectMatchAndConsume(')', matchParen, true); + if (FFlag::LuauErrorRecoveryForTableTypes) + matchRecoveryStopOnToken[')']--; + std::optional typelist = parseOptionalReturnType(); AstLocal* funLocal = nullptr; @@ -1308,12 +1539,14 @@ std::pair Parser::parseFunctionBody( } // explist ::= {exp `,'} exp -void Parser::parseExprList(TempVector& result) +void Parser::parseExprList(TempVector& result, TempVector* commaPositions) { result.push_back(parseExpr()); while (lexer.current().type == ',') { + if (FFlag::LuauStoreCSTData && commaPositions) + commaPositions->push_back(lexer.current().location.begin); nextLexeme(); if (lexer.current().type == ')') @@ -1340,7 +1573,7 @@ Parser::Binding Parser::parseBinding() } // bindinglist ::= (binding | `...') [`,' bindinglist] -std::tuple Parser::parseBindingList(TempVector& result, bool allowDot3) +std::tuple Parser::parseBindingList(TempVector& result, bool allowDot3, TempVector* commaPositions) { while (true) { @@ -1363,6 +1596,8 @@ std::tuple Parser::parseBindingList(TempVectorpush_back(lexer.current().location.begin); nextLexeme(); } @@ -1497,15 +1732,31 @@ std::pair Parser::parseReturnType() if (lexer.current().type != Lexeme::SkinnyArrow && resultNames.empty()) { // If it turns out that it's just '(A)', it's possible that there are unions/intersections to follow, so fold over it. - if (result.size() == 1) + if (FFlag::LuauAstTypeGroup2) { - AstType* returnType = parseTypeSuffix(result[0], innerBegin); + if (result.size() == 1 && varargAnnotation == nullptr) + { + AstType* returnType = parseTypeSuffix(allocator.alloc(location, result[0]), begin.location); - // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack - // type to successfully parse. We need the span of the whole annotation. - Position endPos = result.size() == 1 ? location.end : returnType->location.end; + // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack + // type to successfully parse. We need the span of the whole annotation. + Position endPos = result.size() == 1 ? location.end : returnType->location.end; - return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + } + } + else + { + if (result.size() == 1) + { + AstType* returnType = parseTypeSuffix(result[0], innerBegin); + + // If parseType parses nothing, then returnType->location.end only points at the last non-type-pack + // type to successfully parse. We need the span of the whole annotation. + Position endPos = result.size() == 1 ? location.end : returnType->location.end; + + return {Location{location.begin, endPos}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + } } return {location, AstTypeList{copy(result), varargAnnotation}}; @@ -1516,8 +1767,61 @@ std::pair Parser::parseReturnType() return {Location{location, tail->location}, AstTypeList{copy(&tail, 1), varargAnnotation}}; } +std::pair Parser::extractStringDetails() +{ + LUAU_ASSERT(FFlag::LuauStoreCSTData); + + CstExprConstantString::QuoteStyle style; + unsigned int blockDepth = 0; + + switch (lexer.current().type) + { + case Lexeme::QuotedString: + style = + lexer.current().getQuoteStyle() == Lexeme::QuoteStyle::Double ? CstExprConstantString::QuotedDouble : CstExprConstantString::QuotedSingle; + break; + case Lexeme::InterpStringSimple: + style = CstExprConstantString::QuotedInterp; + break; + case Lexeme::RawString: + { + style = CstExprConstantString::QuotedRaw; + blockDepth = lexer.current().getBlockDepth(); + break; + } + default: + LUAU_ASSERT(false && "Invalid string type"); + } + + return {style, blockDepth}; +} + // TableIndexer ::= `[' Type `]' `:' Type -AstTableIndexer* Parser::parseTableIndexer(AstTableAccess access, std::optional accessLocation) +Parser::TableIndexerResult Parser::parseTableIndexer(AstTableAccess access, std::optional accessLocation) +{ + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + + AstType* index = parseType(); + + Position indexerClosePosition = lexer.current().location.begin; + expectMatchAndConsume(']', begin); + + Position colonPosition = lexer.current().location.begin; + expectAndConsume(':', "table field"); + + AstType* result = parseType(); + + return { + allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location), access, accessLocation}), + begin.location.begin, + indexerClosePosition, + colonPosition, + }; +} + +// Remove with FFlagLuauStoreCSTData +AstTableIndexer* Parser::parseTableIndexer_DEPRECATED(AstTableAccess access, std::optional accessLocation) { const Lexeme begin = lexer.current(); nextLexeme(); // [ @@ -1542,6 +1846,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) incrementRecursionCounter("type annotation"); TempVector props(scratchTableTypeProps); + TempVector cstItems(scratchCstTableTypeProps); AstTableIndexer* indexer = nullptr; Location start = lexer.current().location; @@ -1549,6 +1854,8 @@ AstType* Parser::parseTableType(bool inDeclarationContext) MatchLexeme matchBrace = lexer.current(); expectAndConsume('{', "table type"); + bool isArray = false; + while (lexer.current().type != '}') { AstTableAccess access = AstTableAccess::ReadWrite; @@ -1574,18 +1881,39 @@ AstType* Parser::parseTableType(bool inDeclarationContext) { const Lexeme begin = lexer.current(); nextLexeme(); // [ - std::optional> chars = parseCharArray(); + CstExprConstantString::QuoteStyle style; + unsigned int blockDepth = 0; + if (FFlag::LuauStoreCSTData && options.storeCstData) + std::tie(style, blockDepth) = extractStringDetails(); + + AstArray sourceString; + std::optional> chars = parseCharArray(options.storeCstData ? &sourceString : nullptr); + + Position indexerClosePosition = lexer.current().location.begin; expectMatchAndConsume(']', begin); + Position colonPosition = lexer.current().location.begin; expectAndConsume(':', "table field"); AstType* type = parseType(); // since AstName contains a char*, it can't contain null - bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + bool containsNull = chars && (memchr(chars->data, 0, chars->size) != nullptr); if (chars && !containsNull) + { props.push_back(AstTableProp{AstName(chars->data), begin.location, type, access, accessLocation}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back(CstTypeTable::Item{ + CstTypeTable::Item::Kind::StringProperty, + begin.location.begin, + indexerClosePosition, + colonPosition, + tableSeparator(), + lexer.current().location.begin, + allocator.alloc(sourceString, style, blockDepth) + }); + } else report(begin.location, "String literal contains malformed escape sequence or \\0"); } @@ -1595,14 +1923,35 @@ AstType* Parser::parseTableType(bool inDeclarationContext) { // maybe we don't need to parse the entire badIndexer... // however, we either have { or [ to lint, not the entire table type or the bad indexer. - AstTableIndexer* badIndexer = parseTableIndexer(access, accessLocation); + AstTableIndexer* badIndexer; + if (FFlag::LuauStoreCSTData) + badIndexer = parseTableIndexer(access, accessLocation).node; + else + badIndexer = parseTableIndexer_DEPRECATED(access, accessLocation); // we lose all additional indexer expressions from the AST after error recovery here report(badIndexer->location, "Cannot have more than one table indexer"); } else { - indexer = parseTableIndexer(access, accessLocation); + if (FFlag::LuauStoreCSTData) + { + auto tableIndexerResult = parseTableIndexer(access, accessLocation); + indexer = tableIndexerResult.node; + if (options.storeCstData) + cstItems.push_back(CstTypeTable::Item{ + CstTypeTable::Item::Kind::Indexer, + tableIndexerResult.indexerOpenPosition, + tableIndexerResult.indexerClosePosition, + tableIndexerResult.colonPosition, + tableSeparator(), + lexer.current().location.begin, + }); + } + else + { + indexer = parseTableIndexer_DEPRECATED(access, accessLocation); + } } } else if (props.empty() && !indexer && !(lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':')) @@ -1610,6 +1959,7 @@ AstType* Parser::parseTableType(bool inDeclarationContext) AstType* type = parseType(); // array-like table type: {T} desugars into {[number]: T} + isArray = true; AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber, std::nullopt, type->location); indexer = allocator.alloc(AstTableIndexer{index, type, type->location, access, accessLocation}); @@ -1622,11 +1972,21 @@ AstType* Parser::parseTableType(bool inDeclarationContext) if (!name) break; + Position colonPosition = lexer.current().location.begin; expectAndConsume(':', "table field"); AstType* type = parseType(inDeclarationContext); props.push_back(AstTableProp{name->name, name->location, type, access, accessLocation}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back(CstTypeTable::Item{ + CstTypeTable::Item::Kind::Property, + Position{0, 0}, + Position{0, 0}, + colonPosition, + tableSeparator(), + lexer.current().location.begin + }); } if (lexer.current().type == ',' || lexer.current().type == ';') @@ -1642,10 +2002,20 @@ AstType* Parser::parseTableType(bool inDeclarationContext) Location end = lexer.current().location; - if (!expectMatchAndConsume('}', matchBrace)) + if (!expectMatchAndConsume('}', matchBrace, /* searchForMissing = */ FFlag::LuauErrorRecoveryForTableTypes)) end = lexer.previousLocation(); - return allocator.alloc(Location(start, end), copy(props), indexer); + if (FFlag::LuauStoreCSTData) + { + AstTypeTable* node = allocator.alloc(Location(start, end), copy(props), indexer); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(cstItems), isArray); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(props), indexer); + } } // ReturnType ::= Type | `(' TypeList `)' @@ -1673,6 +2043,7 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray if (lexer.current().type != ')') varargAnnotation = parseTypeList(params, names); + Location closeArgsLocation = lexer.current().location; expectMatchAndConsume(')', parameterStart, true); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; @@ -1690,7 +2061,12 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray if (allowPack) return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; else - return {params[0], {}}; + { + if (FFlag::LuauAstTypeGroup2) + return {allocator.alloc(Location(parameterStart.location, closeArgsLocation), params[0]), {}}; + else + return {params[0], {}}; + } } if (!forceFunctionType && !returnTypeIntroducer && allowPack) @@ -1704,8 +2080,8 @@ AstTypeOrPack Parser::parseFunctionType(bool allowPack, const AstArray AstType* Parser::parseFunctionTypeTail( const Lexeme& begin, const AstArray& attributes, - AstArray generics, - AstArray genericPacks, + AstArray generics, + AstArray genericPacks, AstArray params, AstArray> paramNames, AstTypePack* varargAnnotation @@ -1738,6 +2114,11 @@ AstType* Parser::parseFunctionTypeTail( ); } +static bool isTypeFollow(Lexeme::Type c) +{ + return c == '|' || c == '?' || c == '&'; +} + // Type ::= // nil | // Name[`.' Name] [`<' namelist `>'] | @@ -1807,8 +2188,16 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) ParseError::raise(parts.back()->location, "Exceeded allowed type length; simplify your type annotation to make the code compile"); } - if (parts.size() == 1) - return parts[0]; + if (FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) + { + if (parts.size() == 1 && !isUnion && !isIntersection) + return parts[0]; + } + else + { + if (parts.size() == 1) + return parts[0]; + } if (isUnion && isIntersection) { @@ -1831,7 +2220,7 @@ AstType* Parser::parseTypeSuffix(AstType* type, const Location& begin) ParseError::raise(begin, "Composite type was not an intersection or union."); } -AstTypeOrPack Parser::parseTypeOrPack() +AstTypeOrPack Parser::parseSimpleTypeOrPack() { unsigned int oldRecursionCount = recursionCounter; // recursion counter is incremented in parseSimpleType @@ -1912,13 +2301,35 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { - if (std::optional> value = parseCharArray()) + if (FFlag::LuauStoreCSTData) { - AstArray svalue = *value; - return {allocator.alloc(start, svalue)}; + CstExprConstantString::QuoteStyle style; + unsigned int blockDepth = 0; + if (options.storeCstData) + std::tie(style, blockDepth) = extractStringDetails(); + + AstArray originalString; + if (std::optional> value = parseCharArray(options.storeCstData ? &originalString : nullptr)) + { + AstArray svalue = *value; + auto node = allocator.alloc(start, svalue); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(originalString, style, blockDepth); + return {node}; + } + else + return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; } else - return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; + { + if (std::optional> value = parseCharArray()) + { + AstArray svalue = *value; + return {allocator.alloc(start, svalue)}; + } + else + return {reportTypeError(start, {}, "String literal contains malformed escape sequence")}; + } } else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) { @@ -1934,17 +2345,30 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) else if (lexer.current().type == Lexeme::Name) { std::optional prefix; + std::optional prefixPointPosition; std::optional prefixLocation; Name name = parseName("type name"); if (lexer.current().type == '.') { - Position pointPosition = lexer.current().location.begin; - nextLexeme(); + if (FFlag::LuauStoreCSTData) + { + prefixPointPosition = lexer.current().location.begin; + nextLexeme(); - prefix = name.name; - prefixLocation = name.location; - name = parseIndexName("field name", pointPosition); + prefix = name.name; + prefixLocation = name.location; + name = parseIndexName("field name", *prefixPointPosition); + } + else + { + Position pointPosition = lexer.current().location.begin; + nextLexeme(); + + prefix = name.name; + prefixLocation = name.location; + name = parseIndexName("field name", pointPosition); + } } else if (lexer.current().type == Lexeme::Dot3) { @@ -1962,23 +2386,53 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack, bool inDeclarationContext) expectMatchAndConsume(')', typeofBegin); - return {allocator.alloc(Location(start, end), expr), {}}; + if (FFlag::LuauStoreCSTData) + { + AstTypeTypeof* node = allocator.alloc(Location(start, end), expr); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(typeofBegin.location.begin, end.begin); + return {node, {}}; + } + else + { + return {allocator.alloc(Location(start, end), expr), {}}; + } } bool hasParameters = false; AstArray parameters{}; + Position parametersOpeningPosition{0, 0}; + TempVector parametersCommaPositions(scratchPosition); + Position parametersClosingPosition{0, 0}; if (lexer.current().type == '<') { hasParameters = true; - parameters = parseTypeParams(); + if (FFlag::LuauStoreCSTData && options.storeCstData) + parameters = parseTypeParams(¶metersOpeningPosition, ¶metersCommaPositions, ¶metersClosingPosition); + else + parameters = parseTypeParams(); } Location end = lexer.previousLocation(); - return { - allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), {} - }; + if (FFlag::LuauStoreCSTData) + { + AstTypeReference* node = + allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc( + prefixPointPosition, parametersOpeningPosition, copy(parametersCommaPositions), parametersClosingPosition + ); + return {node, {}}; + } + else + { + return { + allocator.alloc(Location(start, end), prefix, name.name, prefixLocation, name.location, hasParameters, parameters), + {} + }; + } } else if (lexer.current().type == '{') { @@ -2175,7 +2629,8 @@ std::optional Parser::checkBinaryConfusables(const BinaryOpPr report(Location(start, next.location), "Unexpected '||'; did you mean 'or'?"); return AstExprBinary::Or; } - else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::CompareNe].left > limit) + else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && + binaryPriority[AstExprBinary::CompareNe].left > limit) { nextLexeme(); report(Location(start, next.location), "Unexpected '!='; did you mean '~='?"); @@ -2228,11 +2683,14 @@ AstExpr* Parser::parseExpr(unsigned int limit) if (uop) { + Position opPosition = lexer.current().location.begin; nextLexeme(); AstExpr* subexpr = parseExpr(unaryPriority); expr = allocator.alloc(Location(start, subexpr->location), *uop, subexpr); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[expr] = allocator.alloc(opPosition); } else { @@ -2247,12 +2705,15 @@ AstExpr* Parser::parseExpr(unsigned int limit) while (op && binaryPriority[*op].left > limit) { + Position opPosition = lexer.current().location.begin; nextLexeme(); // read sub-expression with higher priority AstExpr* next = parseExpr(binaryPriority[*op].right); expr = allocator.alloc(Location(start, next->location), *op, expr, next); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[expr] = allocator.alloc(opPosition); op = parseBinaryOp(lexer.current()); if (!op) @@ -2281,6 +2742,9 @@ AstExpr* Parser::parseNameExpr(const char* context) { AstLocal* local = *value; + if (local->functionDepth < typeFunctionDepth) + return reportExprError(lexer.current().location, {}, "Type function cannot reference outer local '%s'", local->name.value); + return allocator.alloc(name->location, local, local->functionDepth != functionStack.size() - 1); } @@ -2349,11 +2813,14 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement) AstExpr* index = parseExpr(); + Position closeBracketPosition = lexer.current().location.begin; Position end = lexer.current().location.end; expectMatchAndConsume(']', matchBracket); expr = allocator.alloc(Location(start, end), expr, index); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[expr] = allocator.alloc(matchBracket.position, closeBracketPosition); } else if (lexer.current().type == ':') { @@ -2402,9 +2869,18 @@ AstExpr* Parser::parseAssertionExpr() if (lexer.current().type == Lexeme::DoubleColon) { + CstExprTypeAssertion* cstNode = nullptr; + if (FFlag::LuauStoreCSTData && options.storeCstData) + { + Position opPosition = lexer.current().location.begin; + cstNode = allocator.alloc(opPosition); + } nextLexeme(); AstType* annotation = parseType(); - return allocator.alloc(Location(start, annotation->location), expr, annotation); + AstExprTypeAssertion* node = allocator.alloc(Location(start, annotation->location), expr, annotation); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstNodeMap[node] = cstNode; + return node; } else return expr; @@ -2479,7 +2955,7 @@ AstExpr* Parser::parseSimpleExpr() AstArray attributes{nullptr, 0}; - if (FFlag::LuauAttributeSyntaxFunExpr && lexer.current().type == Lexeme::Attribute) + if (lexer.current().type == Lexeme::Attribute) { attributes = parseAttributes(); @@ -2520,7 +2996,8 @@ AstExpr* Parser::parseSimpleExpr() { return parseNumber(); } - else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || + lexer.current().type == Lexeme::InterpStringSimple) { return parseString(); } @@ -2580,16 +3057,27 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) nextLexeme(); TempVector args(scratchExpr); + TempVector commaPositions(scratchPosition); if (lexer.current().type != ')') - parseExprList(args); + parseExprList(args, (FFlag::LuauStoreCSTData && options.storeCstData) ? &commaPositions : nullptr); Location end = lexer.current().location; Position argEnd = end.end; expectMatchAndConsume(')', matchParen); - return allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + if (FFlag::LuauStoreCSTData) + { + AstExprCall* node = allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(matchParen.position, lexer.previousLocation().begin, copy(commaPositions)); + return node; + } + else + { + return allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + } } else if (lexer.current().type == '{') { @@ -2597,14 +3085,35 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self) AstExpr* expr = parseTableConstructor(); Position argEnd = lexer.previousLocation().end; - return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + if (FFlag::LuauStoreCSTData) + { + AstExprCall* node = + allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(std::nullopt, std::nullopt, AstArray{nullptr, 0}); + return node; + } + else + { + return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + } } else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { Location argLocation = lexer.current().location; AstExpr* expr = parseString(); - return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + if (FFlag::LuauStoreCSTData) + { + AstExprCall* node = allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(std::nullopt, std::nullopt, AstArray{nullptr, 0}); + return node; + } + else + { + return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + } } else { @@ -2638,6 +3147,17 @@ LUAU_NOINLINE void Parser::reportAmbiguousCallError() ); } +std::optional Parser::tableSeparator() +{ + LUAU_ASSERT(FFlag::LuauStoreCSTData); + if (lexer.current().type == ',') + return CstExprTable::Comma; + else if (lexer.current().type == ';') + return CstExprTable::Semicolon; + else + return std::nullopt; +} + // tableconstructor ::= `{' [fieldlist] `}' // fieldlist ::= field {fieldsep field} [fieldsep] // field ::= `[' exp `]' `=' exp | Name `=' exp | exp @@ -2645,6 +3165,7 @@ LUAU_NOINLINE void Parser::reportAmbiguousCallError() AstExpr* Parser::parseTableConstructor() { TempVector items(scratchItem); + TempVector cstItems(scratchCstItem); Location start = lexer.current().location; @@ -2658,23 +3179,29 @@ AstExpr* Parser::parseTableConstructor() if (lexer.current().type == '[') { + Position indexerOpenPosition = lexer.current().location.begin; MatchLexeme matchLocationBracket = lexer.current(); nextLexeme(); AstExpr* key = parseExpr(); + Position indexerClosePosition = lexer.current().location.begin; expectMatchAndConsume(']', matchLocationBracket); + Position equalsPosition = lexer.current().location.begin; expectAndConsume('=', "table field"); AstExpr* value = parseExpr(); items.push_back({AstExprTable::Item::General, key, value}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back({indexerOpenPosition, indexerClosePosition, equalsPosition, tableSeparator(), lexer.current().location.begin}); } else if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == '=') { Name name = parseName("table field"); + Position equalsPosition = lexer.current().location.begin; expectAndConsume('=', "table field"); AstArray nameString; @@ -2688,12 +3215,16 @@ AstExpr* Parser::parseTableConstructor() func->debugname = name.name; items.push_back({AstExprTable::Item::Record, key, value}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back({std::nullopt, std::nullopt, equalsPosition, tableSeparator(), lexer.current().location.begin}); } else { AstExpr* expr = parseExpr(); items.push_back({AstExprTable::Item::List, nullptr, expr}); + if (FFlag::LuauStoreCSTData && options.storeCstData) + cstItems.push_back({std::nullopt, std::nullopt, std::nullopt, tableSeparator(), lexer.current().location.begin}); } if (lexer.current().type == ',' || lexer.current().type == ';') @@ -2715,7 +3246,17 @@ AstExpr* Parser::parseTableConstructor() if (!expectMatchAndConsume('}', matchBrace)) end = lexer.previousLocation(); - return allocator.alloc(Location(start, end), copy(items)); + if (FFlag::LuauStoreCSTData) + { + AstExprTable* node = allocator.alloc(Location(start, end), copy(items)); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(cstItems)); + return node; + } + else + { + return allocator.alloc(Location(start, end), copy(items)); + } } AstExpr* Parser::parseIfElseExpr() @@ -2727,11 +3268,14 @@ AstExpr* Parser::parseIfElseExpr() AstExpr* condition = parseExpr(); + Position thenPosition = lexer.current().location.begin; bool hasThen = expectAndConsume(Lexeme::ReservedThen, "if then else expression"); AstExpr* trueExpr = parseExpr(); AstExpr* falseExpr = nullptr; + Position elsePosition = lexer.current().location.begin; + bool isElseIf = false; if (lexer.current().type == Lexeme::ReservedElseif) { unsigned int oldRecursionCount = recursionCounter; @@ -2739,6 +3283,8 @@ AstExpr* Parser::parseIfElseExpr() hasElse = true; falseExpr = parseIfElseExpr(); recursionCounter = oldRecursionCount; + if (FFlag::LuauStoreCSTData) + isElseIf = true; } else { @@ -2748,7 +3294,17 @@ AstExpr* Parser::parseIfElseExpr() Location end = falseExpr->location; - return allocator.alloc(Location(start, end), condition, hasThen, trueExpr, hasElse, falseExpr); + if (FFlag::LuauStoreCSTData) + { + AstExprIfElse* node = allocator.alloc(Location(start, end), condition, hasThen, trueExpr, hasElse, falseExpr); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(thenPosition, elsePosition, isElseIf); + return node; + } + else + { + return allocator.alloc(Location(start, end), condition, hasThen, trueExpr, hasElse, falseExpr); + } } // Name @@ -2801,14 +3357,21 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou return Name(nameError, location); } -std::pair, AstArray> Parser::parseGenericTypeList(bool withDefaultValues) +std::pair, AstArray> Parser::parseGenericTypeList( + bool withDefaultValues, + Position* openPosition, + TempVector* commaPositions, + Position* closePosition +) { - TempVector names{scratchGenericTypes}; - TempVector namePacks{scratchGenericTypePacks}; + TempVector names{scratchGenericTypes}; + TempVector namePacks{scratchGenericTypePacks}; if (lexer.current().type == '<') { Lexeme begin = lexer.current(); + if (FFlag::LuauStoreCSTData && openPosition) + *openPosition = begin.location.begin; nextLexeme(); bool seenPack = false; @@ -2822,6 +3385,7 @@ std::pair, AstArray> Parser::parseG { seenPack = true; + Position ellipsisPosition = lexer.current().location.begin; if (lexer.current().type != Lexeme::Dot3) report(lexer.current().location, "Generic types come before generic type packs"); else @@ -2830,22 +3394,43 @@ std::pair, AstArray> Parser::parseG if (withDefaultValues && lexer.current().type == '=') { seenDefault = true; + Position equalsPosition = lexer.current().location.begin; nextLexeme(); if (shouldParseTypePack(lexer)) { AstTypePack* typePack = parseTypePack(); - namePacks.push_back({name, nameLocation, typePack}); + if (FFlag::LuauStoreCSTData) + { + AstGenericTypePack* node = allocator.alloc(nameLocation, name, typePack); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(ellipsisPosition, equalsPosition); + namePacks.push_back(node); + } + else + { + namePacks.push_back(allocator.alloc(nameLocation, name, typePack)); + } } else { - auto [type, typePack] = parseTypeOrPack(); + auto [type, typePack] = parseSimpleTypeOrPack(); if (type) report(type->location, "Expected type pack after '=', got type"); - namePacks.push_back({name, nameLocation, typePack}); + if (FFlag::LuauStoreCSTData) + { + AstGenericTypePack* node = allocator.alloc(nameLocation, name, typePack); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(ellipsisPosition, equalsPosition); + namePacks.push_back(node); + } + else + { + namePacks.push_back(allocator.alloc(nameLocation, name, typePack)); + } } } else @@ -2853,7 +3438,17 @@ std::pair, AstArray> Parser::parseG if (seenDefault) report(lexer.current().location, "Expected default type pack after type pack name"); - namePacks.push_back({name, nameLocation, nullptr}); + if (FFlag::LuauStoreCSTData) + { + AstGenericTypePack* node = allocator.alloc(nameLocation, name, nullptr); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(ellipsisPosition, std::nullopt); + namePacks.push_back(node); + } + else + { + namePacks.push_back(allocator.alloc(nameLocation, name, nullptr)); + } } } else @@ -2861,23 +3456,46 @@ std::pair, AstArray> Parser::parseG if (withDefaultValues && lexer.current().type == '=') { seenDefault = true; + Position equalsPosition = lexer.current().location.begin; nextLexeme(); AstType* defaultType = parseType(); - names.push_back({name, nameLocation, defaultType}); + if (FFlag::LuauStoreCSTData) + { + AstGenericType* node = allocator.alloc(nameLocation, name, defaultType); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(equalsPosition); + names.push_back(node); + } + else + { + names.push_back(allocator.alloc(nameLocation, name, defaultType)); + } } else { if (seenDefault) report(lexer.current().location, "Expected default type after type name"); - names.push_back({name, nameLocation, nullptr}); + if (FFlag::LuauStoreCSTData) + { + AstGenericType* node = allocator.alloc(nameLocation, name, nullptr); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(std::nullopt); + names.push_back(node); + } + else + { + names.push_back(allocator.alloc(nameLocation, name, nullptr)); + } } } if (lexer.current().type == ',') { + if (FFlag::LuauStoreCSTData && commaPositions) + commaPositions->push_back(lexer.current().location.begin); nextLexeme(); if (lexer.current().type == '>') @@ -2890,21 +3508,25 @@ std::pair, AstArray> Parser::parseG break; } + if (FFlag::LuauStoreCSTData && closePosition) + *closePosition = lexer.current().location.begin; expectMatchAndConsume('>', begin); } - AstArray generics = copy(names); - AstArray genericPacks = copy(namePacks); + AstArray generics = copy(names); + AstArray genericPacks = copy(namePacks); return {generics, genericPacks}; } -AstArray Parser::parseTypeParams() +AstArray Parser::parseTypeParams(Position* openingPosition, TempVector* commaPositions, Position* closingPosition) { TempVector parameters{scratchTypeOrPack}; if (lexer.current().type == '<') { Lexeme begin = lexer.current(); + if (FFlag::LuauStoreCSTData && openingPosition) + *openingPosition = begin.location.begin; nextLexeme(); while (true) @@ -2912,17 +3534,83 @@ AstArray Parser::parseTypeParams() if (shouldParseTypePack(lexer)) { AstTypePack* typePack = parseTypePack(); - parameters.push_back({{}, typePack}); } else if (lexer.current().type == '(') { - auto [type, typePack] = parseTypeOrPack(); + if (FFlag::LuauAllowComplexTypesInGenericParams) + { + Location begin = lexer.current().location; + AstType* type = nullptr; + AstTypePack* typePack = nullptr; + Lexeme::Type c = lexer.current().type; - if (typePack) - parameters.push_back({{}, typePack}); + if (c != '|' && c != '&') + { + auto typeOrTypePack = parseSimpleType(/* allowPack */ true, /* inDeclarationContext */ false); + type = typeOrTypePack.type; + typePack = typeOrTypePack.typePack; + } + + // Consider the following type: + // + // X<(T)> + // + // Is this a type pack or a parenthesized type? The + // assumption will be a type pack, as that's what allows one + // to express either a singular type pack or a potential + // complex type. + + if (typePack) + { + auto explicitTypePack = typePack->as(); + if (explicitTypePack && explicitTypePack->typeList.tailType == nullptr && explicitTypePack->typeList.types.size == 1 && + isTypeFollow(lexer.current().type)) + { + // If we parsed an explicit type pack with a single + // type in it (something of the form `(T)`), and + // the next lexeme is one that follows a type + // (&, |, ?), then assume that this was actually a + // parenthesized type. + if (FFlag::LuauAstTypeGroup2) + { + auto parenthesizedType = explicitTypePack->typeList.types.data[0]; + parameters.push_back( + {parseTypeSuffix(allocator.alloc(parenthesizedType->location, parenthesizedType), begin), {}} + ); + } + else + parameters.push_back({parseTypeSuffix(explicitTypePack->typeList.types.data[0], begin), {}}); + } + else + { + // Otherwise, it's a type pack. + parameters.push_back({{}, typePack}); + } + } + else + { + // There's two cases in which `typePack` will be null: + // - We try to parse a simple type or a type pack, and + // we get a simple type: there's no ambiguity and + // we attempt to parse a complex type. + // - The next lexeme was a `|` or `&` indicating a + // union or intersection type with a leading + // separator. We just fall right into + // `parseTypeSuffix`, which allows its first + // argument to be `nullptr` + parameters.push_back({parseTypeSuffix(type, begin), {}}); + } + } else - parameters.push_back({type, {}}); + { + auto [type, typePack] = parseSimpleTypeOrPack(); + + if (typePack) + parameters.push_back({{}, typePack}); + else + parameters.push_back({type, {}}); + } } else if (lexer.current().type == '>' && parameters.empty()) { @@ -2934,18 +3622,24 @@ AstArray Parser::parseTypeParams() } if (lexer.current().type == ',') + { + if (FFlag::LuauStoreCSTData && commaPositions) + commaPositions->push_back(lexer.current().location.begin); nextLexeme(); + } else break; } + if (FFlag::LuauStoreCSTData && closingPosition) + *closingPosition = lexer.current().location.begin; expectMatchAndConsume('>', begin); } return copy(parameters); } -std::optional> Parser::parseCharArray() +std::optional> Parser::parseCharArray(AstArray* originalString) { LUAU_ASSERT( lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || @@ -2953,6 +3647,11 @@ std::optional> Parser::parseCharArray() ); scratchData.assign(lexer.current().data, lexer.current().getLength()); + if (FFlag::LuauStoreCSTData) + { + if (originalString) + *originalString = copy(scratchData); + } if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { @@ -2975,15 +3674,53 @@ std::optional> Parser::parseCharArray() AstExpr* Parser::parseString() { Location location = lexer.current().location; - if (std::optional> value = parseCharArray()) - return allocator.alloc(location, *value); + + AstExprConstantString::QuoteStyle style; + switch (lexer.current().type) + { + case Lexeme::QuotedString: + case Lexeme::InterpStringSimple: + style = AstExprConstantString::QuotedSimple; + break; + case Lexeme::RawString: + style = AstExprConstantString::QuotedRaw; + break; + default: + LUAU_ASSERT(false && "Invalid string type"); + } + + if (FFlag::LuauStoreCSTData) + { + CstExprConstantString::QuoteStyle fullStyle; + unsigned int blockDepth; + if (options.storeCstData) + std::tie(fullStyle, blockDepth) = extractStringDetails(); + + AstArray originalString; + if (std::optional> value = parseCharArray(options.storeCstData ? &originalString : nullptr)) + { + AstExprConstantString* node = allocator.alloc(location, *value, style); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(originalString, fullStyle, blockDepth); + return node; + } + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); + } else - return reportExprError(location, {}, "String literal contains malformed escape sequence"); + { + if (std::optional> value = parseCharArray()) + return allocator.alloc(location, *value, style); + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); + } } AstExpr* Parser::parseInterpString() { TempVector> strings(scratchString); + TempVector> sourceStrings(scratchString2); + TempVector stringPositions(scratchPosition); TempVector expressions(scratchExpr); Location startLocation = lexer.current().location; @@ -3001,6 +3738,12 @@ AstExpr* Parser::parseInterpString() scratchData.assign(currentLexeme.data, currentLexeme.getLength()); + if (FFlag::LuauStoreCSTData && options.storeCstData) + { + sourceStrings.push_back(copy(scratchData)); + stringPositions.push_back(currentLexeme.location.begin); + } + if (!Lexer::fixupQuotedString(scratchData)) { nextLexeme(); @@ -3065,7 +3808,15 @@ AstExpr* Parser::parseInterpString() AstArray> stringsArray = copy(strings); AstArray expressionsArray = copy(expressions); - return allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); + if (FFlag::LuauStoreCSTData) + { + AstExprInterpString* node = allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(copy(sourceStrings), copy(stringPositions)); + return node; + } + else + return allocator.alloc(Location{startLocation, endLocation}, stringsArray, expressionsArray); } AstExpr* Parser::parseNumber() @@ -3073,6 +3824,9 @@ AstExpr* Parser::parseNumber() Location start = lexer.current().location; scratchData.assign(lexer.current().data, lexer.current().getLength()); + AstArray sourceData; + if (FFlag::LuauStoreCSTData && options.storeCstData) + sourceData = copy(scratchData); // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al if (scratchData.find('_') != std::string::npos) @@ -3087,7 +3841,17 @@ AstExpr* Parser::parseNumber() if (result == ConstantNumberParseResult::Malformed) return reportExprError(start, {}, "Malformed number"); - return allocator.alloc(start, value, result); + if (FFlag::LuauStoreCSTData) + { + AstExprConstantNumber* node = allocator.alloc(start, value, result); + if (options.storeCstData) + cstNodeMap[node] = allocator.alloc(sourceData); + return node; + } + else + { + return allocator.alloc(start, value, result); + } } AstLocal* Parser::pushLocal(const Binding& binding) @@ -3364,7 +4128,7 @@ void Parser::report(const Location& location, const char* format, va_list args) parseErrors.emplace_back(location, message); - if (parseErrors.size() >= unsigned(FInt::LuauParseErrorLimit)) + if (parseErrors.size() >= unsigned(FInt::LuauParseErrorLimit) && (!FFlag::ParserNoErrorLimit || !options.noErrorLimit)) ParseError::raise(location, "Reached error limit (%d)", int(FInt::LuauParseErrorLimit)); } diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index e8be59eb..24bc707f 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -3,6 +3,7 @@ #include "Luau/StringUtils.h" +#include #include #include @@ -25,7 +26,7 @@ #include -LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false) +LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing) namespace Luau { namespace TimeTrace diff --git a/CLI/Require.cpp b/CLI/Require.cpp deleted file mode 100644 index b6753e96..00000000 --- a/CLI/Require.cpp +++ /dev/null @@ -1,306 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Require.h" - -#include "FileUtils.h" -#include "Luau/Common.h" - -#include -#include -#include - -RequireResolver::RequireResolver(lua_State* L, std::string path) - : pathToResolve(std::move(path)) - , L(L) -{ - lua_Debug ar; - lua_getinfo(L, 1, "s", &ar); - sourceChunkname = ar.source; - - if (!isRequireAllowed(sourceChunkname)) - luaL_errorL(L, "require is not supported in this context"); - - if (isAbsolutePath(pathToResolve)) - luaL_argerrorL(L, 1, "cannot require an absolute path"); - - std::replace(pathToResolve.begin(), pathToResolve.end(), '\\', '/'); - - substituteAliasIfPresent(pathToResolve); -} - -[[nodiscard]] RequireResolver::ResolvedRequire RequireResolver::resolveRequire(lua_State* L, std::string path) -{ - RequireResolver resolver(L, std::move(path)); - ModuleStatus status = resolver.findModule(); - if (status != ModuleStatus::FileRead) - return ResolvedRequire{status}; - else - return ResolvedRequire{status, std::move(resolver.chunkname), std::move(resolver.absolutePath), std::move(resolver.sourceCode)}; -} - -RequireResolver::ModuleStatus RequireResolver::findModule() -{ - resolveAndStoreDefaultPaths(); - - // Put _MODULES table on stack for checking and saving to the cache - luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); - - RequireResolver::ModuleStatus moduleStatus = findModuleImpl(); - - if (moduleStatus != RequireResolver::ModuleStatus::NotFound) - return moduleStatus; - - if (!shouldSearchPathsArray()) - return moduleStatus; - - if (!isConfigFullyResolved) - parseNextConfig(); - - // Index-based iteration because std::iterator may be invalidated if config.paths is reallocated - for (size_t i = 0; i < config.paths.size(); ++i) - { - // "placeholder" acts as a requiring file in the relevant directory - std::optional absolutePathOpt = resolvePath(pathToResolve, joinPaths(config.paths[i], "placeholder")); - - if (!absolutePathOpt) - luaL_errorL(L, "error requiring module"); - - chunkname = *absolutePathOpt; - absolutePath = *absolutePathOpt; - - moduleStatus = findModuleImpl(); - - if (moduleStatus != RequireResolver::ModuleStatus::NotFound) - return moduleStatus; - - // Before finishing the loop, parse more config files if there are any - if (i == config.paths.size() - 1 && !isConfigFullyResolved) - parseNextConfig(); // could reallocate config.paths when paths are parsed and added - } - - return RequireResolver::ModuleStatus::NotFound; -} - -RequireResolver::ModuleStatus RequireResolver::findModuleImpl() -{ - static const std::array possibleSuffixes = {".luau", ".lua", "/init.luau", "/init.lua"}; - - size_t unsuffixedAbsolutePathSize = absolutePath.size(); - - for (const char* possibleSuffix : possibleSuffixes) - { - absolutePath += possibleSuffix; - - // Check cache for module - lua_getfield(L, -1, absolutePath.c_str()); - if (!lua_isnil(L, -1)) - { - return ModuleStatus::Cached; - } - lua_pop(L, 1); - - // Try to read the matching file - std::optional source = readFile(absolutePath); - if (source) - { - chunkname = "=" + chunkname + possibleSuffix; - sourceCode = *source; - return ModuleStatus::FileRead; - } - - absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix - } - - return ModuleStatus::NotFound; -} - -bool RequireResolver::isRequireAllowed(std::string_view sourceChunkname) -{ - LUAU_ASSERT(!sourceChunkname.empty()); - return (sourceChunkname[0] == '=' || sourceChunkname[0] == '@'); -} - -bool RequireResolver::shouldSearchPathsArray() -{ - return !isAbsolutePath(pathToResolve) && !isExplicitlyRelative(pathToResolve); -} - -void RequireResolver::resolveAndStoreDefaultPaths() -{ - if (!isAbsolutePath(pathToResolve)) - { - std::string chunknameContext = getRequiringContextRelative(); - std::optional absolutePathContext = getRequiringContextAbsolute(); - - if (!absolutePathContext) - luaL_errorL(L, "error requiring module"); - - // resolvePath automatically sanitizes/normalizes the paths - std::optional chunknameOpt = resolvePath(pathToResolve, chunknameContext); - std::optional absolutePathOpt = resolvePath(pathToResolve, *absolutePathContext); - - if (!chunknameOpt || !absolutePathOpt) - luaL_errorL(L, "error requiring module"); - - chunkname = std::move(*chunknameOpt); - absolutePath = std::move(*absolutePathOpt); - } - else - { - // Here we must explicitly sanitize, as the path is taken as is - std::optional sanitizedPath = normalizePath(pathToResolve); - if (!sanitizedPath) - luaL_errorL(L, "error requiring module"); - - chunkname = *sanitizedPath; - absolutePath = std::move(*sanitizedPath); - } -} - -std::optional RequireResolver::getRequiringContextAbsolute() -{ - std::string requiringFile; - if (isAbsolutePath(sourceChunkname.substr(1))) - { - // We already have an absolute path for the requiring file - requiringFile = sourceChunkname.substr(1); - } - else - { - // Requiring file's stored path is relative to the CWD, must make absolute - std::optional cwd = getCurrentWorkingDirectory(); - if (!cwd) - return std::nullopt; - - if (sourceChunkname.substr(1) == "stdin") - { - // Require statement is being executed from REPL input prompt - // The requiring context is the pseudo-file "stdin" in the CWD - requiringFile = joinPaths(*cwd, "stdin"); - } - else - { - // Require statement is being executed in a file, must resolve relative to CWD - std::optional requiringFileOpt = resolvePath(sourceChunkname.substr(1), joinPaths(*cwd, "stdin")); - if (!requiringFileOpt) - return std::nullopt; - - requiringFile = *requiringFileOpt; - } - } - std::replace(requiringFile.begin(), requiringFile.end(), '\\', '/'); - return requiringFile; -} - -std::string RequireResolver::getRequiringContextRelative() -{ - std::string baseFilePath; - if (sourceChunkname.substr(1) != "stdin") - baseFilePath = sourceChunkname.substr(1); - - return baseFilePath; -} - -void RequireResolver::substituteAliasIfPresent(std::string& path) -{ - if (path.size() < 1 || path[0] != '@') - return; - - // To ignore the '@' alias prefix when processing the alias - const size_t aliasStartPos = 1; - - // If a directory separator was found, the length of the alias is the - // distance between the start of the alias and the separator. Otherwise, - // the whole string after the alias symbol is the alias. - size_t aliasLen = path.find_first_of("\\/"); - if (aliasLen != std::string::npos) - aliasLen -= aliasStartPos; - - const std::string potentialAlias = path.substr(aliasStartPos, aliasLen); - - // Not worth searching when potentialAlias cannot be an alias - if (!Luau::isValidAlias(potentialAlias)) - luaL_errorL(L, "@%s is not a valid alias", potentialAlias.c_str()); - - std::optional alias = getAlias(potentialAlias); - if (alias) - { - path = *alias + path.substr(potentialAlias.size() + 1); - } - else - { - luaL_errorL(L, "@%s is not a valid alias", potentialAlias.c_str()); - } -} - -std::optional RequireResolver::getAlias(std::string alias) -{ - std::transform( - alias.begin(), - alias.end(), - alias.begin(), - [](unsigned char c) - { - return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; - } - ); - while (!config.aliases.count(alias) && !isConfigFullyResolved) - { - parseNextConfig(); - } - if (!config.aliases.count(alias) && isConfigFullyResolved) - return std::nullopt; // could not find alias - - return resolvePath(config.aliases[alias], joinPaths(lastSearchedDir, Luau::kConfigName)); -} - -void RequireResolver::parseNextConfig() -{ - if (isConfigFullyResolved) - return; // no config files left to parse - - std::optional directory; - if (lastSearchedDir.empty()) - { - std::optional requiringFile = getRequiringContextAbsolute(); - if (!requiringFile) - luaL_errorL(L, "error requiring module"); - - directory = getParentPath(*requiringFile); - } - else - directory = getParentPath(lastSearchedDir); - - if (directory) - { - lastSearchedDir = *directory; - parseConfigInDirectory(*directory); - } - else - isConfigFullyResolved = true; -} - -void RequireResolver::parseConfigInDirectory(const std::string& directory) -{ - std::string configPath = joinPaths(directory, Luau::kConfigName); - - size_t numPaths = config.paths.size(); - - if (std::optional contents = readFile(configPath)) - { - std::optional error = Luau::parseConfig(*contents, config); - if (error) - luaL_errorL(L, "error parsing %s (%s)", configPath.c_str(), (*error).c_str()); - } - - // Resolve any newly obtained relative paths in "paths" in relation to configPath - for (auto it = config.paths.begin() + numPaths; it != config.paths.end(); ++it) - { - if (!isAbsolutePath(*it)) - { - if (std::optional resolvedPath = resolvePath(*it, configPath)) - *it = std::move(*resolvedPath); - else - luaL_errorL(L, "error requiring module"); - } - } -} diff --git a/CLI/Require.h b/CLI/Require.h deleted file mode 100644 index ae96834f..00000000 --- a/CLI/Require.h +++ /dev/null @@ -1,62 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#pragma once - -#include "lua.h" -#include "lualib.h" - -#include "Luau/Config.h" - -#include -#include - -class RequireResolver -{ -public: - std::string chunkname; - std::string absolutePath; - std::string sourceCode; - - enum class ModuleStatus - { - Cached, - FileRead, - NotFound - }; - - struct ResolvedRequire - { - ModuleStatus status; - std::string chunkName; - std::string absolutePath; - std::string sourceCode; - }; - - [[nodiscard]] ResolvedRequire static resolveRequire(lua_State* L, std::string path); - -private: - std::string pathToResolve; - std::string_view sourceChunkname; - - RequireResolver(lua_State* L, std::string path); - - ModuleStatus findModule(); - lua_State* L; - Luau::Config config; - std::string lastSearchedDir; - bool isConfigFullyResolved = false; - - bool isRequireAllowed(std::string_view sourceChunkname); - bool shouldSearchPathsArray(); - - void resolveAndStoreDefaultPaths(); - ModuleStatus findModuleImpl(); - - std::optional getRequiringContextAbsolute(); - std::string getRequiringContextRelative(); - - void substituteAliasIfPresent(std::string& path); - std::optional getAlias(std::string alias); - - void parseNextConfig(); - void parseConfigInDirectory(const std::string& path); -}; diff --git a/CLI/Coverage.h b/CLI/include/Luau/Coverage.h similarity index 100% rename from CLI/Coverage.h rename to CLI/include/Luau/Coverage.h diff --git a/CLI/FileUtils.h b/CLI/include/Luau/FileUtils.h similarity index 68% rename from CLI/FileUtils.h rename to CLI/include/Luau/FileUtils.h index 2004a2eb..80e36378 100644 --- a/CLI/FileUtils.h +++ b/CLI/include/Luau/FileUtils.h @@ -10,18 +10,20 @@ std::optional getCurrentWorkingDirectory(); std::string normalizePath(std::string_view path); -std::string resolvePath(std::string_view relativePath, std::string_view baseFilePath); +std::optional resolvePath(std::string_view relativePath, std::string_view baseFilePath); std::optional readFile(const std::string& name); std::optional readStdin(); +bool hasFileExtension(std::string_view name, const std::vector& extensions); + bool isAbsolutePath(std::string_view path); -bool isExplicitlyRelative(std::string_view path); +bool isFile(const std::string& path); bool isDirectory(const std::string& path); bool traverseDirectory(const std::string& path, const std::function& callback); std::vector splitPath(std::string_view path); -std::string joinPaths(const std::string& lhs, const std::string& rhs); -std::optional getParentPath(const std::string& path); +std::string joinPaths(std::string_view lhs, std::string_view rhs); +std::optional getParentPath(std::string_view path); std::vector getSourceFiles(int argc, char** argv); diff --git a/CLI/Flags.h b/CLI/include/Luau/Flags.h similarity index 100% rename from CLI/Flags.h rename to CLI/include/Luau/Flags.h diff --git a/CLI/Profiler.h b/CLI/include/Luau/Profiler.h similarity index 100% rename from CLI/Profiler.h rename to CLI/include/Luau/Profiler.h diff --git a/CLI/Repl.h b/CLI/include/Luau/Repl.h similarity index 100% rename from CLI/Repl.h rename to CLI/include/Luau/Repl.h diff --git a/CLI/include/Luau/Require.h b/CLI/include/Luau/Require.h new file mode 100644 index 00000000..e4fc019a --- /dev/null +++ b/CLI/include/Luau/Require.h @@ -0,0 +1,84 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Config.h" + +#include +#include +#include + +class RequireResolver +{ +public: + enum class ModuleStatus + { + Cached, + FileRead, + ErrorReported + }; + + struct ResolvedRequire + { + ModuleStatus status; + std::string identifier; + std::string absolutePath; + std::string sourceCode; + }; + + struct RequireContext + { + virtual ~RequireContext() = default; + virtual std::string getPath() = 0; + virtual bool isRequireAllowed() = 0; + virtual bool isStdin() = 0; + virtual std::string createNewIdentifer(const std::string& path) = 0; + }; + + struct CacheManager + { + virtual ~CacheManager() = default; + virtual bool isCached(const std::string& path) + { + return false; + } + }; + + struct ErrorHandler + { + virtual ~ErrorHandler() = default; + virtual void reportError(const std::string message) {} + }; + + RequireResolver(std::string pathToResolve, RequireContext& requireContext, CacheManager& cacheManager, ErrorHandler& errorHandler); + + [[nodiscard]] ResolvedRequire resolveRequire(std::function completionCallback = nullptr); + +private: + std::string pathToResolve; + + RequireContext& requireContext; + CacheManager& cacheManager; + ErrorHandler& errorHandler; + + ResolvedRequire resolvedRequire; + bool isRequireResolved = false; + + Luau::Config config; + std::string lastSearchedDir; + bool isConfigFullyResolved = false; + + [[nodiscard]] bool initialize(); + + ModuleStatus findModule(); + ModuleStatus findModuleImpl(); + + [[nodiscard]] bool resolveAndStoreDefaultPaths(); + std::optional getRequiringContextAbsolute(); + std::string getRequiringContextRelative(); + + [[nodiscard]] bool substituteAliasIfPresent(std::string& path); + std::optional getAlias(std::string alias); + + [[nodiscard]] bool parseNextConfig(); + [[nodiscard]] bool parseConfigInDirectory(const std::string& directory); +}; \ No newline at end of file diff --git a/CLI/Analyze.cpp b/CLI/src/Analyze.cpp similarity index 88% rename from CLI/Analyze.cpp rename to CLI/src/Analyze.cpp index be1f23f0..e10a2c2e 100644 --- a/CLI/Analyze.cpp +++ b/CLI/src/Analyze.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Config.h" #include "Luau/ModuleResolver.h" #include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" @@ -6,8 +7,9 @@ #include "Luau/TypeAttach.h" #include "Luau/Transpiler.h" -#include "FileUtils.h" -#include "Flags.h" +#include "Luau/FileUtils.h" +#include "Luau/Flags.h" +#include "Luau/Require.h" #include #include @@ -169,14 +171,17 @@ struct CliFileResolver : Luau::FileResolver { if (Luau::AstExprConstantString* expr = node->as()) { - Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau"; - if (!readFile(name)) - { - // fall back to .lua if a module with .luau doesn't exist - name = std::string(expr->value.data, expr->value.size) + ".lua"; - } + std::string path{expr->value.data, expr->value.size}; - return {{name}}; + AnalysisRequireContext requireContext{context->name}; + AnalysisCacheManager cacheManager; + AnalysisErrorHandler errorHandler; + + RequireResolver resolver(path, requireContext, cacheManager, errorHandler); + RequireResolver::ResolvedRequire resolvedRequire = resolver.resolveRequire(); + + if (resolvedRequire.status == RequireResolver::ModuleStatus::FileRead) + return {{resolvedRequire.identifier}}; } return std::nullopt; @@ -188,6 +193,48 @@ struct CliFileResolver : Luau::FileResolver return "stdin"; return name; } + +private: + struct AnalysisRequireContext : RequireResolver::RequireContext + { + explicit AnalysisRequireContext(std::string path) + : path(std::move(path)) + { + } + + std::string getPath() override + { + return path; + } + + bool isRequireAllowed() override + { + return true; + } + + bool isStdin() override + { + return path == "-"; + } + + std::string createNewIdentifer(const std::string& path) override + { + return path; + } + + private: + std::string path; + }; + + struct AnalysisCacheManager : public RequireResolver::CacheManager + { + AnalysisCacheManager() = default; + }; + + struct AnalysisErrorHandler : RequireResolver::ErrorHandler + { + AnalysisErrorHandler() = default; + }; }; struct CliConfigResolver : Luau::ConfigResolver @@ -224,7 +271,14 @@ struct CliConfigResolver : Luau::ConfigResolver if (std::optional contents = readFile(configPath)) { - std::optional error = Luau::parseConfig(*contents, result); + Luau::ConfigOptions::AliasOptions aliasOpts; + aliasOpts.configLocation = configPath; + aliasOpts.overwriteAliases = true; + + Luau::ConfigOptions opts; + opts.aliasOptions = std::move(aliasOpts); + + std::optional error = Luau::parseConfig(*contents, result, opts); if (error) configErrors.push_back({configPath, *error}); } diff --git a/CLI/Ast.cpp b/CLI/src/Ast.cpp similarity index 98% rename from CLI/Ast.cpp rename to CLI/src/Ast.cpp index b5a922aa..5341d889 100644 --- a/CLI/Ast.cpp +++ b/CLI/src/Ast.cpp @@ -8,7 +8,7 @@ #include "Luau/ParseOptions.h" #include "Luau/ToString.h" -#include "FileUtils.h" +#include "Luau/FileUtils.h" static void displayHelp(const char* argv0) { diff --git a/CLI/Bytecode.cpp b/CLI/src/Bytecode.cpp similarity index 99% rename from CLI/Bytecode.cpp rename to CLI/src/Bytecode.cpp index 2da9570b..dc8e4833 100644 --- a/CLI/Bytecode.cpp +++ b/CLI/src/Bytecode.cpp @@ -7,8 +7,8 @@ #include "Luau/BytecodeBuilder.h" #include "Luau/Parser.h" #include "Luau/BytecodeSummary.h" -#include "FileUtils.h" -#include "Flags.h" +#include "Luau/FileUtils.h" +#include "Luau/Flags.h" #include diff --git a/CLI/Compile.cpp b/CLI/src/Compile.cpp similarity index 99% rename from CLI/Compile.cpp rename to CLI/src/Compile.cpp index 7d95387c..6f41b42d 100644 --- a/CLI/Compile.cpp +++ b/CLI/src/Compile.cpp @@ -8,8 +8,8 @@ #include "Luau/Parser.h" #include "Luau/TimeTrace.h" -#include "FileUtils.h" -#include "Flags.h" +#include "Luau/FileUtils.h" +#include "Luau/Flags.h" #include @@ -341,7 +341,8 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Remarks); bcb.setDumpSource(*source); } - else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || format == CompileFormat::CodegenVerbose) + else if (format == CompileFormat::Codegen || format == CompileFormat::CodegenAsm || format == CompileFormat::CodegenIr || + format == CompileFormat::CodegenVerbose) { bcb.setDumpFlags( Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals | diff --git a/CLI/Coverage.cpp b/CLI/src/Coverage.cpp similarity index 98% rename from CLI/Coverage.cpp rename to CLI/src/Coverage.cpp index a509ab89..7330d492 100644 --- a/CLI/Coverage.cpp +++ b/CLI/src/Coverage.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Coverage.h" +#include "Luau/Coverage.h" #include "lua.h" diff --git a/CLI/FileUtils.cpp b/CLI/src/FileUtils.cpp similarity index 70% rename from CLI/FileUtils.cpp rename to CLI/src/FileUtils.cpp index daa7c295..d54d94e0 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/src/FileUtils.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "FileUtils.h" +#include "Luau/FileUtils.h" #include "Luau/Common.h" @@ -20,6 +20,7 @@ #endif #include +#include #ifdef _WIN32 static std::wstring fromUtf8(const std::string& path) @@ -57,12 +58,6 @@ bool isAbsolutePath(std::string_view path) #endif } -bool isExplicitlyRelative(std::string_view path) -{ - return (path == ".") || (path == "..") || (path.size() >= 2 && path[0] == '.' && path[1] == '/') || - (path.size() >= 3 && path[0] == '.' && path[1] == '.' && path[2] == '/'); -} - std::optional getCurrentWorkingDirectory() { // 2^17 - derived from the Windows path length limit @@ -96,95 +91,86 @@ std::optional getCurrentWorkingDirectory() return std::nullopt; } -// Returns the normal/canonical form of a path (e.g. "../subfolder/../module.luau" -> "../module.luau") std::string normalizePath(std::string_view path) { - return resolvePath(path, ""); -} + const std::vector components = splitPath(path); + std::vector normalizedComponents; -// Takes a path that is relative to the file at baseFilePath and returns the path explicitly rebased onto baseFilePath. -// For absolute paths, baseFilePath will be ignored, and this function will resolve the path to a canonical path: -// (e.g. "/Users/.././Users/johndoe" -> "/Users/johndoe"). -std::string resolvePath(std::string_view path, std::string_view baseFilePath) -{ - std::vector pathComponents; - std::vector baseFilePathComponents; + const bool isAbsolute = isAbsolutePath(path); - // Dependent on whether the final resolved path is absolute or relative - // - if relative (when path and baseFilePath are both relative), resolvedPathPrefix remains empty - // - if absolute (if either path or baseFilePath are absolute), resolvedPathPrefix is "C:\", "/", etc. - std::string resolvedPathPrefix; - - if (isAbsolutePath(path)) - { - // path is absolute, we use path's prefix and ignore baseFilePath - size_t afterPrefix = path.find_first_of("\\/") + 1; - resolvedPathPrefix = path.substr(0, afterPrefix); - pathComponents = splitPath(path.substr(afterPrefix)); - } - else - { - pathComponents = splitPath(path); - if (isAbsolutePath(baseFilePath)) - { - // path is relative and baseFilePath is absolute, we use baseFilePath's prefix - size_t afterPrefix = baseFilePath.find_first_of("\\/") + 1; - resolvedPathPrefix = baseFilePath.substr(0, afterPrefix); - baseFilePathComponents = splitPath(baseFilePath.substr(afterPrefix)); - } - else - { - // path and baseFilePath are both relative, we do not set a prefix (resolved path will be relative) - baseFilePathComponents = splitPath(baseFilePath); - } - } - - // Remove filename from components - if (!baseFilePathComponents.empty()) - baseFilePathComponents.pop_back(); - - // Resolve the path by applying pathComponents to baseFilePathComponents - int numPrependedParents = 0; - for (std::string_view component : pathComponents) + // 1. Normalize path components + const size_t startIndex = isAbsolute ? 1 : 0; + for (size_t i = startIndex; i < components.size(); i++) { + std::string_view component = components[i]; if (component == "..") { - if (baseFilePathComponents.empty()) + if (normalizedComponents.empty()) { - if (resolvedPathPrefix.empty()) // only when final resolved path will be relative - numPrependedParents++; // "../" will later be added to the beginning of the resolved path + if (!isAbsolute) + { + normalizedComponents.emplace_back(".."); + } } - else if (baseFilePathComponents.back() != "..") + else if (normalizedComponents.back() == "..") { - baseFilePathComponents.pop_back(); // Resolve cases like "folder/subfolder/../../file" to "file" + normalizedComponents.emplace_back(".."); + } + else + { + normalizedComponents.pop_back(); } } - else if (component != "." && !component.empty()) + else if (!component.empty() && component != ".") { - baseFilePathComponents.push_back(component); + normalizedComponents.emplace_back(component); } } - // Join baseFilePathComponents to form the resolved path - std::string resolvedPath = resolvedPathPrefix; - // Only when resolvedPath will be relative - for (int i = 0; i < numPrependedParents; i++) - { - resolvedPath += "../"; - } - for (auto iter = baseFilePathComponents.begin(); iter != baseFilePathComponents.end(); ++iter) - { - if (iter != baseFilePathComponents.begin()) - resolvedPath += "/"; + std::string normalizedPath; - resolvedPath += *iter; - } - if (resolvedPath.size() > resolvedPathPrefix.size() && resolvedPath.back() == '/') + // 2. Add correct prefix to formatted path + if (isAbsolute) { - // Remove trailing '/' if present - resolvedPath.pop_back(); + normalizedPath += components[0]; + normalizedPath += "/"; } - return resolvedPath; + else if (normalizedComponents.empty() || normalizedComponents[0] != "..") + { + normalizedPath += "./"; + } + + // 3. Join path components to form the normalized path + for (auto iter = normalizedComponents.begin(); iter != normalizedComponents.end(); ++iter) + { + if (iter != normalizedComponents.begin()) + normalizedPath += "/"; + + normalizedPath += *iter; + } + if (normalizedPath.size() >= 2 && normalizedPath[normalizedPath.size() - 1] == '.' && normalizedPath[normalizedPath.size() - 2] == '.') + normalizedPath += "/"; + + return normalizedPath; +} + +std::optional resolvePath(std::string_view path, std::string_view baseFilePath) +{ + std::optional baseFilePathParent = getParentPath(baseFilePath); + if (!baseFilePathParent) + return std::nullopt; + + return normalizePath(joinPaths(*baseFilePathParent, path)); +} + +bool hasFileExtension(std::string_view name, const std::vector& extensions) +{ + for (const std::string& extension : extensions) + { + if (name.size() >= extension.size() && name.substr(name.size() - extension.size()) == extension) + return true; + } + return false; } std::optional readFile(const std::string& name) @@ -353,6 +339,20 @@ bool traverseDirectory(const std::string& path, const std::function splitPath(std::string_view path) return components; } -std::string joinPaths(const std::string& lhs, const std::string& rhs) +std::string joinPaths(std::string_view lhs, std::string_view rhs) { - std::string result = lhs; + std::string result = std::string(lhs); if (!result.empty() && result.back() != '/' && result.back() != '\\') result += '/'; result += rhs; return result; } -std::optional getParentPath(const std::string& path) +std::optional getParentPath(std::string_view path) { if (path == "" || path == "." || path == "/") return std::nullopt; @@ -410,7 +410,7 @@ std::optional getParentPath(const std::string& path) return "/"; if (slash != std::string::npos) - return path.substr(0, slash); + return std::string(path.substr(0, slash)); return ""; } @@ -440,10 +440,12 @@ std::vector getSourceFiles(int argc, char** argv) if (argv[i][0] == '-' && argv[i][1] != '\0') continue; - if (isDirectory(argv[i])) + std::string normalized = normalizePath(argv[i]); + + if (isDirectory(normalized)) { traverseDirectory( - argv[i], + normalized, [&](const std::string& name) { std::string ext = getExtension(name); @@ -455,7 +457,7 @@ std::vector getSourceFiles(int argc, char** argv) } else { - files.push_back(argv[i]); + files.push_back(normalized); } } diff --git a/CLI/Flags.cpp b/CLI/src/Flags.cpp similarity index 100% rename from CLI/Flags.cpp rename to CLI/src/Flags.cpp diff --git a/CLI/Profiler.cpp b/CLI/src/Profiler.cpp similarity index 100% rename from CLI/Profiler.cpp rename to CLI/src/Profiler.cpp diff --git a/CLI/Reduce.cpp b/CLI/src/Reduce.cpp similarity index 99% rename from CLI/Reduce.cpp rename to CLI/src/Reduce.cpp index 7f8c459c..e66d80dc 100644 --- a/CLI/Reduce.cpp +++ b/CLI/src/Reduce.cpp @@ -5,7 +5,7 @@ #include "Luau/Parser.h" #include "Luau/Transpiler.h" -#include "FileUtils.h" +#include "Luau/FileUtils.h" #include #include diff --git a/CLI/Repl.cpp b/CLI/src/Repl.cpp similarity index 88% rename from CLI/Repl.cpp rename to CLI/src/Repl.cpp index b8e9d814..3e3ae182 100644 --- a/CLI/Repl.cpp +++ b/CLI/src/Repl.cpp @@ -1,5 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Repl.h" +#include "Luau/Repl.h" #include "Luau/Common.h" #include "lua.h" @@ -10,15 +10,17 @@ #include "Luau/Parser.h" #include "Luau/TimeTrace.h" -#include "Coverage.h" -#include "FileUtils.h" -#include "Flags.h" -#include "Profiler.h" -#include "Require.h" +#include "Luau/Coverage.h" +#include "Luau/FileUtils.h" +#include "Luau/Flags.h" +#include "Luau/Profiler.h" +#include "Luau/Require.h" #include "isocline.h" #include +#include +#include #ifdef _WIN32 #include @@ -119,16 +121,113 @@ static int finishrequire(lua_State* L) return 1; } +struct RuntimeRequireContext : public RequireResolver::RequireContext +{ + // In the context of the REPL, source is the calling context's chunkname. + // + // These chunknames have certain prefixes that indicate context. These + // are used when displaying debug information (see luaO_chunkid). + // + // Generally, the '@' prefix is used for filepaths, and the '=' prefix is + // used for custom chunknames, such as =stdin. + explicit RuntimeRequireContext(std::string source) + : source(std::move(source)) + { + } + + std::string getPath() override + { + return source.substr(1); + } + + bool isRequireAllowed() override + { + return isStdin() || (!source.empty() && source[0] == '@'); + } + + bool isStdin() override + { + return source == "=stdin"; + } + + std::string createNewIdentifer(const std::string& path) override + { + return "@" + path; + } + +private: + std::string source; +}; + +struct RuntimeCacheManager : public RequireResolver::CacheManager +{ + explicit RuntimeCacheManager(lua_State* L) + : L(L) + { + } + + bool isCached(const std::string& path) override + { + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + lua_getfield(L, -1, path.c_str()); + bool cached = !lua_isnil(L, -1); + lua_pop(L, 2); + + if (cached) + cacheKey = path; + + return cached; + } + + std::string cacheKey; + +private: + lua_State* L; +}; + +struct RuntimeErrorHandler : RequireResolver::ErrorHandler +{ + explicit RuntimeErrorHandler(lua_State* L) + : L(L) + { + } + + void reportError(const std::string message) override + { + luaL_errorL(L, "%s", message.c_str()); + } + +private: + lua_State* L; +}; + static int lua_require(lua_State* L) { std::string name = luaL_checkstring(L, 1); - RequireResolver::ResolvedRequire resolvedRequire = RequireResolver::resolveRequire(L, std::move(name)); + RequireResolver::ResolvedRequire resolvedRequire; + { + lua_Debug ar; + lua_getinfo(L, 1, "s", &ar); + + RuntimeRequireContext requireContext{ar.source}; + RuntimeCacheManager cacheManager{L}; + RuntimeErrorHandler errorHandler{L}; + + RequireResolver resolver(std::move(name), requireContext, cacheManager, errorHandler); + + resolvedRequire = resolver.resolveRequire( + [L, &cacheKey = cacheManager.cacheKey](const RequireResolver::ModuleStatus status) + { + lua_getfield(L, LUA_REGISTRYINDEX, "_MODULES"); + if (status == RequireResolver::ModuleStatus::Cached) + lua_getfield(L, -1, cacheKey.c_str()); + } + ); + } if (resolvedRequire.status == RequireResolver::ModuleStatus::Cached) return finishrequire(L); - else if (resolvedRequire.status == RequireResolver::ModuleStatus::NotFound) - luaL_errorL(L, "error requiring module"); // module needs to run in a new thread, isolated from the rest // note: we create ML on main thread so that it doesn't inherit environment of L @@ -141,7 +240,7 @@ static int lua_require(lua_State* L) // now we can compile & run module on the new thread std::string bytecode = Luau::compile(resolvedRequire.sourceCode, copts()); - if (luau_load(ML, resolvedRequire.chunkName.c_str(), bytecode.data(), bytecode.size(), 0) == 0) + if (luau_load(ML, resolvedRequire.identifier.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { if (codegen) { @@ -613,7 +712,7 @@ static bool runFile(const char* name, lua_State* GL, bool repl) // new thread needs to have the globals sandboxed luaL_sandboxthread(L); - std::string chunkname = "=" + std::string(name); + std::string chunkname = "@" + std::string(name); std::string bytecode = Luau::compile(*source, copts()); int status = 0; @@ -692,8 +791,6 @@ int replMain(int argc, char** argv) { Luau::assertHandler() = assertionHandler; - setLuauFlagsDefault(); - #ifdef _WIN32 SetConsoleOutputCP(CP_UTF8); #endif diff --git a/CLI/ReplEntry.cpp b/CLI/src/ReplEntry.cpp similarity index 71% rename from CLI/ReplEntry.cpp rename to CLI/src/ReplEntry.cpp index 8543e3f7..a69e4a37 100644 --- a/CLI/ReplEntry.cpp +++ b/CLI/src/ReplEntry.cpp @@ -1,7 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Repl.h" +#include "Luau/Repl.h" +#include "Luau/Flags.h" int main(int argc, char** argv) { + setLuauFlagsDefault(); + return replMain(argc, argv); } diff --git a/CLI/src/Require.cpp b/CLI/src/Require.cpp new file mode 100644 index 00000000..1039f85c --- /dev/null +++ b/CLI/src/Require.cpp @@ -0,0 +1,313 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Require.h" + +#include "Luau/FileUtils.h" +#include "Luau/Common.h" +#include "Luau/Config.h" + +#include +#include +#include + +static constexpr char kRequireErrorGeneric[] = "error requiring module"; + +RequireResolver::RequireResolver(std::string path, RequireContext& requireContext, CacheManager& cacheManager, ErrorHandler& errorHandler) + : pathToResolve(std::move(path)) + , requireContext(requireContext) + , cacheManager(cacheManager) + , errorHandler(errorHandler) +{ +} + +RequireResolver::ResolvedRequire RequireResolver::resolveRequire(std::function completionCallback) +{ + if (isRequireResolved) + { + errorHandler.reportError("require statement has already been resolved"); + return ResolvedRequire{ModuleStatus::ErrorReported}; + } + + if (!initialize()) + return ResolvedRequire{ModuleStatus::ErrorReported}; + + resolvedRequire.status = findModule(); + + if (completionCallback) + completionCallback(resolvedRequire.status); + + isRequireResolved = true; + return resolvedRequire; +} + +static bool hasValidPrefix(std::string_view path) +{ + return path.compare(0, 2, "./") == 0 || path.compare(0, 3, "../") == 0 || path.compare(0, 1, "@") == 0; +} + +static bool isPathAmbiguous(const std::string& path) +{ + bool found = false; + for (const char* suffix : {".luau", ".lua"}) + { + if (isFile(path + suffix)) + { + if (found) + return true; + else + found = true; + } + } + if (isDirectory(path) && found) + return true; + + return false; +} + +bool RequireResolver::initialize() +{ + if (!requireContext.isRequireAllowed()) + { + errorHandler.reportError("require is not supported in this context"); + return false; + } + + if (isAbsolutePath(pathToResolve)) + { + errorHandler.reportError("cannot require an absolute path"); + return false; + } + + std::replace(pathToResolve.begin(), pathToResolve.end(), '\\', '/'); + + if (!hasValidPrefix(pathToResolve)) + { + errorHandler.reportError("require path must start with a valid prefix: ./, ../, or @"); + return false; + } + + return substituteAliasIfPresent(pathToResolve); +} + +RequireResolver::ModuleStatus RequireResolver::findModule() +{ + if (!resolveAndStoreDefaultPaths()) + return ModuleStatus::ErrorReported; + + if (isPathAmbiguous(resolvedRequire.absolutePath)) + { + errorHandler.reportError("require path could not be resolved to a unique file"); + return ModuleStatus::ErrorReported; + } + + static constexpr std::array possibleSuffixes = {".luau", ".lua", "/init.luau", "/init.lua"}; + size_t unsuffixedAbsolutePathSize = resolvedRequire.absolutePath.size(); + + for (const char* possibleSuffix : possibleSuffixes) + { + resolvedRequire.absolutePath += possibleSuffix; + + if (cacheManager.isCached(resolvedRequire.absolutePath)) + return ModuleStatus::Cached; + + // Try to read the matching file + if (std::optional source = readFile(resolvedRequire.absolutePath)) + { + resolvedRequire.identifier = requireContext.createNewIdentifer(resolvedRequire.identifier + possibleSuffix); + resolvedRequire.sourceCode = *source; + return ModuleStatus::FileRead; + } + + resolvedRequire.absolutePath.resize(unsuffixedAbsolutePathSize); // truncate to remove suffix + } + + if (hasFileExtension(resolvedRequire.absolutePath, {".luau", ".lua"}) && isFile(resolvedRequire.absolutePath)) + { + errorHandler.reportError("error requiring module: consider removing the file extension"); + return ModuleStatus::ErrorReported; + } + + errorHandler.reportError(kRequireErrorGeneric); + return ModuleStatus::ErrorReported; +} + +bool RequireResolver::resolveAndStoreDefaultPaths() +{ + if (!isAbsolutePath(pathToResolve)) + { + std::string identifierContext = getRequiringContextRelative(); + std::optional absolutePathContext = getRequiringContextAbsolute(); + + if (!absolutePathContext) + return false; + + // resolvePath automatically sanitizes/normalizes the paths + std::optional identifier = resolvePath(pathToResolve, identifierContext); + std::optional absolutePath = resolvePath(pathToResolve, *absolutePathContext); + + if (!identifier || !absolutePath) + { + errorHandler.reportError("could not resolve require path"); + return false; + } + + resolvedRequire.identifier = std::move(*identifier); + resolvedRequire.absolutePath = std::move(*absolutePath); + } + else + { + // Here we must explicitly sanitize, as the path is taken as is + std::string sanitizedPath = normalizePath(pathToResolve); + resolvedRequire.identifier = sanitizedPath; + resolvedRequire.absolutePath = std::move(sanitizedPath); + } + return true; +} + +std::optional RequireResolver::getRequiringContextAbsolute() +{ + std::string requiringFile; + if (isAbsolutePath(requireContext.getPath())) + { + // We already have an absolute path for the requiring file + requiringFile = requireContext.getPath(); + } + else + { + // Requiring file's stored path is relative to the CWD, must make absolute + std::optional cwd = getCurrentWorkingDirectory(); + if (!cwd) + { + errorHandler.reportError("could not determine current working directory"); + return std::nullopt; + } + + if (requireContext.isStdin()) + { + // Require statement is being executed from REPL input prompt + // The requiring context is the pseudo-file "stdin" in the CWD + requiringFile = joinPaths(*cwd, "stdin"); + } + else + { + // Require statement is being executed in a file, must resolve relative to CWD + requiringFile = normalizePath(joinPaths(*cwd, requireContext.getPath())); + } + } + std::replace(requiringFile.begin(), requiringFile.end(), '\\', '/'); + return requiringFile; +} + +std::string RequireResolver::getRequiringContextRelative() +{ + return requireContext.isStdin() ? "./" : requireContext.getPath(); +} + +bool RequireResolver::substituteAliasIfPresent(std::string& path) +{ + if (path.size() < 1 || path[0] != '@') + return true; + + // To ignore the '@' alias prefix when processing the alias + const size_t aliasStartPos = 1; + + // If a directory separator was found, the length of the alias is the + // distance between the start of the alias and the separator. Otherwise, + // the whole string after the alias symbol is the alias. + size_t aliasLen = path.find_first_of("\\/"); + if (aliasLen != std::string::npos) + aliasLen -= aliasStartPos; + + const std::string potentialAlias = path.substr(aliasStartPos, aliasLen); + + // Not worth searching when potentialAlias cannot be an alias + if (!Luau::isValidAlias(potentialAlias)) + { + errorHandler.reportError("@" + potentialAlias + " is not a valid alias"); + return false; + } + + if (std::optional alias = getAlias(potentialAlias)) + { + path = *alias + path.substr(potentialAlias.size() + 1); + return true; + } + + errorHandler.reportError("@" + potentialAlias + " is not a valid alias"); + return false; +} + +std::optional RequireResolver::getAlias(std::string alias) +{ + std::transform( + alias.begin(), + alias.end(), + alias.begin(), + [](unsigned char c) + { + return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; + } + ); + while (!config.aliases.contains(alias) && !isConfigFullyResolved) + { + if (!parseNextConfig()) + return std::nullopt; // error parsing config + } + if (!config.aliases.contains(alias) && isConfigFullyResolved) + return std::nullopt; // could not find alias + + const Luau::Config::AliasInfo& aliasInfo = config.aliases[alias]; + return resolvePath(aliasInfo.value, aliasInfo.configLocation); +} + +bool RequireResolver::parseNextConfig() +{ + if (isConfigFullyResolved) + return true; // no config files left to parse + + std::optional directory; + if (lastSearchedDir.empty()) + { + std::optional requiringFile = getRequiringContextAbsolute(); + if (!requiringFile) + return false; + + directory = getParentPath(*requiringFile); + } + else + directory = getParentPath(lastSearchedDir); + + if (directory) + { + lastSearchedDir = *directory; + if (!parseConfigInDirectory(*directory)) + return false; + } + else + isConfigFullyResolved = true; + + return true; +} + +bool RequireResolver::parseConfigInDirectory(const std::string& directory) +{ + std::string configPath = joinPaths(directory, Luau::kConfigName); + + Luau::ConfigOptions::AliasOptions aliasOpts; + aliasOpts.configLocation = configPath; + aliasOpts.overwriteAliases = false; + + Luau::ConfigOptions opts; + opts.aliasOptions = std::move(aliasOpts); + + if (std::optional contents = readFile(configPath)) + { + std::optional error = Luau::parseConfig(*contents, config, opts); + if (error) + { + errorHandler.reportError("error parsing " + configPath + "(" + *error + ")"); + return false; + } + } + + return true; +} diff --git a/CLI/Web.cpp b/CLI/src/Web.cpp similarity index 100% rename from CLI/Web.cpp rename to CLI/src/Web.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b18cd5c9..5286fd9f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,11 +68,12 @@ include(Sources.cmake) target_include_directories(Luau.Common INTERFACE Common/include) target_compile_features(Luau.CLI.lib PUBLIC cxx_std_17) -target_link_libraries(Luau.CLI.lib PRIVATE Luau.Common) +target_include_directories(Luau.CLI.lib PUBLIC CLI/include) +target_link_libraries(Luau.CLI.lib PRIVATE Luau.Common Luau.Config) target_compile_features(Luau.Ast PUBLIC cxx_std_17) target_include_directories(Luau.Ast PUBLIC Ast/include) -target_link_libraries(Luau.Ast PUBLIC Luau.Common Luau.CLI.lib) +target_link_libraries(Luau.Ast PUBLIC Luau.Common) target_compile_features(Luau.Compiler PUBLIC cxx_std_17) target_include_directories(Luau.Compiler PUBLIC Compiler/include) @@ -85,6 +86,7 @@ target_link_libraries(Luau.Config PUBLIC Luau.Ast) target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_include_directories(Luau.Analysis PUBLIC Analysis/include) target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.EqSat Luau.Config) +target_link_libraries(Luau.Analysis PRIVATE Luau.Compiler Luau.VM) target_compile_features(Luau.EqSat PUBLIC cxx_std_17) target_include_directories(Luau.EqSat PUBLIC EqSat/include) @@ -123,6 +125,8 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") # Some gcc versions treat var in `if (type var = val)` as unused # Some gcc versions treat variables used in constexpr if blocks as unused list(APPEND LUAU_OPTIONS -Wno-unused) + # some gcc versions warn maybe uninitialized on optional members on structs + list(APPEND LUAU_OPTIONS -Wno-maybe-uninitialized) endif() # Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere @@ -276,7 +280,7 @@ foreach(LIB Luau.Ast Luau.Compiler Luau.Config Luau.Analysis Luau.EqSat Luau.Cod if(LIB MATCHES "CodeGen|VM" AND DEPENDS MATCHES "Ast|Analysis|Config|Compiler") message(FATAL_ERROR ${LIB} " is a runtime component but it depends on one of the offline components") endif() - if(LIB MATCHES "Ast|Analysis|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM") + if(LIB MATCHES "Ast|EqSat|Compiler" AND DEPENDS MATCHES "CodeGen|VM") message(FATAL_ERROR ${LIB} " is an offline component but it depends on one of the runtime components") endif() if(LIB MATCHES "Ast|Compiler" AND DEPENDS MATCHES "Analysis|Config") diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 26579740..d9ba7b0a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,7 +8,7 @@ Some questions help improve the language, implementation or documentation by ins ## Documentation -A [separate site repository](https://github.com/luau-lang/site) hosts the language documentation, which is accessible on https://luau-lang.org. +A [separate site repository](https://github.com/luau-lang/site) hosts the language documentation, which is accessible on https://luau.org. Changes to this documentation that improve clarity, fix grammatical issues, explain aspects that haven't been explained before and the like are warmly welcomed. Please feel free to [create a pull request](https://help.github.com/articles/about-pull-requests/) to improve our documentation. Note that at this point the documentation is English-only. @@ -23,12 +23,13 @@ Of course, feel free to [create a pull request](https://help.github.com/articles ## Features If you're thinking of adding a new feature to the language, library, analysis tools, etc., please *don't* start by submitting a pull request. -Luau team has internal priorities and a roadmap that may or may not align with specific features, so before starting to work on a feature please submit an issue describing the missing feature that you'd like to add. +The Luau team has internal priorities and a roadmap that may or may not align with specific features, so before starting to work on a feature, please submit an issue describing the missing feature that you'd like to add. -For features that result in observable change of language syntax or semantics, you'd need to [create an RFC](https://github.com/luau-lang/rfcs/blob/master/README.md) to make sure that the feature is needed and well designed. +For features that result in an observable change to the language's syntax or semantics, you'll need to [create an RFC](https://github.com/luau-lang/rfcs/blob/master/README.md) to make sure that the feature is needed and well-designed. -Finally, please note that Luau tries to carry a minimal feature set. All features must be evaluated not just for the benefits that they provide, but also for the downsides/costs in terms of language simplicity, maintainability, cross-feature interaction etc. +Finally, please note that Luau tries to carry a minimal feature set. All features must be evaluated not just for the benefits that they provide, but also for the downsides/costs in terms of language simplicity, maintainability, cross-feature interaction, etc. As such, feature requests may not be accepted even if a comprehensive RFC is written - don't expect Luau to gain a feature just because another programming language has it. +We generally apply a standard similar to the C\# team's famous [Minus 100 Points](https://learn.microsoft.com/en-us/archive/blogs/ericgu/minus-100-points). ## Code style @@ -48,9 +49,9 @@ When making code changes please try to make sure they are covered by an existing ## Performance -One of the central feature of Luau is performance; our runtime in particular is heavily optimized for high performance and low memory consumption, and code is generally carefully tuned to result in close to optimal assembly for x64 and AArch64 architectures. The analysis code is not optimized to the same level of detail, but performance is still very important to make sure that we can support interactive IDE features. +One of the central features of Luau is performance; our runtime in particular is heavily optimized for high performance and low memory consumption, and code is generally carefully tuned to result in close-to-optimal assembly for x64 and AArch64 architectures. The analysis code is not optimized to the same level of detail, but performance is still very important to make sure that we can support interactive IDE features. -As such, it's important to make sure that the changes, including bug fixes, improve or at least do not regress performance. For VM this can be validated by running `bench.py` script from `bench` folder on two binaries built in Release mode, before and after the changes, although note that our benchmark coverage is not complete and in some cases additional performance testing will be necessary to determine if the change can be merged. +As such, it's important to make sure that the changes, including bug fixes, improve (or at least do not regress) performance. For the VM, this can be validated by running `bench/bench.py` on two binaries built in Release mode, before and after the changes. Note that our benchmark coverage is not complete, and in some cases, additional performance testing will be necessary to determine if the change can be merged. ## Feature flags diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index a4d857a4..9d337942 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -138,6 +138,7 @@ public: void fneg(RegisterA64 dst, RegisterA64 src); void fsqrt(RegisterA64 dst, RegisterA64 src); void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2); + void faddp(RegisterA64 dst, RegisterA64 src); // Vector component manipulation void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index); diff --git a/CodeGen/include/Luau/AssemblyBuilderX64.h b/CodeGen/include/Luau/AssemblyBuilderX64.h index c52d95c5..ca5fa7a9 100644 --- a/CodeGen/include/Luau/AssemblyBuilderX64.h +++ b/CodeGen/include/Luau/AssemblyBuilderX64.h @@ -160,6 +160,7 @@ public: void vmaxsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); + void vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2); void vblendvpd(RegisterX64 dst, RegisterX64 src1, OperandX64 mask, RegisterX64 src3); @@ -167,6 +168,8 @@ public: void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle); void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset); + void vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask); + // Run final checks bool finalize(); diff --git a/CodeGen/include/Luau/CodeAllocator.h b/CodeGen/include/Luau/CodeAllocator.h index dcc1de85..db1774d8 100644 --- a/CodeGen/include/Luau/CodeAllocator.h +++ b/CodeGen/include/Luau/CodeAllocator.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/CodeGen.h" +#include "Luau/CodeGenOptions.h" #include diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 0cf9d9a5..2e689fe2 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -1,7 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include +#include "Luau/CodeGenCommon.h" +#include "Luau/CodeGenOptions.h" +#include "Luau/LoweringStats.h" + #include #include #include @@ -12,25 +15,11 @@ struct lua_State; -#if defined(__x86_64__) || defined(_M_X64) -#define CODEGEN_TARGET_X64 -#elif defined(__aarch64__) || defined(_M_ARM64) -#define CODEGEN_TARGET_A64 -#endif - namespace Luau { namespace CodeGen { -enum CodeGenFlags -{ - // Only run native codegen for modules that have been marked with --!native - CodeGen_OnlyNativeModules = 1 << 0, - // Run native codegen for functions that the compiler considers not profitable - CodeGen_ColdFunctions = 1 << 1, -}; - // These enum values can be reported through telemetry. // To ensure consistency, changes should be additive. enum class CodeGenCompilationResult @@ -72,106 +61,6 @@ struct CompilationResult } }; -struct IrBuilder; -struct IrOp; - -using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); -using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); -using HostVectorNamecallHandler = - bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); - -enum class HostMetamethod -{ - Add, - Sub, - Mul, - Div, - Idiv, - Mod, - Pow, - Minus, - Equal, - LessThan, - LessEqual, - Length, - Concat, -}; - -using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength); -using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method); -using HostUserdataAccessHandler = - bool (*)(IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); -using HostUserdataMetamethodHandler = - bool (*)(IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos); -using HostUserdataNamecallHandler = bool (*)( - IrBuilder& builder, - uint8_t type, - const char* member, - size_t memberLength, - int argResReg, - int sourceReg, - int params, - int results, - int pcpos -); - -struct HostIrHooks -{ - // Suggest result type of a vector field access - HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr; - - // Suggest result type of a vector function namecall - HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr; - - // Handle vector value field access - // 'sourceReg' is guaranteed to be a vector - // Guards should take a VM exit to 'pcpos' - HostVectorAccessHandler vectorAccess = nullptr; - - // Handle namecall performed on a vector value - // 'sourceReg' (self argument) is guaranteed to be a vector - // All other arguments can be of any type - // Guards should take a VM exit to 'pcpos' - HostVectorNamecallHandler vectorNamecall = nullptr; - - // Suggest result type of a userdata field access - HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr; - - // Suggest result type of a metamethod call - HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr; - - // Suggest result type of a userdata namecall - HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr; - - // Handle userdata value field access - // 'sourceReg' is guaranteed to be a userdata, but tag has to be checked - // Write to 'resultReg' might invalidate 'sourceReg' - // Guards should take a VM exit to 'pcpos' - HostUserdataAccessHandler userdataAccess = nullptr; - - // Handle metamethod operation on a userdata value - // 'lhs' and 'rhs' operands can be VM registers of constants - // Operand types have to be checked and userdata operand tags have to be checked - // Write to 'resultReg' might invalidate source operands - // Guards should take a VM exit to 'pcpos' - HostUserdataMetamethodHandler userdataMetamethod = nullptr; - - // Handle namecall performed on a userdata value - // 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked - // All other arguments can be of any type - // Guards should take a VM exit to 'pcpos' - HostUserdataNamecallHandler userdataNamecall = nullptr; -}; - -struct CompilationOptions -{ - unsigned int flags = 0; - HostIrHooks hooks; - - // null-terminated array of userdata types names that might have custom lowering - const char* const* userdataTypes = nullptr; -}; - struct CompilationStats { size_t bytecodeSizeBytes = 0; @@ -184,8 +73,6 @@ struct CompilationStats uint32_t functionsBound = 0; }; -using AllocationCallback = void(void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize); - bool isSupported(); class SharedCodeGenContext; @@ -249,153 +136,6 @@ CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, unsig CompilationResult compile(lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr); CompilationResult compile(const ModuleId& moduleId, lua_State* L, int idx, const CompilationOptions& options, CompilationStats* stats = nullptr); -using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); - -// Output "#" before IR blocks and instructions -enum class IncludeIrPrefix -{ - No, - Yes -}; - -// Output user count and last use information of blocks and instructions -enum class IncludeUseInfo -{ - No, - Yes -}; - -// Output CFG informations like block predecessors, successors and etc -enum class IncludeCfgInfo -{ - No, - Yes -}; - -// Output VM register live in/out information for blocks -enum class IncludeRegFlowInfo -{ - No, - Yes -}; - -struct AssemblyOptions -{ - enum Target - { - Host, - A64, - A64_NoFeatures, - X64_Windows, - X64_SystemV, - }; - - Target target = Host; - - CompilationOptions compilationOptions; - - bool outputBinary = false; - - bool includeAssembly = false; - bool includeIr = false; - bool includeOutlinedCode = false; - bool includeIrTypes = false; - - IncludeIrPrefix includeIrPrefix = IncludeIrPrefix::Yes; - IncludeUseInfo includeUseInfo = IncludeUseInfo::Yes; - IncludeCfgInfo includeCfgInfo = IncludeCfgInfo::Yes; - IncludeRegFlowInfo includeRegFlowInfo = IncludeRegFlowInfo::Yes; - - // Optional annotator function can be provided to describe each instruction, it takes function id and sequential instruction id - AnnotatorFn annotator = nullptr; - void* annotatorContext = nullptr; -}; - -struct BlockLinearizationStats -{ - unsigned int constPropInstructionCount = 0; - double timeSeconds = 0.0; - - BlockLinearizationStats& operator+=(const BlockLinearizationStats& that) - { - this->constPropInstructionCount += that.constPropInstructionCount; - this->timeSeconds += that.timeSeconds; - - return *this; - } - - BlockLinearizationStats operator+(const BlockLinearizationStats& other) const - { - BlockLinearizationStats result(*this); - result += other; - return result; - } -}; - -enum FunctionStatsFlags -{ - // Enable stats collection per function - FunctionStats_Enable = 1 << 0, - // Compute function bytecode summary - FunctionStats_BytecodeSummary = 1 << 1, -}; - -struct FunctionStats -{ - std::string name; - int line = -1; - unsigned bcodeCount = 0; - unsigned irCount = 0; - unsigned asmCount = 0; - unsigned asmSize = 0; - std::vector> bytecodeSummary; -}; - -struct LoweringStats -{ - unsigned totalFunctions = 0; - unsigned skippedFunctions = 0; - int spillsToSlot = 0; - int spillsToRestore = 0; - unsigned maxSpillSlotsUsed = 0; - unsigned blocksPreOpt = 0; - unsigned blocksPostOpt = 0; - unsigned maxBlockInstructions = 0; - - int regAllocErrors = 0; - int loweringErrors = 0; - - BlockLinearizationStats blockLinearizationStats; - - unsigned functionStatsFlags = 0; - std::vector functions; - - LoweringStats operator+(const LoweringStats& other) const - { - LoweringStats result(*this); - result += other; - return result; - } - - LoweringStats& operator+=(const LoweringStats& that) - { - this->totalFunctions += that.totalFunctions; - this->skippedFunctions += that.skippedFunctions; - this->spillsToSlot += that.spillsToSlot; - this->spillsToRestore += that.spillsToRestore; - this->maxSpillSlotsUsed = std::max(this->maxSpillSlotsUsed, that.maxSpillSlotsUsed); - this->blocksPreOpt += that.blocksPreOpt; - this->blocksPostOpt += that.blocksPostOpt; - this->maxBlockInstructions = std::max(this->maxBlockInstructions, that.maxBlockInstructions); - this->regAllocErrors += that.regAllocErrors; - this->loweringErrors += that.loweringErrors; - this->blockLinearizationStats += that.blockLinearizationStats; - if (this->functionStatsFlags & FunctionStats_Enable) - this->functions.insert(this->functions.end(), that.functions.begin(), that.functions.end()); - return *this; - } -}; - // Generates assembly for target function and all inner functions std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {}, LoweringStats* stats = nullptr); diff --git a/CodeGen/include/Luau/CodeGenCommon.h b/CodeGen/include/Luau/CodeGenCommon.h index 84090423..a9d1761c 100644 --- a/CodeGen/include/Luau/CodeGenCommon.h +++ b/CodeGen/include/Luau/CodeGenCommon.h @@ -10,3 +10,9 @@ #else #define CODEGEN_ASSERT(expr) (void)sizeof(!!(expr)) #endif + +#if defined(__x86_64__) || defined(_M_X64) +#define CODEGEN_TARGET_X64 +#elif defined(__aarch64__) || defined(_M_ARM64) +#define CODEGEN_TARGET_A64 +#endif diff --git a/CodeGen/include/Luau/CodeGenOptions.h b/CodeGen/include/Luau/CodeGenOptions.h new file mode 100644 index 00000000..de95efa6 --- /dev/null +++ b/CodeGen/include/Luau/CodeGenOptions.h @@ -0,0 +1,188 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +enum CodeGenFlags +{ + // Only run native codegen for modules that have been marked with --!native + CodeGen_OnlyNativeModules = 1 << 0, + // Run native codegen for functions that the compiler considers not profitable + CodeGen_ColdFunctions = 1 << 1, +}; + +using AllocationCallback = void(void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize); + +struct IrBuilder; +struct IrOp; + +using HostVectorOperationBytecodeType = uint8_t (*)(const char* member, size_t memberLength); +using HostVectorAccessHandler = bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); +using HostVectorNamecallHandler = + bool (*)(IrBuilder& builder, const char* member, size_t memberLength, int argResReg, int sourceReg, int params, int results, int pcpos); + +enum class HostMetamethod +{ + Add, + Sub, + Mul, + Div, + Idiv, + Mod, + Pow, + Minus, + Equal, + LessThan, + LessEqual, + Length, + Concat, +}; + +using HostUserdataOperationBytecodeType = uint8_t (*)(uint8_t type, const char* member, size_t memberLength); +using HostUserdataMetamethodBytecodeType = uint8_t (*)(uint8_t lhsTy, uint8_t rhsTy, HostMetamethod method); +using HostUserdataAccessHandler = + bool (*)(IrBuilder& builder, uint8_t type, const char* member, size_t memberLength, int resultReg, int sourceReg, int pcpos); +using HostUserdataMetamethodHandler = + bool (*)(IrBuilder& builder, uint8_t lhsTy, uint8_t rhsTy, int resultReg, IrOp lhs, IrOp rhs, HostMetamethod method, int pcpos); +using HostUserdataNamecallHandler = bool (*)( + IrBuilder& builder, + uint8_t type, + const char* member, + size_t memberLength, + int argResReg, + int sourceReg, + int params, + int results, + int pcpos +); + +struct HostIrHooks +{ + // Suggest result type of a vector field access + HostVectorOperationBytecodeType vectorAccessBytecodeType = nullptr; + + // Suggest result type of a vector function namecall + HostVectorOperationBytecodeType vectorNamecallBytecodeType = nullptr; + + // Handle vector value field access + // 'sourceReg' is guaranteed to be a vector + // Guards should take a VM exit to 'pcpos' + HostVectorAccessHandler vectorAccess = nullptr; + + // Handle namecall performed on a vector value + // 'sourceReg' (self argument) is guaranteed to be a vector + // All other arguments can be of any type + // Guards should take a VM exit to 'pcpos' + HostVectorNamecallHandler vectorNamecall = nullptr; + + // Suggest result type of a userdata field access + HostUserdataOperationBytecodeType userdataAccessBytecodeType = nullptr; + + // Suggest result type of a metamethod call + HostUserdataMetamethodBytecodeType userdataMetamethodBytecodeType = nullptr; + + // Suggest result type of a userdata namecall + HostUserdataOperationBytecodeType userdataNamecallBytecodeType = nullptr; + + // Handle userdata value field access + // 'sourceReg' is guaranteed to be a userdata, but tag has to be checked + // Write to 'resultReg' might invalidate 'sourceReg' + // Guards should take a VM exit to 'pcpos' + HostUserdataAccessHandler userdataAccess = nullptr; + + // Handle metamethod operation on a userdata value + // 'lhs' and 'rhs' operands can be VM registers of constants + // Operand types have to be checked and userdata operand tags have to be checked + // Write to 'resultReg' might invalidate source operands + // Guards should take a VM exit to 'pcpos' + HostUserdataMetamethodHandler userdataMetamethod = nullptr; + + // Handle namecall performed on a userdata value + // 'sourceReg' (self argument) is guaranteed to be a userdata, but tag has to be checked + // All other arguments can be of any type + // Guards should take a VM exit to 'pcpos' + HostUserdataNamecallHandler userdataNamecall = nullptr; +}; + +struct CompilationOptions +{ + unsigned int flags = 0; + HostIrHooks hooks; + + // null-terminated array of userdata types names that might have custom lowering + const char* const* userdataTypes = nullptr; +}; + + +using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); + +// Output "#" before IR blocks and instructions +enum class IncludeIrPrefix +{ + No, + Yes +}; + +// Output user count and last use information of blocks and instructions +enum class IncludeUseInfo +{ + No, + Yes +}; + +// Output CFG informations like block predecessors, successors and etc +enum class IncludeCfgInfo +{ + No, + Yes +}; + +// Output VM register live in/out information for blocks +enum class IncludeRegFlowInfo +{ + No, + Yes +}; + +struct AssemblyOptions +{ + enum Target + { + Host, + A64, + A64_NoFeatures, + X64_Windows, + X64_SystemV, + }; + + Target target = Host; + + CompilationOptions compilationOptions; + + bool outputBinary = false; + + bool includeAssembly = false; + bool includeIr = false; + bool includeOutlinedCode = false; + bool includeIrTypes = false; + + IncludeIrPrefix includeIrPrefix = IncludeIrPrefix::Yes; + IncludeUseInfo includeUseInfo = IncludeUseInfo::Yes; + IncludeCfgInfo includeCfgInfo = IncludeCfgInfo::Yes; + IncludeRegFlowInfo includeRegFlowInfo = IncludeRegFlowInfo::Yes; + + // Optional annotator function can be provided to describe each instruction, it takes function id and sequential instruction id + AnnotatorFn annotator = nullptr; + void* annotatorContext = nullptr; +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index ae406bbc..38519f95 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -20,6 +20,8 @@ namespace Luau namespace CodeGen { +struct LoweringStats; + // IR extensions to LuauBuiltinFunction enum (these only exist inside IR, and start from 256 to avoid collisions) enum { @@ -67,18 +69,18 @@ enum class IrCmd : uint8_t LOAD_ENV, // Get pointer (TValue) to table array at index - // A: pointer (Table) + // A: pointer (LuaTable) // B: int GET_ARR_ADDR, // Get pointer (LuaNode) to table node element at the active cached slot index - // A: pointer (Table) + // A: pointer (LuaTable) // B: unsigned int (pcpos) // C: Kn GET_SLOT_NODE_ADDR, // Get pointer (LuaNode) to table node element at the main position of the specified key hash - // A: pointer (Table) + // A: pointer (LuaTable) // B: unsigned int (hash) GET_HASH_NODE_ADDR, @@ -114,10 +116,12 @@ enum class IrCmd : uint8_t STORE_INT, // Store a vector into TValue + // When optional 'E' tag is present, it is written out to the TValue as well // A: Rn // B: double (x) // C: double (y) // D: double (z) + // E: tag (optional) STORE_VECTOR, // Store a TValue into memory @@ -183,6 +187,11 @@ enum class IrCmd : uint8_t // A: double SIGN_NUM, + // Select B if C == D, otherwise select A + // A, B: double (endpoints) + // C, D: double (condition arguments) + SELECT_NUM, + // Add/Sub/Mul/Div/Idiv two vectors // A, B: TValue ADD_VEC, @@ -194,6 +203,10 @@ enum class IrCmd : uint8_t // A: TValue UNM_VEC, + // Compute dot product between two vectors + // A, B: TValue + DOT_VEC, + // Compute Luau 'not' operation on destructured TValue // A: tag // B: int (value) @@ -262,7 +275,7 @@ enum class IrCmd : uint8_t JUMP_SLOT_MATCH, // Get table length - // A: pointer (Table) + // A: pointer (LuaTable) TABLE_LEN, // Get string length @@ -275,11 +288,11 @@ enum class IrCmd : uint8_t NEW_TABLE, // Duplicate a table - // A: pointer (Table) + // A: pointer (LuaTable) DUP_TABLE, // Insert an integer key into a table and return the pointer to inserted value (TValue) - // A: pointer (Table) + // A: pointer (LuaTable) // B: int (key) TABLE_SETNUM, @@ -419,13 +432,13 @@ enum class IrCmd : uint8_t CHECK_TRUTHY, // Guard against readonly table - // A: pointer (Table) + // A: pointer (LuaTable) // B: block/vmexit/undef // When undef is specified instead of a block, execution is aborted on check failure CHECK_READONLY, // Guard against table having a metatable - // A: pointer (Table) + // A: pointer (LuaTable) // B: block/vmexit/undef // When undef is specified instead of a block, execution is aborted on check failure CHECK_NO_METATABLE, @@ -436,7 +449,7 @@ enum class IrCmd : uint8_t CHECK_SAFE_ENV, // Guard against index overflowing the table array size - // A: pointer (Table) + // A: pointer (LuaTable) // B: int (index) // C: block/vmexit/undef // When undef is specified instead of a block, execution is aborted on check failure @@ -492,11 +505,11 @@ enum class IrCmd : uint8_t BARRIER_OBJ, // Handle GC write barrier (backwards) for a write into a table - // A: pointer (Table) + // A: pointer (LuaTable) BARRIER_TABLE_BACK, // Handle GC write barrier (forward) for a write into a table - // A: pointer (Table) + // A: pointer (LuaTable) // B: Rn (TValue that was written to the object) // C: tag/undef (tag of the value that was written) BARRIER_TABLE_FORWARD, @@ -1038,6 +1051,8 @@ struct IrFunction CfgInfo cfg; + LoweringStats* stats = nullptr; + IrBlock& blockOp(IrOp op) { CODEGEN_ASSERT(op.kind == IrOpKind::Block); diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index 9364f461..27a9feb4 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -2,11 +2,13 @@ #pragma once #include "Luau/IrData.h" -#include "Luau/CodeGen.h" +#include "Luau/CodeGenOptions.h" #include #include +struct Proto; + namespace Luau { namespace CodeGen @@ -23,6 +25,7 @@ struct IrToStringContext const std::vector& blocks; const std::vector& constants; const CfgInfo& cfg; + Proto* proto = nullptr; }; void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 8d48780f..1afa1a34 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -5,6 +5,8 @@ #include "Luau/Common.h" #include "Luau/IrData.h" +LUAU_FASTFLAG(LuauVectorLibNativeDot); + namespace Luau { namespace CodeGen @@ -172,10 +174,15 @@ inline bool hasResult(IrCmd cmd) case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::SIGN_NUM: + case IrCmd::SELECT_NUM: case IrCmd::ADD_VEC: case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: case IrCmd::DIV_VEC: + case IrCmd::DOT_VEC: + if (cmd == IrCmd::DOT_VEC) + LUAU_ASSERT(FFlag::LuauVectorLibNativeDot); + LUAU_FALLTHROUGH; case IrCmd::UNM_VEC: case IrCmd::NOT_ANY: case IrCmd::CMP_ANY: diff --git a/CodeGen/include/Luau/LoweringStats.h b/CodeGen/include/Luau/LoweringStats.h new file mode 100644 index 00000000..532a5270 --- /dev/null +++ b/CodeGen/include/Luau/LoweringStats.h @@ -0,0 +1,103 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include +#include + +namespace Luau +{ +namespace CodeGen +{ + +struct BlockLinearizationStats +{ + unsigned int constPropInstructionCount = 0; + double timeSeconds = 0.0; + + BlockLinearizationStats& operator+=(const BlockLinearizationStats& that) + { + this->constPropInstructionCount += that.constPropInstructionCount; + this->timeSeconds += that.timeSeconds; + + return *this; + } + + BlockLinearizationStats operator+(const BlockLinearizationStats& other) const + { + BlockLinearizationStats result(*this); + result += other; + return result; + } +}; + +enum FunctionStatsFlags +{ + // Enable stats collection per function + FunctionStats_Enable = 1 << 0, + // Compute function bytecode summary + FunctionStats_BytecodeSummary = 1 << 1, +}; + +struct FunctionStats +{ + std::string name; + int line = -1; + unsigned bcodeCount = 0; + unsigned irCount = 0; + unsigned asmCount = 0; + unsigned asmSize = 0; + std::vector> bytecodeSummary; +}; + +struct LoweringStats +{ + unsigned totalFunctions = 0; + unsigned skippedFunctions = 0; + int spillsToSlot = 0; + int spillsToRestore = 0; + unsigned maxSpillSlotsUsed = 0; + unsigned blocksPreOpt = 0; + unsigned blocksPostOpt = 0; + unsigned maxBlockInstructions = 0; + + int regAllocErrors = 0; + int loweringErrors = 0; + + BlockLinearizationStats blockLinearizationStats; + + unsigned functionStatsFlags = 0; + std::vector functions; + + LoweringStats operator+(const LoweringStats& other) const + { + LoweringStats result(*this); + result += other; + return result; + } + + LoweringStats& operator+=(const LoweringStats& that) + { + this->totalFunctions += that.totalFunctions; + this->skippedFunctions += that.skippedFunctions; + this->spillsToSlot += that.spillsToSlot; + this->spillsToRestore += that.spillsToRestore; + this->maxSpillSlotsUsed = std::max(this->maxSpillSlotsUsed, that.maxSpillSlotsUsed); + this->blocksPreOpt += that.blocksPreOpt; + this->blocksPostOpt += that.blocksPostOpt; + this->maxBlockInstructions = std::max(this->maxBlockInstructions, that.maxBlockInstructions); + + this->regAllocErrors += that.regAllocErrors; + this->loweringErrors += that.loweringErrors; + + this->blockLinearizationStats += that.blockLinearizationStats; + + if (this->functionStatsFlags & FunctionStats_Enable) + this->functions.insert(this->functions.end(), that.functions.begin(), that.functions.end()); + + return *this; + } +}; + +} // namespace CodeGen +} // namespace Luau diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index b98a21f2..9e17d3fd 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -7,6 +7,8 @@ #include #include +LUAU_FASTFLAG(LuauVectorLibNativeDot); + namespace Luau { namespace CodeGen @@ -586,6 +588,15 @@ void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src) placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000); } +void AssemblyBuilderA64::faddp(RegisterA64 dst, RegisterA64 src) +{ + LUAU_ASSERT(FFlag::LuauVectorLibNativeDot); + CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::s); + CODEGEN_ASSERT(dst.kind == src.kind); + + placeR1("faddp", dst, src, 0b011'11110'0'0'11000'01101'10 | ((dst.kind == KindA64::d) << 12)); +} + void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2) { if (dst.kind == KindA64::d) diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 73c40679..1fb1b671 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAG(LuauVectorLibNativeDot); + namespace Luau { namespace CodeGen @@ -925,6 +927,11 @@ void AssemblyBuilderX64::vminsd(OperandX64 dst, OperandX64 src1, OperandX64 src2 placeAvx("vminsd", dst, src1, src2, 0x5d, false, AVX_0F, AVX_F2); } +void AssemblyBuilderX64::vcmpeqsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) +{ + placeAvx("vcmpeqsd", dst, src1, src2, 0x00, 0xc2, false, AVX_0F, AVX_F2); +} + void AssemblyBuilderX64::vcmpltsd(OperandX64 dst, OperandX64 src1, OperandX64 src2) { placeAvx("vcmpltsd", dst, src1, src2, 0x01, 0xc2, false, AVX_0F, AVX_F2); @@ -946,6 +953,12 @@ void AssemblyBuilderX64::vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 s placeAvx("vpinsrd", dst, src1, src2, offset, 0x22, false, AVX_0F3A, AVX_66); } +void AssemblyBuilderX64::vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask) +{ + LUAU_ASSERT(FFlag::LuauVectorLibNativeDot); + placeAvx("vdpps", dst, src1, src2, mask, 0x40, false, AVX_0F3A, AVX_66); +} + bool AssemblyBuilderX64::finalize() { code.resize(codePos - code.data()); diff --git a/CodeGen/src/BytecodeAnalysis.cpp b/CodeGen/src/BytecodeAnalysis.cpp index 8d2efebe..b859b111 100644 --- a/CodeGen/src/BytecodeAnalysis.cpp +++ b/CodeGen/src/BytecodeAnalysis.cpp @@ -2,7 +2,7 @@ #include "Luau/BytecodeAnalysis.h" #include "Luau/BytecodeUtils.h" -#include "Luau/CodeGen.h" +#include "Luau/CodeGenOptions.h" #include "Luau/IrData.h" #include "Luau/IrUtils.h" @@ -235,7 +235,7 @@ static uint8_t getBytecodeConstantTag(Proto* proto, unsigned ki) return LBC_TYPE_ANY; } -static void applyBuiltinCall(int bfid, BytecodeTypes& types) +static void applyBuiltinCall(LuauBuiltinFunction bfid, BytecodeTypes& types) { switch (bfid) { @@ -515,6 +515,46 @@ static void applyBuiltinCall(int bfid, BytecodeTypes& types) types.a = LBC_TYPE_TABLE; types.b = LBC_TYPE_TABLE; break; + case LBF_VECTOR_MAGNITUDE: + types.result = LBC_TYPE_NUMBER; + types.a = LBC_TYPE_VECTOR; + break; + case LBF_VECTOR_NORMALIZE: + types.result = LBC_TYPE_VECTOR; + types.a = LBC_TYPE_VECTOR; + break; + case LBF_VECTOR_CROSS: + types.result = LBC_TYPE_VECTOR; + types.a = LBC_TYPE_VECTOR; + types.b = LBC_TYPE_VECTOR; + break; + case LBF_VECTOR_DOT: + types.result = LBC_TYPE_NUMBER; + types.a = LBC_TYPE_VECTOR; + types.b = LBC_TYPE_VECTOR; + break; + case LBF_VECTOR_FLOOR: + case LBF_VECTOR_CEIL: + case LBF_VECTOR_ABS: + case LBF_VECTOR_SIGN: + case LBF_VECTOR_CLAMP: + types.result = LBC_TYPE_VECTOR; + types.a = LBC_TYPE_VECTOR; + types.b = LBC_TYPE_VECTOR; + break; + case LBF_VECTOR_MIN: + case LBF_VECTOR_MAX: + types.result = LBC_TYPE_VECTOR; + types.a = LBC_TYPE_VECTOR; + types.b = LBC_TYPE_VECTOR; + types.c = LBC_TYPE_VECTOR; // We can mark optional arguments + break; + case LBF_MATH_LERP: + types.result = LBC_TYPE_NUMBER; + types.a = LBC_TYPE_NUMBER; + types.b = LBC_TYPE_NUMBER; + types.c = LBC_TYPE_NUMBER; + break; } } @@ -808,7 +848,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; @@ -839,7 +880,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) { regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); } @@ -861,7 +903,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; @@ -883,7 +926,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; @@ -914,7 +958,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) { regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); } @@ -936,7 +981,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.a == LBC_TYPE_NUMBER && bcType.b == LBC_TYPE_NUMBER) regTags[ra] = LBC_TYPE_NUMBER; - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; @@ -957,7 +1003,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) regTags[ra] = LBC_TYPE_NUMBER; else if (bcType.a == LBC_TYPE_VECTOR && bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); bcType.result = regTags[ra]; @@ -986,7 +1033,8 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) if (bcType.b == LBC_TYPE_NUMBER || bcType.b == LBC_TYPE_VECTOR) regTags[ra] = LBC_TYPE_VECTOR; } - else if (hostHooks.userdataMetamethodBytecodeType && (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) + else if (hostHooks.userdataMetamethodBytecodeType && + (isCustomUserdataBytecodeType(bcType.a) || isCustomUserdataBytecodeType(bcType.b))) { regTags[ra] = hostHooks.userdataMetamethodBytecodeType(bcType.a, bcType.b, opcodeToHostMetamethod(op)); } @@ -1052,7 +1100,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[ra + 1] = bcType.a; regTags[ra + 2] = bcType.b; regTags[ra + 3] = bcType.c; @@ -1071,7 +1119,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[ra] = bcType.result; @@ -1088,7 +1136,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[int(pc[1])] = bcType.b; @@ -1107,7 +1155,7 @@ void analyzeBytecodeTypes(IrFunction& function, const HostIrHooks& hostHooks) CODEGEN_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); int ra = LUAU_INSN_A(call); - applyBuiltinCall(bfid, bcType); + applyBuiltinCall(LuauBuiltinFunction(bfid), bcType); regTags[LUAU_INSN_B(*pc)] = bcType.a; regTags[aux & 0xff] = bcType.b; diff --git a/CodeGen/src/BytecodeSummary.cpp b/CodeGen/src/BytecodeSummary.cpp index d0d71504..d179dcc5 100644 --- a/CodeGen/src/BytecodeSummary.cpp +++ b/CodeGen/src/BytecodeSummary.cpp @@ -8,8 +8,6 @@ #include "lobject.h" #include "lstate.h" -LUAU_FASTFLAG(LuauNativeAttribute) - namespace Luau { namespace CodeGen @@ -58,10 +56,7 @@ std::vector summarizeBytecode(lua_State* L, int idx, un Proto* root = clvalue(func)->l.p; std::vector protos; - if (FFlag::LuauNativeAttribute) - gatherFunctions(protos, root, CodeGen_ColdFunctions, root->flags & LPF_NATIVE_FUNCTION); - else - gatherFunctions_DEPRECATED(protos, root, CodeGen_ColdFunctions); + gatherFunctions(protos, root, CodeGen_ColdFunctions, root->flags & LPF_NATIVE_FUNCTION); std::vector summaries; summaries.reserve(protos.size()); diff --git a/CodeGen/src/CodeBlockUnwind.cpp b/CodeGen/src/CodeBlockUnwind.cpp index cb2d693a..d21dd14b 100644 --- a/CodeGen/src/CodeBlockUnwind.cpp +++ b/CodeGen/src/CodeBlockUnwind.cpp @@ -2,6 +2,7 @@ #include "Luau/CodeBlockUnwind.h" #include "Luau/CodeAllocator.h" +#include "Luau/CodeGenCommon.h" #include "Luau/UnwindBuilder.h" #include @@ -19,9 +20,21 @@ #elif defined(__linux__) || defined(__APPLE__) -// Defined in unwind.h which may not be easily discoverable on various platforms -extern "C" void __register_frame(const void*) __attribute__((weak)); -extern "C" void __deregister_frame(const void*) __attribute__((weak)); +// __register_frame and __deregister_frame are defined in libgcc or libc++ +// (depending on how it's built). We want to declare them as weak symbols +// so that if they're provided by a shared library, we'll use them, and if +// not, we'll disable some c++ exception handling support. However, if they're +// declared as weak and the definitions are linked in a static library +// that's not linked with whole-archive, then the symbols will technically be defined here, +// and the linker won't look for the strong ones in the library. +#ifndef LUAU_ENABLE_REGISTER_FRAME +#define REGISTER_FRAME_WEAK __attribute__((weak)) +#else +#define REGISTER_FRAME_WEAK +#endif + +extern "C" void __register_frame(const void*) REGISTER_FRAME_WEAK; +extern "C" void __deregister_frame(const void*) REGISTER_FRAME_WEAK; extern "C" void __unw_add_dynamic_fde() __attribute__((weak)); #endif @@ -120,7 +133,7 @@ void* createBlockUnwindInfo(void* context, uint8_t* block, size_t blockSize, siz #endif #elif defined(__linux__) || defined(__APPLE__) - if (!__register_frame) + if (!&__register_frame) return nullptr; visitFdeEntries(unwindData, __register_frame); @@ -149,7 +162,7 @@ void destroyBlockUnwindInfo(void* context, void* unwindData) #endif #elif defined(__linux__) || defined(__APPLE__) - if (!__deregister_frame) + if (!&__deregister_frame) { CODEGEN_ASSERT(!"Cannot deregister unwind information"); return; diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 667f5726..01e87d3d 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -3,7 +3,7 @@ #include "CodeGenLower.h" -#include "Luau/Common.h" +#include "Luau/CodeGenCommon.h" #include "Luau/CodeAllocator.h" #include "Luau/CodeBlockUnwind.h" #include "Luau/IrBuilder.h" @@ -41,9 +41,9 @@ #endif #endif -LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt, false) -LUAU_FASTFLAGVARIABLE(DebugCodegenOptSize, false) -LUAU_FASTFLAGVARIABLE(DebugCodegenSkipNumbering, false) +LUAU_FASTFLAGVARIABLE(DebugCodegenNoOpt) +LUAU_FASTFLAGVARIABLE(DebugCodegenOptSize) +LUAU_FASTFLAGVARIABLE(DebugCodegenSkipNumbering) // Per-module IR instruction count limit LUAU_FASTINTVARIABLE(CodegenHeuristicsInstructionLimit, 1'048'576) // 1 M @@ -166,7 +166,7 @@ bool isSupported() if (sizeof(LuaNode) != 32) return false; - // Windows CRT uses stack unwinding in longjmp so we have to use unwind data; on other platforms, it's only necessary for C++ EH. + // Windows CRT uses stack unwinding in longjmp so we have to use unwind data; on other platforms, it's only necessary for C++ EH. #if defined(_WIN32) if (!isUnwindSupported()) return false; diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index c423a1ce..6bbdc473 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/CodeGen.h" #include "Luau/BytecodeAnalysis.h" #include "Luau/BytecodeUtils.h" #include "Luau/BytecodeSummary.h" @@ -12,8 +11,6 @@ #include "lapi.h" -LUAU_FASTFLAG(LuauNativeAttribute) - namespace Luau { namespace CodeGen @@ -155,10 +152,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A } std::vector protos; - if (FFlag::LuauNativeAttribute) - gatherFunctions(protos, root, options.compilationOptions.flags, root->flags & LPF_NATIVE_FUNCTION); - else - gatherFunctions_DEPRECATED(protos, root, options.compilationOptions.flags); + gatherFunctions(protos, root, options.compilationOptions.flags, root->flags & LPF_NATIVE_FUNCTION); protos.erase( std::remove_if( diff --git a/CodeGen/src/CodeGenContext.cpp b/CodeGen/src/CodeGenContext.cpp index a463091d..82dfa17e 100644 --- a/CodeGen/src/CodeGenContext.cpp +++ b/CodeGen/src/CodeGenContext.cpp @@ -5,6 +5,7 @@ #include "CodeGenLower.h" #include "CodeGenX64.h" +#include "Luau/CodeGenCommon.h" #include "Luau/CodeBlockUnwind.h" #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" @@ -14,7 +15,6 @@ LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) -LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { @@ -346,13 +346,13 @@ void SharedCodeGenContextDeleter::operator()(const SharedCodeGenContext* codeGen return static_cast(L->global->ecb.context); } -static void onCloseState(lua_State* L) noexcept +static void onCloseState(lua_State* L) { getCodeGenContext(L)->onCloseState(); L->global->ecb = lua_ExecutionCallbacks{}; } -static void onDestroyFunction(lua_State* L, Proto* proto) noexcept +static void onDestroyFunction(lua_State* L, Proto* proto) { getCodeGenContext(L)->onDestroyFunction(proto->execdata); proto->execdata = nullptr; @@ -510,10 +510,7 @@ template return CompilationResult{CodeGenCompilationResult::CodeGenNotInitialized}; std::vector protos; - if (FFlag::LuauNativeAttribute) - gatherFunctions(protos, root, options.flags, root->flags & LPF_NATIVE_FUNCTION); - else - gatherFunctions_DEPRECATED(protos, root, options.flags); + gatherFunctions(protos, root, options.flags, root->flags & LPF_NATIVE_FUNCTION); // Skip protos that have been compiled during previous invocations of CodeGen::compile protos.erase( diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index 9e77844b..c2117a12 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -7,6 +7,7 @@ #include "Luau/IrBuilder.h" #include "Luau/IrDump.h" #include "Luau/IrUtils.h" +#include "Luau/LoweringStats.h" #include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeDeadStore.h" #include "Luau/OptimizeFinalX64.h" @@ -27,31 +28,12 @@ LUAU_FASTFLAG(DebugCodegenSkipNumbering) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) LUAU_FASTINT(CodegenHeuristicsBlockLimit) LUAU_FASTINT(CodegenHeuristicsBlockInstructionLimit) -LUAU_FASTFLAG(LuauNativeAttribute) namespace Luau { namespace CodeGen { -inline void gatherFunctions_DEPRECATED(std::vector& results, Proto* proto, unsigned int flags) -{ - if (results.size() <= size_t(proto->bytecodeid)) - results.resize(proto->bytecodeid + 1); - - // Skip protos that we've already compiled in this run: this happens because at -O2, inlined functions get their protos reused - if (results[proto->bytecodeid]) - return; - - // Only compile cold functions if requested - if ((proto->flags & LPF_NATIVE_COLD) == 0 || (flags & CodeGen_ColdFunctions) != 0) - results[proto->bytecodeid] = proto; - - // Recursively traverse child protos even if we aren't compiling this one - for (int i = 0; i < proto->sizep; i++) - gatherFunctions_DEPRECATED(results, proto->p[i], flags); -} - inline void gatherFunctionsHelper( std::vector& results, Proto* proto, @@ -82,7 +64,6 @@ inline void gatherFunctionsHelper( inline void gatherFunctions(std::vector& results, Proto* root, const unsigned int flags, const bool hasNativeFunctions = false) { - LUAU_ASSERT(FFlag::LuauNativeAttribute); gatherFunctionsHelper(results, root, flags, hasNativeFunctions, true); } @@ -121,7 +102,7 @@ inline bool lowerImpl( bool outputEnabled = options.includeAssembly || options.includeIr; - IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; + IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg, function.proto}; // We use this to skip outlined fallback blocks from IR/asm text output size_t textSize = build.text.length(); @@ -318,6 +299,8 @@ inline bool lowerFunction( CodeGenCompilationResult& codeGenCompilationResult ) { + ir.function.stats = stats; + killUnusedBlocks(ir.function); unsigned preOptBlockCount = 0; diff --git a/CodeGen/src/CodeGenUtils.cpp b/CodeGen/src/CodeGenUtils.cpp index ad231e76..26451eea 100644 --- a/CodeGen/src/CodeGenUtils.cpp +++ b/CodeGen/src/CodeGenUtils.cpp @@ -18,6 +18,8 @@ #include +LUAU_DYNAMIC_FASTFLAG(LuauPopIncompleteCi) + // All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT // This makes sure that we save the pc (in case the Lua call needs to generate a backtrace) before the call, // and restores the stack pointer after in case stack gets reallocated @@ -61,7 +63,7 @@ namespace Luau namespace CodeGen { -bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra) +bool forgLoopTableIter(lua_State* L, LuaTable* h, int index, TValue* ra) { int sizearray = h->sizearray; @@ -104,7 +106,7 @@ bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra) return false; } -bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra) +bool forgLoopNodeIter(lua_State* L, LuaTable* h, int index, TValue* ra) { int sizearray = h->sizearray; int sizenode = 1 << h->lsizenode; @@ -191,7 +193,14 @@ Closure* callProlog(lua_State* L, TValue* ra, StkId argtop, int nresults) // note: this reallocs stack, but we don't need to VM_PROTECT this // this is because we're going to modify base/savedpc manually anyhow // crucially, we can't use ra/argtop after this line - luaD_checkstack(L, ccl->stacksize); + if (DFFlag::LuauPopIncompleteCi) + { + luaD_checkstackfornewci(L, ccl->stacksize); + } + else + { + luaD_checkstack(L, ccl->stacksize); + } return ccl; } @@ -224,11 +233,12 @@ Udata* newUserdata(lua_State* L, size_t s, int tag) { Udata* u = luaU_newudata(L, s, tag); - if (Table* h = L->global->udatamt[tag]) + if (LuaTable* h = L->global->udatamt[tag]) { - u->metatable = h; + // currently, we always allocate unmarked objects, so forward barrier can be skipped + LUAU_ASSERT(!isblack(obj2gco(u))); - luaC_objbarrier(L, u, h); + u->metatable = h; } return u; @@ -260,7 +270,14 @@ Closure* callFallback(lua_State* L, StkId ra, StkId argtop, int nresults) // note: this reallocs stack, but we don't need to VM_PROTECT this // this is because we're going to modify base/savedpc manually anyhow // crucially, we can't use ra/argtop after this line - luaD_checkstack(L, ccl->stacksize); + if (DFFlag::LuauPopIncompleteCi) + { + luaD_checkstackfornewci(L, ccl->stacksize); + } + else + { + luaD_checkstack(L, ccl->stacksize); + } LUAU_ASSERT(ci->top <= L->stack_last); @@ -328,7 +345,7 @@ const Instruction* executeGETGLOBAL(lua_State* L, const Instruction* pc, StkId b LUAU_ASSERT(ttisstring(kv)); // fast-path should already have been checked, so we skip checking for it here - Table* h = cl->env; + LuaTable* h = cl->env; int slot = LUAU_INSN_C(insn) & h->nodemask8; // slow-path, may invoke Lua calls via __index metamethod @@ -351,7 +368,7 @@ const Instruction* executeSETGLOBAL(lua_State* L, const Instruction* pc, StkId b LUAU_ASSERT(ttisstring(kv)); // fast-path should already have been checked, so we skip checking for it here - Table* h = cl->env; + LuaTable* h = cl->env; int slot = LUAU_INSN_C(insn) & h->nodemask8; // slow-path, may invoke Lua calls via __newindex metamethod @@ -377,7 +394,7 @@ const Instruction* executeGETTABLEKS(lua_State* L, const Instruction* pc, StkId // fast-path: built-in table if (ttistable(rb)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); // we ignore the fast path that checks for the cached slot since IrTranslation already checks for it. @@ -489,7 +506,7 @@ const Instruction* executeSETTABLEKS(lua_State* L, const Instruction* pc, StkId // fast-path: built-in table if (ttistable(rb)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); // we ignore the fast path that checks for the cached slot since IrTranslation already checks for it. @@ -574,7 +591,7 @@ const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId ba } else { - Table* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; + LuaTable* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; const TValue* tmi = 0; // fast-path: metatable with __namecall @@ -588,7 +605,7 @@ const Instruction* executeNAMECALL(lua_State* L, const Instruction* pc, StkId ba } else if ((tmi = fasttm(L, mt, TM_INDEX)) && ttistable(tmi)) { - Table* h = hvalue(tmi); + LuaTable* h = hvalue(tmi); int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -645,7 +662,7 @@ const Instruction* executeSETLIST(lua_State* L, const Instruction* pc, StkId bas L->top = L->ci->top; } - Table* h = hvalue(ra); + LuaTable* h = hvalue(ra); // TODO: we really don't need this anymore if (!ttistable(ra)) @@ -680,7 +697,7 @@ const Instruction* executeFORGPREP(lua_State* L, const Instruction* pc, StkId ba } else { - Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); + LuaTable* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(LuaTable*, NULL); if (const TValue* fn = fasttm(L, mt, TM_ITER)) { diff --git a/CodeGen/src/CodeGenUtils.h b/CodeGen/src/CodeGenUtils.h index 15d4c95d..1003a6f3 100644 --- a/CodeGen/src/CodeGenUtils.h +++ b/CodeGen/src/CodeGenUtils.h @@ -8,8 +8,8 @@ namespace Luau namespace CodeGen { -bool forgLoopTableIter(lua_State* L, Table* h, int index, TValue* ra); -bool forgLoopNodeIter(lua_State* L, Table* h, int index, TValue* ra); +bool forgLoopTableIter(lua_State* L, LuaTable* h, int index, TValue* ra); +bool forgLoopNodeIter(lua_State* L, LuaTable* h, int index, TValue* ra); bool forgLoopNonTableFallback(lua_State* L, int insnA, int aux); void forgPrepXnextFallback(lua_State* L, TValue* ra, int pc); diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index 79562b88..36b5130e 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -120,12 +120,12 @@ void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, Regist CODEGEN_ASSERT(tmp != node); CODEGEN_ASSERT(table != node); - build.mov(node, qword[table + offsetof(Table, node)]); + build.mov(node, qword[table + offsetof(LuaTable, node)]); // compute cached slot build.mov(tmp, sCode); build.movzx(dwordReg(tmp), byte[tmp + pcpos * sizeof(Instruction) + kOffsetOfInstructionC]); - build.and_(byteReg(tmp), byte[table + offsetof(Table, nodemask8)]); + build.and_(byteReg(tmp), byte[table + offsetof(LuaTable, nodemask8)]); // LuaNode* n = &h->node[slot]; build.shl(dwordReg(tmp), kLuaNodeSizeLog2); @@ -282,7 +282,7 @@ void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, Regist IrCallWrapperX64 callWrap(regs, build); callWrap.addArgument(SizeX64::qword, rState); callWrap.addArgument(SizeX64::qword, table, tableOp); - callWrap.addArgument(SizeX64::qword, addr[table + offsetof(Table, gclist)]); + callWrap.addArgument(SizeX64::qword, addr[table + offsetof(LuaTable, gclist)]); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaC_barrierback)]); } diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index ae3d1308..207f7f56 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -292,7 +292,7 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int Label skipResize; // Resize if h->sizearray < last - build.cmp(dword[table + offsetof(Table, sizearray)], last); + build.cmp(dword[table + offsetof(LuaTable, sizearray)], last); build.jcc(ConditionX64::NotBelow, skipResize); // Argument setup reordered to avoid conflicts @@ -309,7 +309,7 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int RegisterX64 arrayDst = rdx; RegisterX64 offset = rcx; - build.mov(arrayDst, qword[table + offsetof(Table, array)]); + build.mov(arrayDst, qword[table + offsetof(LuaTable, array)]); const int kUnrollSetListLimit = 4; @@ -380,7 +380,7 @@ void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRep // &array[index] build.mov(dwordReg(elemPtr), dwordReg(index)); build.shl(dwordReg(elemPtr), kTValueSizeLog2); - build.add(elemPtr, qword[table + offsetof(Table, array)]); + build.add(elemPtr, qword[table + offsetof(LuaTable, array)]); // Clear extra variables since we might have more than two for (int i = 2; i < aux; ++i) @@ -391,7 +391,7 @@ void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRep // First we advance index through the array portion // while (unsigned(index) < unsigned(sizearray)) Label arrayLoop = build.setLabel(); - build.cmp(dwordReg(index), dword[table + offsetof(Table, sizearray)]); + build.cmp(dwordReg(index), dword[table + offsetof(LuaTable, sizearray)]); build.jcc(ConditionX64::NotBelow, skipArray); // If element is nil, we increment the index; if it's not, we still need 'index + 1' inside diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 0d2f9bd3..0d4b0a1f 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -684,7 +684,7 @@ void computeCfgDominanceTreeChildren(IrFunction& function) info.domChildrenOffsets[domParent]++; } - // Convert counds to offsets using prefix sum + // Convert counts to offsets using prefix sum uint32_t total = 0; for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 2846db54..dcc9d879 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -4,9 +4,13 @@ #include "Luau/IrUtils.h" #include "lua.h" +#include "lobject.h" +#include "lstate.h" #include +LUAU_FASTFLAG(LuauVectorLibNativeDot); + namespace Luau { namespace CodeGen @@ -17,6 +21,7 @@ static const char* textForCondition[] = static_assert(sizeof(textForCondition) / sizeof(textForCondition[0]) == size_t(IrCondition::Count), "all conditions have to be covered"); const int kDetailsAlignColumn = 60; +const unsigned kMaxStringConstantPrintLength = 16; LUAU_PRINTF_ATTR(2, 3) static void append(std::string& result, const char* fmt, ...) @@ -37,6 +42,17 @@ static void padToDetailColumn(std::string& result, size_t lineStart) result.append(pad, ' '); } +static bool isPrintableStringConstant(const char* str, size_t len) +{ + for (size_t i = 0; i < len; ++i) + { + if (unsigned(str[i]) < ' ') + return false; + } + + return true; +} + static const char* getTagName(uint8_t tag) { switch (tag) @@ -153,6 +169,8 @@ const char* getCmdName(IrCmd cmd) return "ABS_NUM"; case IrCmd::SIGN_NUM: return "SIGN_NUM"; + case IrCmd::SELECT_NUM: + return "SELECT_NUM"; case IrCmd::ADD_VEC: return "ADD_VEC"; case IrCmd::SUB_VEC: @@ -163,6 +181,9 @@ const char* getCmdName(IrCmd cmd) return "DIV_VEC"; case IrCmd::UNM_VEC: return "UNM_VEC"; + case IrCmd::DOT_VEC: + LUAU_ASSERT(FFlag::LuauVectorLibNativeDot); + return "DOT_VEC"; case IrCmd::NOT_ANY: return "NOT_ANY"; case IrCmd::CMP_ANY: @@ -426,6 +447,53 @@ void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index) append(ctx.result, "%s_%u", getBlockKindName(block.kind), index); } +static void appendVmConstant(std::string& result, Proto* proto, int index) +{ + TValue constant = proto->k[index]; + + if (constant.tt == LUA_TNIL) + { + append(result, "nil"); + } + else if (constant.tt == LUA_TBOOLEAN) + { + append(result, constant.value.b != 0 ? "true" : "false"); + } + else if (constant.tt == LUA_TNUMBER) + { + if (constant.value.n != constant.value.n) + append(result, "nan"); + else + append(result, "%.17g", constant.value.n); + } + else if (constant.tt == LUA_TSTRING) + { + TString* str = gco2ts(constant.value.gc); + const char* data = getstr(str); + + if (isPrintableStringConstant(data, str->len)) + { + if (str->len < kMaxStringConstantPrintLength) + append(result, "'%.*s'", int(str->len), data); + else + append(result, "'%.*s'...", int(kMaxStringConstantPrintLength), data); + } + } + else if (constant.tt == LUA_TVECTOR) + { + const float* v = constant.value.v; + +#if LUA_VECTOR_SIZE == 4 + if (v[3] != 0) + append(result, "%.9g, %.9g, %.9g, %.9g", v[0], v[1], v[2], v[3]); + else + append(result, "%.9g, %.9g, %.9g", v[0], v[1], v[2]); +#else + append(result, "%.9g, %.9g, %.9g", v[0], v[1], v[2]); +#endif + } +} + void toString(IrToStringContext& ctx, IrOp op) { switch (op.kind) @@ -453,6 +521,14 @@ void toString(IrToStringContext& ctx, IrOp op) break; case IrOpKind::VmConst: append(ctx.result, "K%d", vmConstOp(op)); + + if (ctx.proto) + { + append(ctx.result, " ("); + appendVmConstant(ctx.result, ctx.proto, vmConstOp(op)); + append(ctx.result, ")"); + } + break; case IrOpKind::VmUpvalue: append(ctx.result, "U%d", vmUpvalueOp(op)); @@ -765,7 +841,7 @@ void toStringDetailed( std::string toString(const IrFunction& function, IncludeUseInfo includeUseInfo) { std::string result; - IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg, function.proto}; for (size_t i = 0; i < function.blocks.size(); i++) { @@ -872,7 +948,7 @@ static void appendBlocks(IrToStringContext& ctx, const IrFunction& function, boo std::string toDot(const IrFunction& function, bool includeInst) { std::string result; - IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg, function.proto}; append(ctx.result, "digraph CFG {\n"); append(ctx.result, "node[shape=record]\n"); @@ -919,7 +995,7 @@ std::string toDot(const IrFunction& function, bool includeInst) std::string toDotCfg(const IrFunction& function) { std::string result; - IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg, function.proto}; append(ctx.result, "digraph CFG {\n"); append(ctx.result, "node[shape=record]\n"); @@ -942,7 +1018,7 @@ std::string toDotCfg(const IrFunction& function) std::string toDotDjGraph(const IrFunction& function) { std::string result; - IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg, function.proto}; append(ctx.result, "digraph CFG {\n"); diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index cd73bcbb..086b91ed 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -4,6 +4,7 @@ #include "Luau/DenseHash.h" #include "Luau/IrData.h" #include "Luau/IrUtils.h" +#include "Luau/LoweringStats.h" #include "EmitCommonA64.h" #include "NativeState.h" @@ -11,7 +12,7 @@ #include "lstate.h" #include "lgc.h" -LUAU_FASTFLAGVARIABLE(LuauCodegenArmNumToVecFix, false) +LUAU_FASTFLAG(LuauVectorLibNativeDot) namespace Luau { @@ -328,7 +329,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) case IrCmd::GET_ARR_ADDR: { inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); - build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, array))); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(LuaTable, array))); if (inst.b.kind == IrOpKind::Inst) { @@ -374,11 +375,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) // C field can be shifted as long as it's at the most significant byte of the instruction word CODEGEN_ASSERT(kOffsetOfInstructionC == 3); - build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, nodemask8))); + build.ldrb(temp2, mem(regOp(inst.a), offsetof(LuaTable, nodemask8))); build.and_(temp2, temp2, temp1w, -24); // note: this may clobber inst.a, so it's important that we don't use it after this - build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(LuaTable, node))); build.add(inst.regA64, inst.regA64, temp2x, kLuaNodeSizeLog2); // "zero extend" temp2 to get a larger shift (top 32 bits are zero) break; } @@ -391,13 +392,13 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) // hash & ((1 << lsizenode) - 1) == hash & ~(-1 << lsizenode) build.mov(temp1, -1); - build.ldrb(temp2, mem(regOp(inst.a), offsetof(Table, lsizenode))); + build.ldrb(temp2, mem(regOp(inst.a), offsetof(LuaTable, lsizenode))); build.lsl(temp1, temp1, temp2); build.mov(temp2, uintOp(inst.b)); build.bic(temp2, temp2, temp1); // note: this may clobber inst.a, so it's important that we don't use it after this - build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(LuaTable, node))); build.add(inst.regA64, inst.regA64, temp2x, kLuaNodeSizeLog2); // "zero extend" temp2 to get a larger shift (top 32 bits are zero) break; } @@ -497,6 +498,13 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.str(temp4, AddressA64(addr.base, addr.data + 4)); build.fcvt(temp4, temp3); build.str(temp4, AddressA64(addr.base, addr.data + 8)); + + if (inst.e.kind != IrOpKind::None) + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, tagOp(inst.e)); + build.str(temp, tempAddr(inst.a, offsetof(TValue, tt))); + } break; } case IrCmd::STORE_TVALUE: @@ -695,6 +703,19 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fcsel(inst.regA64, temp1, inst.regA64, getConditionFP(IrCondition::Less)); break; } + case IrCmd::SELECT_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b, inst.c, inst.d}); + + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + RegisterA64 temp3 = tempDouble(inst.c); + RegisterA64 temp4 = tempDouble(inst.d); + + build.fcmp(temp3, temp4); + build.fcsel(inst.regA64, temp2, temp1, getConditionFP(IrCondition::Equal)); + break; + } case IrCmd::ADD_VEC: { inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b}); @@ -730,6 +751,23 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.fneg(inst.regA64, regOp(inst.a)); break; } + case IrCmd::DOT_VEC: + { + LUAU_ASSERT(FFlag::LuauVectorLibNativeDot); + + inst.regA64 = regs.allocReg(KindA64::d, index); + + RegisterA64 temp = regs.allocTemp(KindA64::q); + RegisterA64 temps = castReg(KindA64::s, temp); + RegisterA64 regs = castReg(KindA64::s, inst.regA64); + + build.fmul(temp, regOp(inst.a), regOp(inst.b)); + build.faddp(regs, temps); // x+y + build.dup_4s(temp, temp, 2); + build.fadd(regs, regs, temps); // +z + build.fcvt(inst.regA64, regs); + break; + } case IrCmd::NOT_ANY: { inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); @@ -1035,10 +1073,10 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp2 = regs.allocTemp(KindA64::w); - build.ldr(temp1, mem(regOp(inst.a), offsetof(Table, metatable))); + build.ldr(temp1, mem(regOp(inst.a), offsetof(LuaTable, metatable))); build.cbz(temp1, labelOp(inst.c)); // no metatable - build.ldrb(temp2, mem(temp1, offsetof(Table, tmcache))); + build.ldrb(temp2, mem(temp1, offsetof(LuaTable, tmcache))); build.tst(temp2, 1 << intOp(inst.b)); // can't use tbz/tbnz because their jump offsets are too short build.b(ConditionA64::NotEqual, labelOp(inst.c)); // Equal = Zero after tst; tmcache caches *absence* of metamethods @@ -1121,7 +1159,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) else { RegisterA64 tempd = tempDouble(inst.a); - RegisterA64 temps = FFlag::LuauCodegenArmNumToVecFix ? regs.allocTemp(KindA64::s) : castReg(KindA64::s, tempd); + RegisterA64 temps = regs.allocTemp(KindA64::s); build.fcvt(temps, tempd); build.dup_4s(inst.regA64, castReg(KindA64::q, temps), 0); @@ -1475,7 +1513,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { Label fresh; // used when guard aborts execution or jumps to a VM exit RegisterA64 temp = regs.allocTemp(KindA64::w); - build.ldrb(temp, mem(regOp(inst.a), offsetof(Table, readonly))); + build.ldrb(temp, mem(regOp(inst.a), offsetof(LuaTable, readonly))); build.cbnz(temp, getTargetLabel(inst.b, fresh)); finalizeTargetLabel(inst.b, fresh); break; @@ -1484,7 +1522,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { Label fresh; // used when guard aborts execution or jumps to a VM exit RegisterA64 temp = regs.allocTemp(KindA64::x); - build.ldr(temp, mem(regOp(inst.a), offsetof(Table, metatable))); + build.ldr(temp, mem(regOp(inst.a), offsetof(LuaTable, metatable))); build.cbnz(temp, getTargetLabel(inst.b, fresh)); finalizeTargetLabel(inst.b, fresh); break; @@ -1495,7 +1533,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) RegisterA64 temp = regs.allocTemp(KindA64::x); RegisterA64 tempw = castReg(KindA64::w, temp); build.ldr(temp, mem(rClosure, offsetof(Closure, env))); - build.ldrb(tempw, mem(temp, offsetof(Table, safeenv))); + build.ldrb(tempw, mem(temp, offsetof(LuaTable, safeenv))); build.cbz(tempw, getTargetLabel(inst.a, fresh)); finalizeTargetLabel(inst.a, fresh); break; @@ -1506,7 +1544,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) Label& fail = getTargetLabel(inst.c, fresh); RegisterA64 temp = regs.allocTemp(KindA64::w); - build.ldr(temp, mem(regOp(inst.a), offsetof(Table, sizearray))); + build.ldr(temp, mem(regOp(inst.a), offsetof(LuaTable, sizearray))); if (inst.b.kind == IrOpKind::Inst) { @@ -1733,7 +1771,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) size_t spills = regs.spill(build, index, {reg}); build.mov(x1, reg); build.mov(x0, rState); - build.add(x2, x1, uint16_t(offsetof(Table, gclist))); + build.add(x2, x1, uint16_t(offsetof(LuaTable, gclist))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierback))); build.blr(x3); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index d06cef13..373f4f59 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -4,6 +4,7 @@ #include "Luau/DenseHash.h" #include "Luau/IrData.h" #include "Luau/IrUtils.h" +#include "Luau/LoweringStats.h" #include "Luau/IrCallWrapperX64.h" @@ -15,6 +16,8 @@ #include "lstate.h" #include "lgc.h" +LUAU_FASTFLAG(LuauVectorLibNativeDot) + namespace Luau { namespace CodeGen @@ -155,13 +158,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.mov(dwordReg(inst.regX64), regOp(inst.b)); build.shl(dwordReg(inst.regX64), kTValueSizeLog2); - build.add(inst.regX64, qword[regOp(inst.a) + offsetof(Table, array)]); + build.add(inst.regX64, qword[regOp(inst.a) + offsetof(LuaTable, array)]); } else if (inst.b.kind == IrOpKind::Constant) { inst.regX64 = regs.allocRegOrReuse(SizeX64::qword, index, {inst.a}); - build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, array)]); + build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(LuaTable, array)]); if (intOp(inst.b) != 0) build.lea(inst.regX64, addr[inst.regX64 + intOp(inst.b) * sizeof(TValue)]); @@ -189,9 +192,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) ScopedRegX64 tmp{regs, SizeX64::qword}; - build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(Table, node)]); + build.mov(inst.regX64, qword[regOp(inst.a) + offsetof(LuaTable, node)]); build.mov(dwordReg(tmp.reg), 1); - build.mov(byteReg(shiftTmp.reg), byte[regOp(inst.a) + offsetof(Table, lsizenode)]); + build.mov(byteReg(shiftTmp.reg), byte[regOp(inst.a) + offsetof(LuaTable, lsizenode)]); build.shl(dwordReg(tmp.reg), byteReg(shiftTmp.reg)); build.dec(dwordReg(tmp.reg)); build.and_(dwordReg(tmp.reg), uintOp(inst.b)); @@ -295,6 +298,9 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 0), inst.b); storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 1), inst.c); storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 2), inst.d); + + if (inst.e.kind != IrOpKind::None) + build.mov(luauRegTag(vmRegOp(inst.a)), tagOp(inst.e)); break; case IrCmd::STORE_TVALUE: { @@ -616,6 +622,29 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vblendvpd(inst.regX64, tmp1.reg, build.f64x2(1, 1), inst.regX64); break; } + case IrCmd::SELECT_NUM: + { + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.c, inst.d}); // can't reuse b if a is a memory operand + + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + if (inst.c.kind == IrOpKind::Inst) + build.vcmpeqsd(tmp.reg, regOp(inst.c), memRegDoubleOp(inst.d)); + else + { + build.vmovsd(tmp.reg, memRegDoubleOp(inst.c)); + build.vcmpeqsd(tmp.reg, tmp.reg, memRegDoubleOp(inst.d)); + } + + if (inst.a.kind == IrOpKind::Inst) + build.vblendvpd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b), tmp.reg); + else + { + build.vmovsd(inst.regX64, memRegDoubleOp(inst.a)); + build.vblendvpd(inst.regX64, inst.regX64, memRegDoubleOp(inst.b), tmp.reg); + } + break; + } case IrCmd::ADD_VEC: { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); @@ -675,6 +704,22 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0)); break; } + case IrCmd::DOT_VEC: + { + LUAU_ASSERT(FFlag::LuauVectorLibNativeDot); + + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); + + ScopedRegX64 tmp1{regs}; + ScopedRegX64 tmp2{regs}; + + RegisterX64 tmpa = vecOp(inst.a, tmp1); + RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2); + + build.vdpps(inst.regX64, tmpa, tmpb, 0x71); // 7 = 0b0111, sum first 3 products into first float + build.vcvtss2sd(inst.regX64, inst.regX64, inst.regX64); + break; + } case IrCmd::NOT_ANY: { // TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target @@ -907,13 +952,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { ScopedRegX64 tmp{regs, SizeX64::qword}; - build.mov(tmp.reg, qword[regOp(inst.a) + offsetof(Table, metatable)]); + build.mov(tmp.reg, qword[regOp(inst.a) + offsetof(LuaTable, metatable)]); regs.freeLastUseReg(function.instOp(inst.a), index); // Release before the call if it's the last use build.test(tmp.reg, tmp.reg); build.jcc(ConditionX64::Zero, labelOp(inst.c)); // No metatable - build.test(byte[tmp.reg + offsetof(Table, tmcache)], 1 << intOp(inst.b)); + build.test(byte[tmp.reg + offsetof(LuaTable, tmcache)], 1 << intOp(inst.b)); build.jcc(ConditionX64::NotZero, labelOp(inst.c)); // No tag method ScopedRegX64 tmp2{regs, SizeX64::qword}; @@ -1273,11 +1318,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) break; } case IrCmd::CHECK_READONLY: - build.cmp(byte[regOp(inst.a) + offsetof(Table, readonly)], 0); + build.cmp(byte[regOp(inst.a) + offsetof(LuaTable, readonly)], 0); jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.b, next); break; case IrCmd::CHECK_NO_METATABLE: - build.cmp(qword[regOp(inst.a) + offsetof(Table, metatable)], 0); + build.cmp(qword[regOp(inst.a) + offsetof(LuaTable, metatable)], 0); jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.b, next); break; case IrCmd::CHECK_SAFE_ENV: @@ -1286,16 +1331,16 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) build.mov(tmp.reg, sClosure); build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); - build.cmp(byte[tmp.reg + offsetof(Table, safeenv)], 0); + build.cmp(byte[tmp.reg + offsetof(LuaTable, safeenv)], 0); jumpOrAbortOnUndef(ConditionX64::Equal, inst.a, next); break; } case IrCmd::CHECK_ARRAY_SIZE: if (inst.b.kind == IrOpKind::Inst) - build.cmp(dword[regOp(inst.a) + offsetof(Table, sizearray)], regOp(inst.b)); + build.cmp(dword[regOp(inst.a) + offsetof(LuaTable, sizearray)], regOp(inst.b)); else if (inst.b.kind == IrOpKind::Constant) - build.cmp(dword[regOp(inst.a) + offsetof(Table, sizearray)], intOp(inst.b)); + build.cmp(dword[regOp(inst.a) + offsetof(LuaTable, sizearray)], intOp(inst.b)); else CODEGEN_ASSERT(!"Unsupported instruction form"); diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index 4471aaa5..15a306c9 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -2,15 +2,15 @@ #include "IrRegAllocA64.h" #include "Luau/AssemblyBuilderA64.h" -#include "Luau/CodeGen.h" #include "Luau/IrUtils.h" +#include "Luau/LoweringStats.h" #include "BitUtils.h" #include "EmitCommonA64.h" #include -LUAU_FASTFLAGVARIABLE(DebugCodegenChaosA64, false) +LUAU_FASTFLAGVARIABLE(DebugCodegenChaosA64) namespace Luau { diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index d647484b..64625868 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -1,8 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrRegAllocX64.h" -#include "Luau/CodeGen.h" #include "Luau/IrUtils.h" +#include "Luau/LoweringStats.h" #include "EmitCommonX64.h" diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index 52efaef1..ec72b692 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -13,6 +13,8 @@ static const int kMinMaxUnrolledParams = 5; static const int kBit32BinaryOpUnrolledParams = 5; +LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot); + namespace Luau { namespace CodeGen @@ -281,6 +283,40 @@ static BuiltinImplResult translateBuiltinMathClamp( return {BuiltinImplType::UsesFallback, 1}; } +static BuiltinImplResult translateBuiltinMathLerp( + IrBuilder& build, + int nparams, + int ra, + int arg, + IrOp args, + IrOp arg3, + int nresults, + IrOp fallback, + int pcpos +) +{ + if (nparams < 3 || nresults > 1) + return {BuiltinImplType::None, -1}; + + builtinCheckDouble(build, build.vmReg(arg), pcpos); + builtinCheckDouble(build, args, pcpos); + builtinCheckDouble(build, arg3, pcpos); + + IrOp a = builtinLoadDouble(build, build.vmReg(arg)); + IrOp b = builtinLoadDouble(build, args); + IrOp t = builtinLoadDouble(build, arg3); + + IrOp l = build.inst(IrCmd::ADD_NUM, a, build.inst(IrCmd::MUL_NUM, build.inst(IrCmd::SUB_NUM, b, a), t)); + IrOp r = build.inst(IrCmd::SELECT_NUM, l, b, t, build.constDouble(1.0)); // select on t==1.0 + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), r); + + if (ra != arg) + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::Full, 1}; +} + static BuiltinImplResult translateBuiltinMathUnary(IrBuilder& build, IrCmd cmd, int nparams, int ra, int arg, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) @@ -885,6 +921,327 @@ static BuiltinImplResult translateBuiltinBufferWrite( return {BuiltinImplType::Full, 0}; } +static BuiltinImplResult translateBuiltinVectorMagnitude( + IrBuilder& build, + int nparams, + int ra, + int arg, + IrOp args, + IrOp arg3, + int nresults, + int pcpos +) +{ + IrOp arg1 = build.vmReg(arg); + + if (nparams != 1 || nresults > 1 || arg1.kind == IrOpKind::Constant) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp sum; + + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); + + sum = build.inst(IrCmd::DOT_VEC, a, a); + } + else + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + } + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), mag); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::Full, 1}; +} + +static BuiltinImplResult translateBuiltinVectorNormalize( + IrBuilder& build, + int nparams, + int ra, + int arg, + IrOp args, + IrOp arg3, + int nresults, + int pcpos +) +{ + IrOp arg1 = build.vmReg(arg); + + if (nparams != 1 || nresults > 1 || arg1.kind == IrOpKind::Constant) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); + + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); + IrOp sum = build.inst(IrCmd::DOT_VEC, a, a); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + IrOp invvec = build.inst(IrCmd::NUM_TO_VEC, inv); + + IrOp result = build.inst(IrCmd::MUL_VEC, a, invvec); + + result = build.inst(IrCmd::TAG_VECTOR, result); + + build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result); + } + else + { + IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x); + IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y); + IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z); + + IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2); + + IrOp mag = build.inst(IrCmd::SQRT_NUM, sum); + IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag); + + IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv); + IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv); + IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + } + + return {BuiltinImplType::Full, 1}; +} + +static BuiltinImplResult translateBuiltinVectorCross(IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos) +{ + IrOp arg1 = build.vmReg(arg); + + if (nparams != 2 || nresults > 1 || arg1.kind == IrOpKind::Constant || args.kind == IrOpKind::Constant) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); + build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); + + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4)); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); + + IrOp y1z2 = build.inst(IrCmd::MUL_NUM, y1, z2); + IrOp z1y2 = build.inst(IrCmd::MUL_NUM, z1, y2); + IrOp xr = build.inst(IrCmd::SUB_NUM, y1z2, z1y2); + + IrOp z1x2 = build.inst(IrCmd::MUL_NUM, z1, x2); + IrOp x1z2 = build.inst(IrCmd::MUL_NUM, x1, z2); + IrOp yr = build.inst(IrCmd::SUB_NUM, z1x2, x1z2); + + IrOp x1y2 = build.inst(IrCmd::MUL_NUM, x1, y2); + IrOp y1x2 = build.inst(IrCmd::MUL_NUM, y1, x2); + IrOp zr = build.inst(IrCmd::SUB_NUM, x1y2, y1x2); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + + return {BuiltinImplType::Full, 1}; +} + +static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams, int ra, int arg, IrOp args, IrOp arg3, int nresults, int pcpos) +{ + IrOp arg1 = build.vmReg(arg); + + if (nparams != 2 || nresults > 1 || arg1.kind == IrOpKind::Constant || args.kind == IrOpKind::Constant) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); + build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp sum; + + if (FFlag::LuauVectorLibNativeDot) + { + IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0)); + IrOp b = build.inst(IrCmd::LOAD_TVALUE, args, build.constInt(0)); + + sum = build.inst(IrCmd::DOT_VEC, a, b); + } + else + { + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); + IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2); + + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4)); + IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2); + + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); + IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2); + + sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz); + } + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); + + return {BuiltinImplType::Full, 1}; +} + +static BuiltinImplResult translateBuiltinVectorMap1( + IrBuilder& build, + IrCmd cmd, + int nparams, + int ra, + int arg, + IrOp args, + IrOp arg3, + int nresults, + int pcpos +) +{ + IrOp arg1 = build.vmReg(arg); + + if (nparams != 1 || nresults > 1 || arg1.kind == IrOpKind::Constant) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + + IrOp xr = build.inst(cmd, x1); + IrOp yr = build.inst(cmd, y1); + IrOp zr = build.inst(cmd, z1); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + + return {BuiltinImplType::Full, 1}; +} + +static BuiltinImplResult translateBuiltinVectorClamp( + IrBuilder& build, + int nparams, + int ra, + int arg, + IrOp args, + IrOp arg3, + int nresults, + IrOp fallback, + int pcpos +) +{ + IrOp arg1 = build.vmReg(arg); + + if (nparams != 3 || nresults > 1 || arg1.kind == IrOpKind::Constant || args.kind == IrOpKind::Constant || arg3.kind == IrOpKind::Constant) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); + build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos)); + build.loadAndCheckTag(arg3, LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp block1 = build.block(IrBlockKind::Internal); + IrOp block2 = build.block(IrBlockKind::Internal); + IrOp block3 = build.block(IrBlockKind::Internal); + + IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp xmin = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); + IrOp xmax = build.inst(IrCmd::LOAD_FLOAT, arg3, build.constInt(0)); + + build.inst(IrCmd::JUMP_CMP_NUM, xmin, xmax, build.cond(IrCondition::NotLessEqual), fallback, block1); + + build.beginBlock(block1); + + IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp ymin = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4)); + IrOp ymax = build.inst(IrCmd::LOAD_FLOAT, arg3, build.constInt(4)); + + build.inst(IrCmd::JUMP_CMP_NUM, ymin, ymax, build.cond(IrCondition::NotLessEqual), fallback, block2); + + build.beginBlock(block2); + + IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + IrOp zmin = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); + IrOp zmax = build.inst(IrCmd::LOAD_FLOAT, arg3, build.constInt(8)); + + build.inst(IrCmd::JUMP_CMP_NUM, zmin, zmax, build.cond(IrCondition::NotLessEqual), fallback, block3); + + build.beginBlock(block3); + + IrOp xtemp = build.inst(IrCmd::MAX_NUM, xmin, x); + IrOp xclamped = build.inst(IrCmd::MIN_NUM, xmax, xtemp); + + IrOp ytemp = build.inst(IrCmd::MAX_NUM, ymin, y); + IrOp yclamped = build.inst(IrCmd::MIN_NUM, ymax, ytemp); + + IrOp ztemp = build.inst(IrCmd::MAX_NUM, zmin, z); + IrOp zclamped = build.inst(IrCmd::MIN_NUM, zmax, ztemp); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xclamped, yclamped, zclamped); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + + return {BuiltinImplType::UsesFallback, 1}; +} + +static BuiltinImplResult translateBuiltinVectorMap2( + IrBuilder& build, + IrCmd cmd, + int nparams, + int ra, + int arg, + IrOp args, + IrOp arg3, + int nresults, + int pcpos +) +{ + IrOp arg1 = build.vmReg(arg); + + if (nparams != 2 || nresults > 1 || arg1.kind == IrOpKind::Constant || args.kind == IrOpKind::Constant) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos)); + build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos)); + + IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0)); + IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4)); + IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8)); + + IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0)); + IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4)); + IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8)); + + IrOp xr = build.inst(cmd, x1, x2); + IrOp yr = build.inst(cmd, y1, y2); + IrOp zr = build.inst(cmd, z1, z2); + + build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr); + build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR)); + + return {BuiltinImplType::Full, 1}; +} + + BuiltinImplResult translateBuiltin( IrBuilder& build, int bfid, @@ -1018,6 +1375,30 @@ BuiltinImplResult translateBuiltin( return translateBuiltinBufferRead(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_READF64, 8, IrCmd::NOP); case LBF_BUFFER_WRITEF64: return translateBuiltinBufferWrite(build, nparams, ra, arg, args, arg3, nresults, pcpos, IrCmd::BUFFER_WRITEF64, 8, IrCmd::NOP); + case LBF_VECTOR_MAGNITUDE: + return translateBuiltinVectorMagnitude(build, nparams, ra, arg, args, arg3, nresults, pcpos); + case LBF_VECTOR_NORMALIZE: + return translateBuiltinVectorNormalize(build, nparams, ra, arg, args, arg3, nresults, pcpos); + case LBF_VECTOR_CROSS: + return translateBuiltinVectorCross(build, nparams, ra, arg, args, arg3, nresults, pcpos); + case LBF_VECTOR_DOT: + return translateBuiltinVectorDot(build, nparams, ra, arg, args, arg3, nresults, pcpos); + case LBF_VECTOR_FLOOR: + return translateBuiltinVectorMap1(build, IrCmd::FLOOR_NUM, nparams, ra, arg, args, arg3, nresults, pcpos); + case LBF_VECTOR_CEIL: + return translateBuiltinVectorMap1(build, IrCmd::CEIL_NUM, nparams, ra, arg, args, arg3, nresults, pcpos); + case LBF_VECTOR_ABS: + return translateBuiltinVectorMap1(build, IrCmd::ABS_NUM, nparams, ra, arg, args, arg3, nresults, pcpos); + case LBF_VECTOR_SIGN: + return translateBuiltinVectorMap1(build, IrCmd::SIGN_NUM, nparams, ra, arg, args, arg3, nresults, pcpos); + case LBF_VECTOR_CLAMP: + return translateBuiltinVectorClamp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos); + case LBF_VECTOR_MIN: + return translateBuiltinVectorMap2(build, IrCmd::MIN_NUM, nparams, ra, arg, args, arg3, nresults, pcpos); + case LBF_VECTOR_MAX: + return translateBuiltinVectorMap2(build, IrCmd::MAX_NUM, nparams, ra, arg, args, arg3, nresults, pcpos); + case LBF_MATH_LERP: + return translateBuiltinMathLerp(build, nparams, ra, arg, args, arg3, nresults, fallback, pcpos); default: return {BuiltinImplType::None, -1}; } diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 62829766..d15d57e2 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -3,7 +3,7 @@ #include "Luau/Bytecode.h" #include "Luau/BytecodeUtils.h" -#include "Luau/CodeGen.h" +#include "Luau/CodeGenOptions.h" #include "Luau/IrBuilder.h" #include "Luau/IrUtils.h" diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index ebf4c34b..54902435 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrUtils.h" +#include "Luau/CodeGenOptions.h" #include "Luau/IrBuilder.h" #include "BitUtils.h" @@ -9,9 +10,14 @@ #include "lua.h" #include "lnumutils.h" +#include +#include + #include #include +LUAU_FASTFLAG(LuauVectorLibNativeDot); + namespace Luau { namespace CodeGen @@ -68,6 +74,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::SIGN_NUM: + case IrCmd::SELECT_NUM: return IrValueKind::Double; case IrCmd::ADD_VEC: case IrCmd::SUB_VEC: @@ -75,6 +82,9 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::DIV_VEC: case IrCmd::UNM_VEC: return IrValueKind::Tvalue; + case IrCmd::DOT_VEC: + LUAU_ASSERT(FFlag::LuauVectorLibNativeDot); + return IrValueKind::Double; case IrCmd::NOT_ANY: case IrCmd::CMP_ANY: return IrValueKind::Int; @@ -651,6 +661,15 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 substitute(function, inst, build.constDouble(v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0)); } break; + case IrCmd::SELECT_NUM: + if (inst.c.kind == IrOpKind::Constant && inst.d.kind == IrOpKind::Constant) + { + double c = function.doubleOp(inst.c); + double d = function.doubleOp(inst.d); + + substitute(function, inst, c == d ? inst.b : inst.a); + } + break; case IrCmd::NOT_ANY: if (inst.a.kind == IrOpKind::Constant) { diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 941db252..b4f74132 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -44,25 +44,25 @@ struct NativeContext void (*luaV_dolen)(lua_State* L, StkId ra, const TValue* rb) = nullptr; void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; - void (*luaV_getimport)(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil) = nullptr; + void (*luaV_getimport)(lua_State* L, LuaTable* env, TValue* k, StkId res, uint32_t id, bool propagatenil) = nullptr; void (*luaV_concat)(lua_State* L, int total, int last) = nullptr; - int (*luaH_getn)(Table* t) = nullptr; - Table* (*luaH_new)(lua_State* L, int narray, int lnhash) = nullptr; - Table* (*luaH_clone)(lua_State* L, Table* tt) = nullptr; - void (*luaH_resizearray)(lua_State* L, Table* t, int nasize) = nullptr; - TValue* (*luaH_setnum)(lua_State* L, Table* t, int key); + int (*luaH_getn)(LuaTable* t) = nullptr; + LuaTable* (*luaH_new)(lua_State* L, int narray, int lnhash) = nullptr; + LuaTable* (*luaH_clone)(lua_State* L, LuaTable* tt) = nullptr; + void (*luaH_resizearray)(lua_State* L, LuaTable* t, int nasize) = nullptr; + TValue* (*luaH_setnum)(lua_State* L, LuaTable* t, int key); - void (*luaC_barriertable)(lua_State* L, Table* t, GCObject* v) = nullptr; + void (*luaC_barriertable)(lua_State* L, LuaTable* t, GCObject* v) = nullptr; void (*luaC_barrierf)(lua_State* L, GCObject* o, GCObject* v) = nullptr; void (*luaC_barrierback)(lua_State* L, GCObject* o, GCObject** gclist) = nullptr; size_t (*luaC_step)(lua_State* L, bool assist) = nullptr; void (*luaF_close)(lua_State* L, StkId level) = nullptr; UpVal* (*luaF_findupval)(lua_State* L, StkId level) = nullptr; - Closure* (*luaF_newLclosure)(lua_State* L, int nelems, Table* e, Proto* p) = nullptr; + Closure* (*luaF_newLclosure)(lua_State* L, int nelems, LuaTable* e, Proto* p) = nullptr; - const TValue* (*luaT_gettm)(Table* events, TMS event, TString* ename) = nullptr; + const TValue* (*luaT_gettm)(LuaTable* events, TMS event, TString* ename) = nullptr; const TString* (*luaT_objtypenamestr)(lua_State* L, const TValue* o) = nullptr; double (*libm_exp)(double) = nullptr; @@ -87,8 +87,8 @@ struct NativeContext double (*libm_modf)(double, double*) = nullptr; // Helper functions - bool (*forgLoopTableIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr; - bool (*forgLoopNodeIter)(lua_State* L, Table* h, int index, TValue* ra) = nullptr; + bool (*forgLoopTableIter)(lua_State* L, LuaTable* h, int index, TValue* ra) = nullptr; + bool (*forgLoopNodeIter)(lua_State* L, LuaTable* h, int index, TValue* ra) = nullptr; bool (*forgLoopNonTableFallback)(lua_State* L, int insnA, int aux) = nullptr; void (*forgPrepXnextFallback)(lua_State* L, TValue* ra, int pc) = nullptr; Closure* (*callProlog)(lua_State* L, TValue* ra, StkId argtop, int nresults) = nullptr; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index f3271d3f..8cdd1dc8 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -9,7 +9,9 @@ #include "lua.h" #include +#include +#include #include #include #include @@ -17,7 +19,9 @@ LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTINTVARIABLE(LuauCodeGenReuseUdataTagLimit, 64) -LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) +LUAU_FASTINTVARIABLE(LuauCodeGenLiveSlotReuseLimit, 8) +LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks) +LUAU_FASTFLAG(LuauVectorLibNativeDot) namespace Luau { @@ -47,6 +51,14 @@ struct RegisterLink uint32_t version = 0; }; +// Reference to an instruction together with the position of that instruction in the current block chain and the last position of reuse +struct NumberedInstruction +{ + uint32_t instIdx = 0; + uint32_t startPos = 0; + uint32_t finishPos = 0; +}; + // Data we know about the current VM state struct ConstPropState { @@ -361,7 +373,7 @@ struct ConstPropState return; // To avoid captured register invalidation tracking in lowering later, values from loads from captured registers are not propagated - // This prevents the case where load value location is linked to memory in case of a spill and is then cloberred in a user call + // This prevents the case where load value location is linked to memory in case of a spill and is then clobbered in a user call if (function.cfg.captured.regs.test(vmRegOp(loadInst.a))) return; @@ -375,7 +387,7 @@ struct ConstPropState if (!instLink.contains(*prevIdx)) createRegLink(*prevIdx, loadInst.a); - // Substitute load instructon with the previous value + // Substitute load instruction with the previous value substitute(function, loadInst, IrOp{IrOpKind::Inst, *prevIdx}); return; } @@ -398,7 +410,7 @@ struct ConstPropState return; // To avoid captured register invalidation tracking in lowering later, values from stores into captured registers are not propagated - // This prevents the case where store creates an alternative value location in case of a spill and is then cloberred in a user call + // This prevents the case where store creates an alternative value location in case of a spill and is then clobbered in a user call if (function.cfg.captured.regs.test(vmRegOp(storeInst.a))) return; @@ -406,12 +418,69 @@ struct ConstPropState valueMap[versionedVmRegLoad(loadCmd, storeInst.a)] = storeInst.b.index; } + // Used to compute the pressure of the cached value 'set' on the spill registers + // We want to find out the maximum live range intersection count between the cached value at 'slot' and current instruction + // Note that this pressure is approximate, as some values that might have been live at one point could have been marked dead later + int getMaxInternalOverlap(std::vector& set, size_t slot) + { + // Start with one live range for the slot we want to reuse + int curr = 1; + + // For any slots where lifetime began before the slot of interest, mark as live if lifetime end is still active + // This saves us from processing slots [0; slot] in the range sweep later, which requires sorting the lifetime end points + for (size_t i = 0; i < slot; i++) + { + if (set[i].finishPos >= set[slot].startPos) + curr++; + } + + int max = curr; + + // Collect lifetime end points and sort them + rangeEndTemp.clear(); + + for (size_t i = slot + 1; i < set.size(); i++) + rangeEndTemp.push_back(set[i].finishPos); + + std::sort(rangeEndTemp.begin(), rangeEndTemp.end()); + + // Go over the lifetime begin/end ranges that we store as separate array and walk based on the smallest of values + for (size_t i1 = slot + 1, i2 = 0; i1 < set.size() && i2 < rangeEndTemp.size();) + { + if (rangeEndTemp[i2] == set[i1].startPos) + { + i1++; + i2++; + } + else if (rangeEndTemp[i2] < set[i1].startPos) + { + CODEGEN_ASSERT(curr > 0); + + curr--; + i2++; + } + else + { + curr++; + i1++; + + if (curr > max) + max = curr; + } + } + + // We might have unprocessed lifetime end entries, but we will never have unprocessed lifetime start entries + // Not that lifetime end entries can only decrease the current value and do not affect the end result (maximum) + return max; + } + void clear() { for (int i = 0; i <= maxReg; ++i) regs[i] = RegisterInfo(); maxReg = 0; + instPos = 0u; inSafeEnv = false; checkedGc = false; @@ -433,6 +502,9 @@ struct ConstPropState // For range/full invalidations, we only want to visit a limited number of data that we have recorded int maxReg = 0; + // Number of the instruction being processed + uint32_t instPos = 0; + bool inSafeEnv = false; bool checkedGc = false; @@ -444,7 +516,7 @@ struct ConstPropState std::vector tryNumToIndexCache; // Fallback block argument might be different // Heap changes might affect table state - std::vector getSlotNodeCache; // Additionally, pcpos argument might be different + std::vector getSlotNodeCache; // Additionally, pcpos argument might be different std::vector checkSlotMatchCache; // Additionally, fallback block argument might be different std::vector getArrAddrCache; @@ -454,6 +526,8 @@ struct ConstPropState // Userdata tag cache can point to both NEW_USERDATA and CHECK_USERDATA_TAG instructions std::vector useradataTagCache; // Additionally, fallback block argument might be different + + std::vector rangeEndTemp; }; static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) @@ -536,6 +610,18 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid case LBF_BUFFER_WRITEF32: case LBF_BUFFER_READF64: case LBF_BUFFER_WRITEF64: + case LBF_VECTOR_MAGNITUDE: + case LBF_VECTOR_NORMALIZE: + case LBF_VECTOR_CROSS: + case LBF_VECTOR_DOT: + case LBF_VECTOR_FLOOR: + case LBF_VECTOR_CEIL: + case LBF_VECTOR_ABS: + case LBF_VECTOR_SIGN: + case LBF_VECTOR_CLAMP: + case LBF_VECTOR_MIN: + case LBF_VECTOR_MAX: + case LBF_MATH_LERP: break; case LBF_TABLE_INSERT: state.invalidateHeap(); @@ -555,6 +641,8 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& function, IrBlock& block, IrInst& inst, uint32_t index) { + state.instPos++; + switch (inst.cmd) { case IrCmd::LOAD_TAG: @@ -757,7 +845,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (tag == LUA_TBOOLEAN && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int))) canSplitTvalueStore = true; - else if (tag == LUA_TNUMBER && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) + else if (tag == LUA_TNUMBER && + (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double))) canSplitTvalueStore = true; else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst) canSplitTvalueStore = true; @@ -1160,29 +1249,82 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.getArrAddrCache.push_back(index); break; case IrCmd::GET_SLOT_NODE_ADDR: - for (uint32_t prevIdx : state.getSlotNodeCache) + for (size_t i = 0; i < state.getSlotNodeCache.size(); i++) { + auto&& [prevIdx, num, lastNum] = state.getSlotNodeCache[i]; + const IrInst& prev = function.instructions[prevIdx]; if (prev.a == inst.a && prev.c == inst.c) { + // Check if this reuse will increase the overall register pressure over the limit + int limit = FInt::LuauCodeGenLiveSlotReuseLimit; + + if (int(state.getSlotNodeCache.size()) > limit && state.getMaxInternalOverlap(state.getSlotNodeCache, i) > limit) + return; + + // Update live range of the value from the optimization standpoint + lastNum = state.instPos; + substitute(function, inst, IrOp{IrOpKind::Inst, prevIdx}); return; // Break out from both the loop and the switch } } if (int(state.getSlotNodeCache.size()) < FInt::LuauCodeGenReuseSlotLimit) - state.getSlotNodeCache.push_back(index); + state.getSlotNodeCache.push_back({index, state.instPos, state.instPos}); break; case IrCmd::GET_HASH_NODE_ADDR: case IrCmd::GET_CLOSURE_UPVAL_ADDR: break; case IrCmd::ADD_INT: case IrCmd::SUB_INT: + state.substituteOrRecord(inst, index); + break; case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + // a + 0.0 and a - (-0.0) can't be folded since the behavior is different for negative zero + // however, a - 0.0 and a + (-0.0) can be folded into a + if (*k == 0.0 && bool(signbit(*k)) == (inst.cmd == IrCmd::ADD_NUM)) + substitute(function, inst, inst.a); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::MUL_NUM: + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + if (*k == 1.0) // a * 1.0 = a + substitute(function, inst, inst.a); + else if (*k == 2.0) // a * 2.0 = a + a + replace(function, block, index, {IrCmd::ADD_NUM, inst.a, inst.a}); + else if (*k == -1.0) // a * -1.0 = -a + replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::DIV_NUM: + if (std::optional k = function.asDoubleOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b))) + { + if (*k == 1.0) // a / 1.0 = a + substitute(function, inst, inst.a); + else if (*k == -1.0) // a / -1.0 = -a + replace(function, block, index, {IrCmd::UNM_NUM, inst.a}); + else if (int exp = 0; frexp(*k, &exp) == 0.5 && exp >= -1000 && exp <= 1000) // a / 2^k = a * 2^-k + replace(function, block, index, {IrCmd::MUL_NUM, inst.a, build.constDouble(1.0 / *k)}); + else + state.substituteOrRecord(inst, index); + } + else + state.substituteOrRecord(inst, index); + break; case IrCmd::IDIV_NUM: case IrCmd::MOD_NUM: case IrCmd::MIN_NUM: @@ -1194,6 +1336,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SQRT_NUM: case IrCmd::ABS_NUM: case IrCmd::SIGN_NUM: + case IrCmd::SELECT_NUM: case IrCmd::NOT_ANY: state.substituteOrRecord(inst, index); break; @@ -1331,6 +1474,10 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SUB_VEC: case IrCmd::MUL_VEC: case IrCmd::DIV_VEC: + case IrCmd::DOT_VEC: + if (inst.cmd == IrCmd::DOT_VEC) + LUAU_ASSERT(FFlag::LuauVectorLibNativeDot); + if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR) replace(function, inst.a, a->a); diff --git a/CodeGen/src/OptimizeDeadStore.cpp b/CodeGen/src/OptimizeDeadStore.cpp index b4b4c7b5..1483e4a2 100644 --- a/CodeGen/src/OptimizeDeadStore.cpp +++ b/CodeGen/src/OptimizeDeadStore.cpp @@ -324,8 +324,21 @@ static bool tryReplaceTagWithFullStore( // And value store has to follow, as the pre-DSO code would not allow GC to observe an incomplete stack variable if (tag != LUA_TNIL && regInfo.valueInstIdx != ~0u) { - IrOp prevValueOp = function.instructions[regInfo.valueInstIdx].b; - replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp}); + IrInst& prevValueInst = function.instructions[regInfo.valueInstIdx]; + + if (prevValueInst.cmd == IrCmd::STORE_VECTOR) + { + CODEGEN_ASSERT(prevValueInst.e.kind == IrOpKind::None); + IrOp prevValueX = prevValueInst.b; + IrOp prevValueY = prevValueInst.c; + IrOp prevValueZ = prevValueInst.d; + replace(function, block, instIndex, IrInst{IrCmd::STORE_VECTOR, targetOp, prevValueX, prevValueY, prevValueZ, tagOp}); + } + else + { + IrOp prevValueOp = prevValueInst.b; + replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, tagOp, prevValueOp}); + } } state.killTagStore(regInfo); @@ -356,6 +369,25 @@ static bool tryReplaceTagWithFullStore( state.killTValueStore(regInfo); + regInfo.tvalueInstIdx = instIndex; + regInfo.maybeGco = isGCO(tag); + regInfo.knownTag = tag; + state.hasGcoToClear |= regInfo.maybeGco; + return true; + } + else if (prev.cmd == IrCmd::STORE_VECTOR) + { + // If the 'nil' is stored, we keep 'STORE_TAG Rn, tnil' as it writes the 'full' TValue + if (tag != LUA_TNIL) + { + IrOp prevValueX = prev.b; + IrOp prevValueY = prev.c; + IrOp prevValueZ = prev.d; + replace(function, block, instIndex, IrInst{IrCmd::STORE_VECTOR, targetOp, prevValueX, prevValueY, prevValueZ, tagOp}); + } + + state.killTValueStore(regInfo); + regInfo.tvalueInstIdx = instIndex; regInfo.maybeGco = isGCO(tag); regInfo.knownTag = tag; @@ -410,6 +442,92 @@ static bool tryReplaceValueWithFullStore( state.killTValueStore(regInfo); + regInfo.tvalueInstIdx = instIndex; + return true; + } + else if (prev.cmd == IrCmd::STORE_VECTOR) + { + IrOp prevTagOp = prev.e; + CODEGEN_ASSERT(prevTagOp.kind != IrOpKind::None); + uint8_t prevTag = function.tagOp(prevTagOp); + + CODEGEN_ASSERT(regInfo.knownTag == prevTag); + replace(function, block, instIndex, IrInst{IrCmd::STORE_SPLIT_TVALUE, targetOp, prevTagOp, valueOp}); + + state.killTValueStore(regInfo); + + regInfo.tvalueInstIdx = instIndex; + return true; + } + } + + return false; +} + +static bool tryReplaceVectorValueWithFullStore( + RemoveDeadStoreState& state, + IrBuilder& build, + IrFunction& function, + IrBlock& block, + uint32_t instIndex, + StoreRegInfo& regInfo +) +{ + // If the tag+value pair is established, we can mark both as dead and use a single split TValue store + if (regInfo.tagInstIdx != ~0u && regInfo.valueInstIdx != ~0u) + { + IrOp prevTagOp = function.instructions[regInfo.tagInstIdx].b; + uint8_t prevTag = function.tagOp(prevTagOp); + + CODEGEN_ASSERT(regInfo.knownTag == prevTag); + + IrInst& storeInst = function.instructions[instIndex]; + CODEGEN_ASSERT(storeInst.cmd == IrCmd::STORE_VECTOR); + replace(function, storeInst.e, prevTagOp); + + state.killTagStore(regInfo); + state.killValueStore(regInfo); + + regInfo.tvalueInstIdx = instIndex; + return true; + } + + // We can also replace a dead split TValue store with a new one, while keeping the value the same + if (regInfo.tvalueInstIdx != ~0u) + { + IrInst& prev = function.instructions[regInfo.tvalueInstIdx]; + + if (prev.cmd == IrCmd::STORE_SPLIT_TVALUE) + { + IrOp prevTagOp = prev.b; + uint8_t prevTag = function.tagOp(prevTagOp); + + CODEGEN_ASSERT(regInfo.knownTag == prevTag); + CODEGEN_ASSERT(prev.d.kind == IrOpKind::None); + + IrInst& storeInst = function.instructions[instIndex]; + CODEGEN_ASSERT(storeInst.cmd == IrCmd::STORE_VECTOR); + replace(function, storeInst.e, prevTagOp); + + state.killTValueStore(regInfo); + + regInfo.tvalueInstIdx = instIndex; + return true; + } + else if (prev.cmd == IrCmd::STORE_VECTOR) + { + IrOp prevTagOp = prev.e; + CODEGEN_ASSERT(prevTagOp.kind != IrOpKind::None); + uint8_t prevTag = function.tagOp(prevTagOp); + + CODEGEN_ASSERT(regInfo.knownTag == prevTag); + + IrInst& storeInst = function.instructions[instIndex]; + CODEGEN_ASSERT(storeInst.cmd == IrCmd::STORE_VECTOR); + replace(function, storeInst.e, prevTagOp); + + state.killTValueStore(regInfo); + regInfo.tvalueInstIdx = instIndex; return true; } @@ -499,10 +617,24 @@ static void markDeadStoresInInst(RemoveDeadStoreState& state, IrBuilder& build, } break; case IrCmd::STORE_VECTOR: - // Partial vector value store cannot be combined into a STORE_SPLIT_TVALUE, so we skip dead store optimization for it if (inst.a.kind == IrOpKind::VmReg) { - state.useReg(vmRegOp(inst.a)); + int reg = vmRegOp(inst.a); + + if (function.cfg.captured.regs.test(reg)) + return; + + StoreRegInfo& regInfo = state.info[reg]; + + if (tryReplaceVectorValueWithFullStore(state, build, function, block, index, regInfo)) + break; + + // Partial value store can be removed by a new one if the tag is known + if (regInfo.knownTag != kUnknownTag) + state.killValueStore(regInfo); + + regInfo.valueInstIdx = index; + regInfo.maybeGco = false; } break; case IrCmd::STORE_TVALUE: diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 82185e7f..a151056c 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -600,6 +600,22 @@ enum LuauBuiltinFunction LBF_BUFFER_WRITEF32, LBF_BUFFER_READF64, LBF_BUFFER_WRITEF64, + + // vector. + LBF_VECTOR_MAGNITUDE, + LBF_VECTOR_NORMALIZE, + LBF_VECTOR_CROSS, + LBF_VECTOR_DOT, + LBF_VECTOR_FLOOR, + LBF_VECTOR_CEIL, + LBF_VECTOR_ABS, + LBF_VECTOR_SIGN, + LBF_VECTOR_CLAMP, + LBF_VECTOR_MIN, + LBF_VECTOR_MAX, + + // math.lerp + LBF_MATH_LERP, }; // Capture type, used in LOP_CAPTURE diff --git a/Common/include/Luau/Common.h b/Common/include/Luau/Common.h index 2f4f1df8..b4bbf0f7 100644 --- a/Common/include/Luau/Common.h +++ b/Common/include/Luau/Common.h @@ -106,10 +106,10 @@ FValue* FValue::list = nullptr; { \ extern Luau::FValue flag; \ } -#define LUAU_FASTFLAGVARIABLE(flag, def) \ +#define LUAU_FASTFLAGVARIABLE(flag) \ namespace FFlag \ { \ - Luau::FValue flag(#flag, def, false); \ + Luau::FValue flag(#flag, false, false); \ } #define LUAU_FASTINT(flag) \ namespace FInt \ diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index c534bcb4..68ae1e8c 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,7 +13,8 @@ inline bool isFlagExperimental(const char* flag) static const char* const kList[] = { "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative - "StudioReportLuauAny2", // takes telemetry data for usage of any types + "StudioReportLuauAny2", // takes telemetry data for usage of any types + "LuauTableCloneClonesType3", // requires fixes in lua-apps code, terrifyingly "LuauSolverV2", // makes sure we always have at least one entry nullptr, diff --git a/Common/include/Luau/Variant.h b/Common/include/Luau/Variant.h index 88722257..14eb8c4e 100644 --- a/Common/include/Luau/Variant.h +++ b/Common/include/Luau/Variant.h @@ -19,7 +19,7 @@ class Variant static_assert(std::disjunction_v...> == false, "variant does not allow references as an alternative type"); static_assert(std::disjunction_v...> == false, "variant does not allow arrays as an alternative type"); -private: +public: template static constexpr int getTypeId() { @@ -35,6 +35,7 @@ private: return -1; } +private: template struct First { diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 403fa6dd..2c82116d 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -13,6 +13,16 @@ struct ParseResult; class BytecodeBuilder; class BytecodeEncoder; +using CompileConstant = void*; + +// return a type identifier for a global library member +// values are defined by 'enum LuauBytecodeType' in Bytecode.h +using LibraryMemberTypeCallback = int (*)(const char* library, const char* member); + +// setup a value of a constant for a global library member +// use setCompileConstant*** set of functions for values +using LibraryMemberConstantCallback = void (*)(const char* library, const char* member, CompileConstant* constant); + // Note: this structure is duplicated in luacode.h, don't forget to change these in sync! struct CompileOptions { @@ -37,11 +47,11 @@ struct CompileOptions // 2 - statement and expression coverage (verbose) int coverageLevel = 0; - // global builtin to construct vectors; disabled by default + // alternative global builtin to construct vectors, in addition to default builtin 'vector.create' const char* vectorLib = nullptr; const char* vectorCtor = nullptr; - // vector type name for type tables; disabled by default + // alternative vector type name for type tables, in addition to default type 'vector' const char* vectorType = nullptr; // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these @@ -49,6 +59,15 @@ struct CompileOptions // null-terminated array of userdata types that will be included in the type information const char* const* userdataTypes = nullptr; + + // null-terminated array of globals which act as libraries and have members with known type and/or constant value + // when an import of one of these libraries is accessed, callbacks below will be called to receive that information + const char* const* librariesWithKnownMembers = nullptr; + LibraryMemberTypeCallback libraryMemberTypeCb = nullptr; + LibraryMemberConstantCallback libraryMemberConstantCb = nullptr; + + // null-terminated array of library functions that should not be compiled into a built-in fastcall ("name" "lib.name") + const char* const* disabledBuiltins = nullptr; }; class CompileError : public std::exception @@ -81,4 +100,10 @@ std::string compile( BytecodeEncoder* encoder = nullptr ); +void setCompileConstantNil(CompileConstant* constant); +void setCompileConstantBoolean(CompileConstant* constant, bool b); +void setCompileConstantNumber(CompileConstant* constant, double n); +void setCompileConstantVector(CompileConstant* constant, float x, float y, float z, float w); +void setCompileConstantString(CompileConstant* constant, const char* s, size_t l); + } // namespace Luau diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h index 1440a699..4445af43 100644 --- a/Compiler/include/luacode.h +++ b/Compiler/include/luacode.h @@ -3,12 +3,21 @@ #include -// Can be used to reconfigure visibility/exports for public APIs +// can be used to reconfigure visibility/exports for public APIs #ifndef LUACODE_API #define LUACODE_API extern #endif typedef struct lua_CompileOptions lua_CompileOptions; +typedef void* lua_CompileConstant; + +// return a type identifier for a global library member +// values are defined by 'enum LuauBytecodeType' in Bytecode.h +typedef int (*lua_LibraryMemberTypeCallback)(const char* library, const char* member); + +// setup a value of a constant for a global library member +// use luau_set_compile_constant_*** set of functions for values +typedef void (*lua_LibraryMemberConstantCallback)(const char* library, const char* member, lua_CompileConstant* constant); struct lua_CompileOptions { @@ -33,11 +42,11 @@ struct lua_CompileOptions // 2 - statement and expression coverage (verbose) int coverageLevel; // default=0 - // global builtin to construct vectors; disabled by default + // alternative global builtin to construct vectors, in addition to default builtin 'vector.create' const char* vectorLib; const char* vectorCtor; - // vector type name for type tables; disabled by default + // alternative vector type name for type tables, in addition to default type 'vector' const char* vectorType; // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these @@ -45,7 +54,25 @@ struct lua_CompileOptions // null-terminated array of userdata types that will be included in the type information const char* const* userdataTypes; + + // null-terminated array of globals which act as libraries and have members with known type and/or constant value + // when an import of one of these libraries is accessed, callbacks below will be called to receive that information + const char* const* librariesWithKnownMembers; + lua_LibraryMemberTypeCallback libraryMemberTypeCb; + lua_LibraryMemberConstantCallback libraryMemberConstantCb; + + // null-terminated array of library functions that should not be compiled into a built-in fastcall ("name" "lib.name") + const char* const* disabledBuiltins; }; // compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize); + +// when libraryMemberConstantCb is called, these methods can be used to set a value of the opaque lua_CompileConstant struct +// vector component 'w' is not visible to VM runtime configured with LUA_VECTOR_SIZE == 3, but can affect constant folding during compilation +// string storage must outlive the invocation of 'luau_compile' which used the callback +LUACODE_API void luau_set_compile_constant_nil(lua_CompileConstant* constant); +LUACODE_API void luau_set_compile_constant_boolean(lua_CompileConstant* constant, int b); +LUACODE_API void luau_set_compile_constant_number(lua_CompileConstant* constant, double n); +LUACODE_API void luau_set_compile_constant_vector(lua_CompileConstant* constant, float x, float y, float z, float w); +LUACODE_API void luau_set_compile_constant_string(lua_CompileConstant* constant, const char* s, size_t l); diff --git a/Compiler/src/BuiltinFolding.cpp b/Compiler/src/BuiltinFolding.cpp index 0886e94a..28b812f5 100644 --- a/Compiler/src/BuiltinFolding.cpp +++ b/Compiler/src/BuiltinFolding.cpp @@ -471,14 +471,28 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count) break; case LBF_VECTOR: - if (count >= 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number) + if (count >= 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number) { - if (count == 3) + if (count == 2) + return cvector(args[0].valueNumber, args[1].valueNumber, 0.0, 0.0); + else if (count == 3 && args[2].type == Constant::Type_Number) return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, 0.0); - else if (count == 4 && args[3].type == Constant::Type_Number) + else if (count == 4 && args[2].type == Constant::Type_Number && args[3].type == Constant::Type_Number) return cvector(args[0].valueNumber, args[1].valueNumber, args[2].valueNumber, args[3].valueNumber); } break; + + case LBF_MATH_LERP: + if (count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number) + { + double a = args[0].valueNumber; + double b = args[1].valueNumber; + double t = args[2].valueNumber; + + double v = (t == 1.0) ? b : a + (b - a) * t; + return cnum(v); + } + break; } return cvar(); diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 90bf72c4..bc342bd3 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -3,6 +3,9 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +#include "Luau/Lexer.h" + +#include namespace Luau { @@ -134,6 +137,8 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op return LBF_MATH_SIGN; if (builtin.method == "round") return LBF_MATH_ROUND; + if (builtin.method == "lerp") + return LBF_MATH_LERP; } if (builtin.object == "bit32") @@ -220,6 +225,34 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op return LBF_BUFFER_WRITEF64; } + if (builtin.object == "vector") + { + if (builtin.method == "create") + return LBF_VECTOR; + if (builtin.method == "magnitude") + return LBF_VECTOR_MAGNITUDE; + if (builtin.method == "normalize") + return LBF_VECTOR_NORMALIZE; + if (builtin.method == "cross") + return LBF_VECTOR_CROSS; + if (builtin.method == "dot") + return LBF_VECTOR_DOT; + if (builtin.method == "floor") + return LBF_VECTOR_FLOOR; + if (builtin.method == "ceil") + return LBF_VECTOR_CEIL; + if (builtin.method == "abs") + return LBF_VECTOR_ABS; + if (builtin.method == "sign") + return LBF_VECTOR_SIGN; + if (builtin.method == "clamp") + return LBF_VECTOR_CLAMP; + if (builtin.method == "min") + return LBF_VECTOR_MIN; + if (builtin.method == "max") + return LBF_VECTOR_MAX; + } + if (options.vectorCtor) { if (options.vectorLib) @@ -240,23 +273,58 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op struct BuiltinVisitor : AstVisitor { DenseHashMap& result; + std::array builtinIsDisabled; const DenseHashMap& globals; const DenseHashMap& variables; const CompileOptions& options; + const AstNameTable& names; BuiltinVisitor( DenseHashMap& result, const DenseHashMap& globals, const DenseHashMap& variables, - const CompileOptions& options + const CompileOptions& options, + const AstNameTable& names ) : result(result) , globals(globals) , variables(variables) , options(options) + , names(names) { + builtinIsDisabled.fill(false); + + if (const char* const* ptr = options.disabledBuiltins) + { + for (; *ptr; ++ptr) + { + if (const char* dot = strchr(*ptr, '.')) + { + AstName library = names.getWithType(*ptr, dot - *ptr).first; + AstName name = names.get(dot + 1); + + if (library.value && name.value && getGlobalState(globals, name) == Global::Default) + { + Builtin builtin = Builtin{library, name}; + + if (int bfid = getBuiltinFunctionId(builtin, options); bfid >= 0) + builtinIsDisabled[bfid] = true; + } + } + else + { + if (AstName name = names.get(*ptr); name.value && getGlobalState(globals, name) == Global::Default) + { + Builtin builtin = Builtin{AstName(), name}; + + if (int bfid = getBuiltinFunctionId(builtin, options); bfid >= 0) + builtinIsDisabled[bfid] = true; + } + } + } + } } bool visit(AstExprCall* node) override @@ -267,6 +335,9 @@ struct BuiltinVisitor : AstVisitor int bfid = getBuiltinFunctionId(builtin, options); + if (bfid >= 0 && builtinIsDisabled[bfid]) + bfid = -1; + // getBuiltinFunctionId optimistically assumes all select() calls are builtin but actually the second argument must be a vararg if (bfid == LBF_SELECT_VARARG && !(node->args.size == 2 && node->args.data[1]->is())) bfid = -1; @@ -283,10 +354,11 @@ void analyzeBuiltins( const DenseHashMap& globals, const DenseHashMap& variables, const CompileOptions& options, - AstNode* root + AstNode* root, + const AstNameTable& names ) { - BuiltinVisitor visitor{result, globals, variables, options}; + BuiltinVisitor visitor{result, globals, variables, options, names}; root->visit(&visitor); } @@ -463,6 +535,26 @@ BuiltinInfo getBuiltinInfo(int bfid) case LBF_BUFFER_WRITEF32: case LBF_BUFFER_WRITEF64: return {3, 0, BuiltinInfo::Flag_NoneSafe}; + + case LBF_VECTOR_MAGNITUDE: + case LBF_VECTOR_NORMALIZE: + return {1, 1, BuiltinInfo::Flag_NoneSafe}; + case LBF_VECTOR_CROSS: + case LBF_VECTOR_DOT: + return {2, 1, BuiltinInfo::Flag_NoneSafe}; + case LBF_VECTOR_FLOOR: + case LBF_VECTOR_CEIL: + case LBF_VECTOR_ABS: + case LBF_VECTOR_SIGN: + return {1, 1, BuiltinInfo::Flag_NoneSafe}; + case LBF_VECTOR_CLAMP: + return {3, 1, BuiltinInfo::Flag_NoneSafe}; + case LBF_VECTOR_MIN: + case LBF_VECTOR_MAX: + return {-1, 1}; // variadic + + case LBF_MATH_LERP: + return {3, 1, BuiltinInfo::Flag_NoneSafe}; } LUAU_UNREACHABLE(); diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h index e6427c2a..cef48fa5 100644 --- a/Compiler/src/Builtins.h +++ b/Compiler/src/Builtins.h @@ -41,7 +41,8 @@ void analyzeBuiltins( const DenseHashMap& globals, const DenseHashMap& variables, const CompileOptions& options, - AstNode* root + AstNode* root, + const AstNameTable& names ); struct BuiltinInfo diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 685d94fa..46985628 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1751,7 +1751,8 @@ void BytecodeBuilder::validateVariadic() const // variadic sequence since they are never executed if FASTCALL does anything, so it's okay to skip their validation until CALL // (we can't simply start a variadic sequence here because that would trigger assertions during linked CALL validation) } - else if (op == LOP_CLOSEUPVALS || op == LOP_NAMECALL || op == LOP_GETIMPORT || op == LOP_MOVE || op == LOP_GETUPVAL || op == LOP_GETGLOBAL || op == LOP_GETTABLEKS || op == LOP_COVERAGE) + else if (op == LOP_CLOSEUPVALS || op == LOP_NAMECALL || op == LOP_GETIMPORT || op == LOP_MOVE || op == LOP_GETUPVAL || op == LOP_GETGLOBAL || + op == LOP_GETTABLEKS || op == LOP_COVERAGE) { // instructions inside a variadic sequence must be neutral (can't change L->top) // while there are many neutral instructions like this, here we check that the instruction is one of the few diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 7ed70d14..5ef2bf93 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,8 +26,6 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAG(LuauNativeAttribute) - namespace Luau { @@ -285,7 +283,7 @@ struct Compiler if (func->functionDepth == 0 && !hasLoops) protoflags |= LPF_NATIVE_COLD; - if (FFlag::LuauNativeAttribute && func->hasNativeAttribute()) + if (func->hasNativeAttribute()) protoflags |= LPF_NATIVE_FUNCTION; bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size()), protoflags); @@ -725,7 +723,7 @@ struct Compiler inlineFrames.push_back({func, oldLocals, target, targetCount}); // fold constant values updated above into expressions in the function body - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, func->body); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, func->body); bool usedFallthrough = false; @@ -770,7 +768,7 @@ struct Compiler var->type = Constant::Type_Unknown; } - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, func->body); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, func->body); } void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) @@ -1623,6 +1621,24 @@ struct Compiler return; } } + else if (options.optimizationLevel >= 2 && (expr->op == AstExprBinary::Add || expr->op == AstExprBinary::Mul)) + { + // Optimization: replace k*r with r*k when r is known to be a number (otherwise metamethods may be called) + if (LuauBytecodeType* ty = exprTypes.find(expr); ty && *ty == LBC_TYPE_NUMBER) + { + int32_t lc = getConstantNumber(expr->left); + + if (lc >= 0 && lc <= 255) + { + uint8_t rr = compileExprAuto(expr->right, rs); + + bytecode.emitABC(getBinaryOpArith(expr->op, /* k= */ true), target, rr, uint8_t(lc)); + + hintTemporaryExprRegType(expr->right, rr, LBC_TYPE_NUMBER, /* instLength */ 1); + return; + } + } + } uint8_t rl = compileExprAuto(expr->left, rs); uint8_t rr = compileExprAuto(expr->right, rs); @@ -3034,7 +3050,7 @@ struct Compiler locstants[var].type = Constant::Type_Number; locstants[var].valueNumber = from + iv * step; - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, stat); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, stat); size_t iterJumps = loopJumps.size(); @@ -3062,7 +3078,7 @@ struct Compiler // clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again locstants[var].type = Constant::Type_Unknown; - foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldMathK, stat); + foldConstants(constants, variables, locstants, builtinsFold, builtinsFoldLibraryK, options.libraryMemberConstantCb, stat); } void compileStatFor(AstStatFor* stat) @@ -3634,6 +3650,10 @@ struct Compiler { // do nothing } + else if (node->is()) + { + // do nothing + } else { LUAU_ASSERT(!"Unknown statement type"); @@ -3904,7 +3924,7 @@ struct Compiler // this makes sure all functions that are used when compiling this one have been already added to the vector functions.push_back(node); - if (FFlag::LuauNativeAttribute && !hasNativeFunction && node->hasNativeAttribute()) + if (!hasNativeFunction && node->hasNativeAttribute()) hasNativeFunction = true; return false; @@ -4119,7 +4139,7 @@ struct Compiler BuiltinAstTypes builtinTypes; const DenseHashMap* builtinsFold = nullptr; - bool builtinsFoldMathK = false; + bool builtinsFoldLibraryK = false; // compileFunction state, gets reset for every function unsigned int regTop = 0; @@ -4199,16 +4219,37 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c compiler.builtinsFold = &compiler.builtins; if (AstName math = names.get("math"); math.value && getGlobalState(compiler.globals, math) == Global::Default) - compiler.builtinsFoldMathK = true; + { + compiler.builtinsFoldLibraryK = true; + } + else if (const char* const* ptr = options.librariesWithKnownMembers) + { + for (; *ptr; ++ptr) + { + if (AstName name = names.get(*ptr); name.value && getGlobalState(compiler.globals, name) == Global::Default) + { + compiler.builtinsFoldLibraryK = true; + break; + } + } + } } if (options.optimizationLevel >= 1) { // this pass tracks which calls are builtins and can be compiled more efficiently - analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root); + analyzeBuiltins(compiler.builtins, compiler.globals, compiler.variables, options, root, names); // this pass analyzes constantness of expressions - foldConstants(compiler.constants, compiler.variables, compiler.locstants, compiler.builtinsFold, compiler.builtinsFoldMathK, root); + foldConstants( + compiler.constants, + compiler.variables, + compiler.locstants, + compiler.builtinsFold, + compiler.builtinsFoldLibraryK, + options.libraryMemberConstantCb, + root + ); // this pass analyzes table assignments to estimate table shapes for initially empty tables predictTableShapes(compiler.tableShapes, root); @@ -4239,6 +4280,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c compiler.builtinTypes, compiler.builtins, compiler.globals, + options.libraryMemberTypeCb, bytecode ); @@ -4249,15 +4291,15 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c // If a function has native attribute and the whole module is not native, we set LPF_NATIVE_FUNCTION flag // This ensures that LPF_NATIVE_MODULE and LPF_NATIVE_FUNCTION are exclusive. - if (FFlag::LuauNativeAttribute && (protoflags & LPF_NATIVE_FUNCTION) && !(mainFlags & LPF_NATIVE_MODULE)) + if ((protoflags & LPF_NATIVE_FUNCTION) && !(mainFlags & LPF_NATIVE_MODULE)) mainFlags |= LPF_NATIVE_FUNCTION; } AstExprFunction main( root->location, - /*attributes=*/AstArray({nullptr, 0}), - /*generics= */ AstArray(), - /*genericPacks= */ AstArray(), + /* attributes= */ AstArray({nullptr, 0}), + /* generics= */ AstArray(), + /* genericPacks= */ AstArray(), /* self= */ nullptr, AstArray(), /* vararg= */ true, @@ -4318,4 +4360,50 @@ std::string compile(const std::string& source, const CompileOptions& options, co } } +void setCompileConstantNil(CompileConstant* constant) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Nil; +} + +void setCompileConstantBoolean(CompileConstant* constant, bool b) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Boolean; + target->valueBoolean = b; +} + +void setCompileConstantNumber(CompileConstant* constant, double n) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Number; + target->valueNumber = n; +} + +void setCompileConstantVector(CompileConstant* constant, float x, float y, float z, float w) +{ + Compile::Constant* target = reinterpret_cast(constant); + + target->type = Compile::Constant::Type_Vector; + target->valueVector[0] = x; + target->valueVector[1] = y; + target->valueVector[2] = z; + target->valueVector[3] = w; +} + +void setCompileConstantString(CompileConstant* constant, const char* s, size_t l) +{ + Compile::Constant* target = reinterpret_cast(constant); + + if (l > std::numeric_limits::max()) + CompileError::raise({}, "Exceeded custom string constant length limit"); + + target->type = Compile::Constant::Type_String; + target->stringLength = unsigned(l); + target->valueString = s; +} + } // namespace Luau diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 2895bf08..818a5bf7 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -57,6 +57,14 @@ static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg result.type = Constant::Type_Number; result.valueNumber = -arg.valueNumber; } + else if (arg.type == Constant::Type_Vector) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = -arg.valueVector[0]; + result.valueVector[1] = -arg.valueVector[1]; + result.valueVector[2] = -arg.valueVector[2]; + result.valueVector[3] = -arg.valueVector[3]; + } break; case AstExprUnary::Len: @@ -82,6 +90,14 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber + ra.valueNumber; } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] + ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] + ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] + ra.valueVector[2]; + result.valueVector[3] = la.valueVector[3] + ra.valueVector[3]; + } break; case AstExprBinary::Sub: @@ -90,6 +106,14 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber - ra.valueNumber; } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] - ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] - ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] - ra.valueVector[2]; + result.valueVector[3] = la.valueVector[3] - ra.valueVector[3]; + } break; case AstExprBinary::Mul: @@ -98,6 +122,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber * ra.valueNumber; } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] * ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] * ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] * ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] * ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Number && ra.type == Constant::Type_Vector) + { + bool hadW = ra.valueVector[3] != 0.0f; + float resultW = float(la.valueNumber) * ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = float(la.valueNumber) * ra.valueVector[0]; + result.valueVector[1] = float(la.valueNumber) * ra.valueVector[1]; + result.valueVector[2] = float(la.valueNumber) * ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Number) + { + bool hadW = la.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] * float(ra.valueNumber); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] * float(ra.valueNumber); + result.valueVector[1] = la.valueVector[1] * float(ra.valueNumber); + result.valueVector[2] = la.valueVector[2] * float(ra.valueNumber); + result.valueVector[3] = resultW; + } + } break; case AstExprBinary::Div: @@ -106,6 +172,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = la.valueNumber / ra.valueNumber; } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] / ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] / ra.valueVector[0]; + result.valueVector[1] = la.valueVector[1] / ra.valueVector[1]; + result.valueVector[2] = la.valueVector[2] / ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Number && ra.type == Constant::Type_Vector) + { + bool hadW = ra.valueVector[3] != 0.0f; + float resultW = float(la.valueNumber) / ra.valueVector[3]; + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = float(la.valueNumber) / ra.valueVector[0]; + result.valueVector[1] = float(la.valueNumber) / ra.valueVector[1]; + result.valueVector[2] = float(la.valueNumber) / ra.valueVector[2]; + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Number) + { + bool hadW = la.valueVector[3] != 0.0f; + float resultW = la.valueVector[3] / float(ra.valueNumber); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = la.valueVector[0] / float(ra.valueNumber); + result.valueVector[1] = la.valueVector[1] / float(ra.valueNumber); + result.valueVector[2] = la.valueVector[2] / float(ra.valueNumber); + result.valueVector[3] = resultW; + } + } break; case AstExprBinary::FloorDiv: @@ -114,6 +222,48 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l result.type = Constant::Type_Number; result.valueNumber = floor(la.valueNumber / ra.valueNumber); } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Vector) + { + bool hadW = la.valueVector[3] != 0.0f || ra.valueVector[3] != 0.0f; + float resultW = floor(la.valueVector[3] / ra.valueVector[3]); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = floor(la.valueVector[0] / ra.valueVector[0]); + result.valueVector[1] = floor(la.valueVector[1] / ra.valueVector[1]); + result.valueVector[2] = floor(la.valueVector[2] / ra.valueVector[2]); + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Number && ra.type == Constant::Type_Vector) + { + bool hadW = ra.valueVector[3] != 0.0f; + float resultW = floor(float(la.valueNumber) / ra.valueVector[3]); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = floor(float(la.valueNumber) / ra.valueVector[0]); + result.valueVector[1] = floor(float(la.valueNumber) / ra.valueVector[1]); + result.valueVector[2] = floor(float(la.valueNumber) / ra.valueVector[2]); + result.valueVector[3] = resultW; + } + } + else if (la.type == Constant::Type_Vector && ra.type == Constant::Type_Number) + { + bool hadW = la.valueVector[3] != 0.0f; + float resultW = floor(la.valueVector[3] / float(ra.valueNumber)); + + if (resultW == 0.0f || hadW) + { + result.type = Constant::Type_Vector; + result.valueVector[0] = floor(la.valueVector[0] / float(ra.valueNumber)); + result.valueVector[1] = floor(la.valueVector[1] / float(ra.valueNumber)); + result.valueVector[2] = floor(la.valueVector[2] / float(ra.valueNumber)); + result.valueVector[3] = floor(la.valueVector[3] / float(ra.valueNumber)); + } + } break; case AstExprBinary::Mod: @@ -209,7 +359,8 @@ struct ConstantVisitor : AstVisitor DenseHashMap& locals; const DenseHashMap* builtins; - bool foldMathK = false; + bool foldLibraryK = false; + LibraryMemberConstantCallback libraryMemberConstantCb; bool wasEmpty = false; @@ -220,13 +371,15 @@ struct ConstantVisitor : AstVisitor DenseHashMap& variables, DenseHashMap& locals, const DenseHashMap* builtins, - bool foldMathK + bool foldLibraryK, + LibraryMemberConstantCallback libraryMemberConstantCb ) : constants(constants) , variables(variables) , locals(locals) , builtins(builtins) - , foldMathK(foldMathK) + , foldLibraryK(foldLibraryK) + , libraryMemberConstantCb(libraryMemberConstantCb) { // since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries wasEmpty = constants.empty() && locals.empty(); @@ -316,11 +469,16 @@ struct ConstantVisitor : AstVisitor { analyze(expr->expr); - if (foldMathK) + if (foldLibraryK) { - if (AstExprGlobal* eg = expr->expr->as(); eg && eg->name == "math") + if (AstExprGlobal* eg = expr->expr->as()) { - result = foldBuiltinMath(expr->index); + if (eg->name == "math") + result = foldBuiltinMath(expr->index); + + // if we have a custom handler and the constant hasn't been resolved + if (libraryMemberConstantCb && result.type == Constant::Type_Unknown) + libraryMemberConstantCb(eg->name.value, expr->index.value, reinterpret_cast(&result)); } } } @@ -468,11 +626,12 @@ void foldConstants( DenseHashMap& variables, DenseHashMap& locals, const DenseHashMap* builtins, - bool foldMathK, + bool foldLibraryK, + LibraryMemberConstantCallback libraryMemberConstantCb, AstNode* root ) { - ConstantVisitor visitor{constants, variables, locals, builtins, foldMathK}; + ConstantVisitor visitor{constants, variables, locals, builtins, foldLibraryK, libraryMemberConstantCb}; root->visit(&visitor); } diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h index e4eb6428..2653c064 100644 --- a/Compiler/src/ConstantFolding.h +++ b/Compiler/src/ConstantFolding.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Compiler.h" + #include "ValueTracking.h" namespace Luau @@ -49,7 +51,8 @@ void foldConstants( DenseHashMap& variables, DenseHashMap& locals, const DenseHashMap* builtins, - bool foldMathK, + bool foldLibraryK, + LibraryMemberConstantCallback libraryMemberConstantCb, AstNode* root ); diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index 4c8e13c6..04adf3e3 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -130,7 +130,8 @@ struct CostVisitor : AstVisitor { return model(expr->expr); } - else if (node->is() || node->is() || node->is() || node->is()) + else if (node->is() || node->is() || node->is() || + node->is()) { return Cost(0, Cost::kLiteral); } diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index 18dc248f..34a27f4f 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -6,10 +6,10 @@ namespace Luau { -static bool isGeneric(AstName name, const AstArray& generics) +static bool isGeneric(AstName name, const AstArray& generics) { - for (const AstGenericType& gt : generics) - if (gt.name == name) + for (const AstGenericType* gt : generics) + if (gt->name == name) return true; return false; @@ -29,6 +29,8 @@ static LuauBytecodeType getPrimitiveType(AstName name) return LBC_TYPE_THREAD; else if (name == "buffer") return LBC_TYPE_BUFFER; + else if (name == "vector") + return LBC_TYPE_VECTOR; else if (name == "any" || name == "unknown") return LBC_TYPE_ANY; else @@ -37,10 +39,10 @@ static LuauBytecodeType getPrimitiveType(AstName name) static LuauBytecodeType getType( const AstType* ty, - const AstArray& generics, + const AstArray& generics, const DenseHashMap& typeAliases, bool resolveAliases, - const char* vectorType, + const char* hostVectorType, const DenseHashMap& userdataTypes, BytecodeBuilder& bytecode ) @@ -54,7 +56,7 @@ static LuauBytecodeType getType( { // note: we only resolve aliases to the depth of 1 to avoid dealing with recursive aliases if (resolveAliases) - return getType((*alias)->type, (*alias)->generics, typeAliases, /* resolveAliases= */ false, vectorType, userdataTypes, bytecode); + return getType((*alias)->type, (*alias)->generics, typeAliases, /* resolveAliases= */ false, hostVectorType, userdataTypes, bytecode); else return LBC_TYPE_ANY; } @@ -62,7 +64,7 @@ static LuauBytecodeType getType( if (isGeneric(ref->name, generics)) return LBC_TYPE_ANY; - if (vectorType && ref->name == vectorType) + if (hostVectorType && ref->name == hostVectorType) return LBC_TYPE_VECTOR; if (LuauBytecodeType prim = getPrimitiveType(ref->name); prim != LBC_TYPE_INVALID) @@ -92,7 +94,7 @@ static LuauBytecodeType getType( for (AstType* ty : un->types) { - LuauBytecodeType et = getType(ty, generics, typeAliases, resolveAliases, vectorType, userdataTypes, bytecode); + LuauBytecodeType et = getType(ty, generics, typeAliases, resolveAliases, hostVectorType, userdataTypes, bytecode); if (et == LBC_TYPE_NIL) { @@ -119,6 +121,10 @@ static LuauBytecodeType getType( { return LBC_TYPE_ANY; } + else if (const AstTypeGroup* group = ty->as()) + { + return getType(group->type, generics, typeAliases, resolveAliases, hostVectorType, userdataTypes, bytecode); + } return LBC_TYPE_ANY; } @@ -126,7 +132,7 @@ static LuauBytecodeType getType( static std::string getFunctionType( const AstExprFunction* func, const DenseHashMap& typeAliases, - const char* vectorType, + const char* hostVectorType, const DenseHashMap& userdataTypes, BytecodeBuilder& bytecode ) @@ -146,8 +152,9 @@ static std::string getFunctionType( for (AstLocal* arg : func->args) { LuauBytecodeType ty = - arg->annotation ? getType(arg->annotation, func->generics, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode) - : LBC_TYPE_ANY; + arg->annotation + ? getType(arg->annotation, func->generics, typeAliases, /* resolveAliases= */ true, hostVectorType, userdataTypes, bytecode) + : LBC_TYPE_ANY; if (ty != LBC_TYPE_ANY) haveNonAnyParam = true; @@ -170,16 +177,30 @@ static bool isMatchingGlobal(const DenseHashMap& globa return false; } +static bool isMatchingGlobalMember( + const DenseHashMap& globals, + AstExprIndexName* expr, + const char* library, + const char* member +) +{ + if (AstExprGlobal* object = expr->expr->as()) + return getGlobalState(globals, object->name) == Compile::Global::Default && object->name == library && expr->index == member; + + return false; +} + struct TypeMapVisitor : AstVisitor { DenseHashMap& functionTypes; DenseHashMap& localTypes; DenseHashMap& exprTypes; - const char* vectorType; + const char* hostVectorType = nullptr; const DenseHashMap& userdataTypes; const BuiltinAstTypes& builtinTypes; const DenseHashMap& builtinCalls; const DenseHashMap& globals; + LibraryMemberTypeCallback libraryMemberTypeCb = nullptr; BytecodeBuilder& bytecode; DenseHashMap typeAliases; @@ -191,21 +212,23 @@ struct TypeMapVisitor : AstVisitor DenseHashMap& functionTypes, DenseHashMap& localTypes, DenseHashMap& exprTypes, - const char* vectorType, + const char* hostVectorType, const DenseHashMap& userdataTypes, const BuiltinAstTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + LibraryMemberTypeCallback libraryMemberTypeCb, BytecodeBuilder& bytecode ) : functionTypes(functionTypes) , localTypes(localTypes) , exprTypes(exprTypes) - , vectorType(vectorType) + , hostVectorType(hostVectorType) , userdataTypes(userdataTypes) , builtinTypes(builtinTypes) , builtinCalls(builtinCalls) , globals(globals) + , libraryMemberTypeCb(libraryMemberTypeCb) , bytecode(bytecode) , typeAliases(AstName()) , resolvedLocals(nullptr) @@ -271,7 +294,7 @@ struct TypeMapVisitor : AstVisitor resolvedExprs[expr] = ty; - LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode); + LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, hostVectorType, userdataTypes, bytecode); exprTypes[expr] = bty; return bty; } @@ -282,7 +305,7 @@ struct TypeMapVisitor : AstVisitor resolvedLocals[local] = ty; - LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, vectorType, userdataTypes, bytecode); + LuauBytecodeType bty = getType(ty, {}, typeAliases, /* resolveAliases= */ true, hostVectorType, userdataTypes, bytecode); if (bty != LBC_TYPE_ANY) localTypes[local] = bty; @@ -370,7 +393,7 @@ struct TypeMapVisitor : AstVisitor bool visit(AstExprFunction* node) override { - std::string type = getFunctionType(node, typeAliases, vectorType, userdataTypes, bytecode); + std::string type = getFunctionType(node, typeAliases, hostVectorType, userdataTypes, bytecode); if (!type.empty()) functionTypes[node] = std::move(type); @@ -456,7 +479,48 @@ struct TypeMapVisitor : AstVisitor if (*typeBcPtr == LBC_TYPE_VECTOR) { if (node->index == "X" || node->index == "Y" || node->index == "Z") + { recordResolvedType(node, &builtinTypes.numberType); + return false; + } + } + } + + if (isMatchingGlobalMember(globals, node, "vector", "zero") || isMatchingGlobalMember(globals, node, "vector", "one")) + { + recordResolvedType(node, &builtinTypes.vectorType); + return false; + } + + if (libraryMemberTypeCb) + { + if (AstExprGlobal* object = node->expr->as()) + { + if (LuauBytecodeType ty = LuauBytecodeType(libraryMemberTypeCb(object->name.value, node->index.value)); ty != LBC_TYPE_ANY) + { + // TODO: 'resolvedExprs' is more limited than 'exprTypes' which limits full inference of more complex types that a user + // callback can return + switch (ty) + { + case LBC_TYPE_BOOLEAN: + resolvedExprs[node] = &builtinTypes.booleanType; + break; + case LBC_TYPE_NUMBER: + resolvedExprs[node] = &builtinTypes.numberType; + break; + case LBC_TYPE_STRING: + resolvedExprs[node] = &builtinTypes.stringType; + break; + case LBC_TYPE_VECTOR: + resolvedExprs[node] = &builtinTypes.vectorType; + break; + default: + break; + } + + exprTypes[node] = ty; + return false; + } } } @@ -675,6 +739,9 @@ struct TypeMapVisitor : AstVisitor case LBF_BUFFER_READU32: case LBF_BUFFER_READF32: case LBF_BUFFER_READF64: + case LBF_VECTOR_MAGNITUDE: + case LBF_VECTOR_DOT: + case LBF_MATH_LERP: recordResolvedType(node, &builtinTypes.numberType); break; @@ -691,6 +758,15 @@ struct TypeMapVisitor : AstVisitor break; case LBF_VECTOR: + case LBF_VECTOR_NORMALIZE: + case LBF_VECTOR_CROSS: + case LBF_VECTOR_FLOOR: + case LBF_VECTOR_CEIL: + case LBF_VECTOR_ABS: + case LBF_VECTOR_SIGN: + case LBF_VECTOR_CLAMP: + case LBF_VECTOR_MIN: + case LBF_VECTOR_MAX: recordResolvedType(node, &builtinTypes.vectorType); break; } @@ -712,15 +788,18 @@ void buildTypeMap( DenseHashMap& localTypes, DenseHashMap& exprTypes, AstNode* root, - const char* vectorType, + const char* hostVectorType, const DenseHashMap& userdataTypes, const BuiltinAstTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + LibraryMemberTypeCallback libraryMemberTypeCb, BytecodeBuilder& bytecode ) { - TypeMapVisitor visitor(functionTypes, localTypes, exprTypes, vectorType, userdataTypes, builtinTypes, builtinCalls, globals, bytecode); + TypeMapVisitor visitor( + functionTypes, localTypes, exprTypes, hostVectorType, userdataTypes, builtinTypes, builtinCalls, globals, libraryMemberTypeCb, bytecode + ); root->visit(&visitor); } diff --git a/Compiler/src/Types.h b/Compiler/src/Types.h index a310bfcc..e60b3b93 100644 --- a/Compiler/src/Types.h +++ b/Compiler/src/Types.h @@ -3,6 +3,7 @@ #include "Luau/Ast.h" #include "Luau/Bytecode.h" +#include "Luau/Compiler.h" #include "Luau/DenseHash.h" #include "ValueTracking.h" @@ -14,16 +15,18 @@ class BytecodeBuilder; struct BuiltinAstTypes { - BuiltinAstTypes(const char* vectorType) - : vectorType{{}, std::nullopt, AstName{vectorType}, std::nullopt, {}} + BuiltinAstTypes(const char* hostVectorType) + : hostVectorType{{}, std::nullopt, AstName{hostVectorType}, std::nullopt, {}} { } - // AstName use here will not match the AstNameTable, but the was we use them here always force a full string compare + // AstName use here will not match the AstNameTable, but the way we use them here always forces a full string compare AstTypeReference booleanType{{}, std::nullopt, AstName{"boolean"}, std::nullopt, {}}; AstTypeReference numberType{{}, std::nullopt, AstName{"number"}, std::nullopt, {}}; AstTypeReference stringType{{}, std::nullopt, AstName{"string"}, std::nullopt, {}}; - AstTypeReference vectorType; + AstTypeReference vectorType{{}, std::nullopt, AstName{"vector"}, std::nullopt, {}}; + + AstTypeReference hostVectorType; }; void buildTypeMap( @@ -31,11 +34,12 @@ void buildTypeMap( DenseHashMap& localTypes, DenseHashMap& exprTypes, AstNode* root, - const char* vectorType, + const char* hostVectorType, const DenseHashMap& userdataTypes, const BuiltinAstTypes& builtinTypes, const DenseHashMap& builtinCalls, const DenseHashMap& globals, + LibraryMemberTypeCallback libraryMemberTypeCb, BytecodeBuilder& bytecode ); diff --git a/Compiler/src/lcode.cpp b/Compiler/src/lcode.cpp index ee150b17..ff2edc3d 100644 --- a/Compiler/src/lcode.cpp +++ b/Compiler/src/lcode.cpp @@ -27,3 +27,28 @@ char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, *outsize = result.size(); return copy; } + +void luau_set_compile_constant_nil(lua_CompileConstant* constant) +{ + Luau::setCompileConstantNil(constant); +} + +void luau_set_compile_constant_boolean(lua_CompileConstant* constant, int b) +{ + Luau::setCompileConstantBoolean(constant, b != 0); +} + +void luau_set_compile_constant_number(lua_CompileConstant* constant, double n) +{ + Luau::setCompileConstantNumber(constant, n); +} + +void luau_set_compile_constant_vector(lua_CompileConstant* constant, float x, float y, float z, float w) +{ + Luau::setCompileConstantVector(constant, x, y, z, w); +} + +void luau_set_compile_constant_string(lua_CompileConstant* constant, const char* s, size_t l) +{ + Luau::setCompileConstantString(constant, s, l); +} diff --git a/Config/include/Luau/Config.h b/Config/include/Luau/Config.h index 6333c55a..89d018d2 100644 --- a/Config/include/Luau/Config.h +++ b/Config/include/Luau/Config.h @@ -1,12 +1,14 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include "Luau/LinterConfig.h" #include "Luau/ParseOptions.h" +#include #include #include -#include +#include #include namespace Luau @@ -19,6 +21,10 @@ constexpr const char* kConfigName = ".luaurc"; struct Config { Config(); + Config(const Config& other); + Config& operator=(const Config& other); + Config(Config&& other) = default; + Config& operator=(Config&& other) = default; Mode mode = Mode::Nonstrict; @@ -32,8 +38,20 @@ struct Config std::vector globals; - std::vector paths; - std::unordered_map aliases; + struct AliasInfo + { + std::string value; + std::string_view configLocation; + std::string originalCase; // The alias in its original case. + }; + + DenseHashMap aliases{""}; + + void setAlias(std::string alias, std::string value, const std::string& configLocation); + +private: + // Prevents making unnecessary copies of the same config location string. + DenseHashMap> configLocationCache{""}; }; struct ConfigResolver @@ -61,6 +79,18 @@ std::optional parseLintRuleString( bool isValidAlias(const std::string& alias); -std::optional parseConfig(const std::string& contents, Config& config, bool compat = false); +struct ConfigOptions +{ + bool compat = false; + + struct AliasOptions + { + std::string configLocation; + bool overwriteAliases; + }; + std::optional aliasOptions = std::nullopt; +}; + +std::optional parseConfig(const std::string& contents, Config& config, const ConfigOptions& options = ConfigOptions{}); } // namespace Luau diff --git a/Config/include/Luau/LinterConfig.h b/Config/include/Luau/LinterConfig.h index 3a68c0d7..e9305009 100644 --- a/Config/include/Luau/LinterConfig.h +++ b/Config/include/Luau/LinterConfig.h @@ -15,7 +15,7 @@ struct HotComment; struct LintWarning { - // Make sure any new lint codes are documented here: https://luau-lang.org/lint + // Make sure any new lint codes are documented here: https://luau.org/lint // Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints enum Code { diff --git a/Config/src/Config.cpp b/Config/src/Config.cpp index 7d010265..44cbe2e5 100644 --- a/Config/src/Config.cpp +++ b/Config/src/Config.cpp @@ -4,7 +4,8 @@ #include "Luau/Lexer.h" #include "Luau/StringUtils.h" #include -#include +#include +#include namespace Luau { @@ -16,6 +17,54 @@ Config::Config() enabledLint.setDefaults(); } +Config::Config(const Config& other) + : mode(other.mode) + , parseOptions(other.parseOptions) + , enabledLint(other.enabledLint) + , fatalLint(other.fatalLint) + , lintErrors(other.lintErrors) + , typeErrors(other.typeErrors) + , globals(other.globals) +{ + for (const auto& [_, aliasInfo] : other.aliases) + { + setAlias(aliasInfo.originalCase, aliasInfo.value, std::string(aliasInfo.configLocation)); + } +} + +Config& Config::operator=(const Config& other) +{ + if (this != &other) + { + Config copy(other); + std::swap(*this, copy); + } + return *this; +} + +void Config::setAlias(std::string alias, std::string value, const std::string& configLocation) +{ + std::string lowercasedAlias = alias; + std::transform( + lowercasedAlias.begin(), + lowercasedAlias.end(), + lowercasedAlias.begin(), + [](unsigned char c) + { + return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; + } + ); + + AliasInfo& info = aliases[lowercasedAlias]; + info.value = std::move(value); + info.originalCase = std::move(alias); + + if (!configLocationCache.contains(configLocation)) + configLocationCache[configLocation] = std::make_unique(configLocation); + + info.configLocation = *configLocationCache[configLocation]; +} + static Error parseBoolean(bool& result, const std::string& value) { if (value == "true") @@ -136,22 +185,21 @@ bool isValidAlias(const std::string& alias) return true; } -Error parseAlias(std::unordered_map& aliases, std::string aliasKey, const std::string& aliasValue) +static Error parseAlias( + Config& config, + const std::string& aliasKey, + const std::string& aliasValue, + const std::optional& aliasOptions +) { if (!isValidAlias(aliasKey)) return Error{"Invalid alias " + aliasKey}; - std::transform( - aliasKey.begin(), - aliasKey.end(), - aliasKey.begin(), - [](unsigned char c) - { - return ('A' <= c && c <= 'Z') ? (c + ('a' - 'A')) : c; - } - ); - if (!aliases.count(aliasKey)) - aliases[std::move(aliasKey)] = aliasValue; + if (!aliasOptions) + return Error("Cannot parse aliases without alias options"); + + if (aliasOptions->overwriteAliases || !config.aliases.contains(aliasKey)) + config.setAlias(aliasKey, aliasValue, aliasOptions->configLocation); return std::nullopt; } @@ -257,7 +305,8 @@ static Error parseJson(const std::string& contents, Action action) arrayTop = (lexer.current().type == '['); next(lexer); } - else if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::ReservedTrue || lexer.current().type == Lexeme::ReservedFalse) + else if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::ReservedTrue || + lexer.current().type == Lexeme::ReservedFalse) { std::string value = lexer.current().type == Lexeme::QuotedString ? std::string(lexer.current().data, lexer.current().getLength()) @@ -285,16 +334,16 @@ static Error parseJson(const std::string& contents, Action action) return {}; } -Error parseConfig(const std::string& contents, Config& config, bool compat) +Error parseConfig(const std::string& contents, Config& config, const ConfigOptions& options) { return parseJson( contents, [&](const std::vector& keys, const std::string& value) -> Error { if (keys.size() == 1 && keys[0] == "languageMode") - return parseModeString(config.mode, value, compat); + return parseModeString(config.mode, value, options.compat); else if (keys.size() == 2 && keys[0] == "lint") - return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, compat); + return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, options.compat); else if (keys.size() == 1 && keys[0] == "lintErrors") return parseBoolean(config.lintErrors, value); else if (keys.size() == 1 && keys[0] == "typeErrors") @@ -304,15 +353,10 @@ Error parseConfig(const std::string& contents, Config& config, bool compat) config.globals.push_back(value); return std::nullopt; } - else if (keys.size() == 1 && keys[0] == "paths") - { - config.paths.push_back(value); - return std::nullopt; - } else if (keys.size() == 2 && keys[0] == "aliases") - return parseAlias(config.aliases, keys[1], value); - else if (compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode") - return parseModeString(config.mode, value, compat); + return parseAlias(config, keys[1], value, options.aliasOptions); + else if (options.compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode") + return parseModeString(config.mode, value, options.compat); else { std::vector keysv(keys.begin(), keys.end()); diff --git a/EqSat/include/Luau/EGraph.h b/EqSat/include/Luau/EGraph.h index 480aa07d..c3bc5ab1 100644 --- a/EqSat/include/Luau/EGraph.h +++ b/EqSat/include/Luau/EGraph.h @@ -23,6 +23,13 @@ struct Analysis final using D = typename N::Data; + Analysis() = default; + + Analysis(N a) + : analysis(std::move(a)) + { + } + template static D fnMake(const N& analysis, const EGraph& egraph, const L& enode) { @@ -44,13 +51,70 @@ struct Analysis final } }; +template +struct Node +{ + L node; + bool boring = false; + + struct Hash + { + size_t operator()(const Node& node) const + { + return typename L::Hash{}(node.node); + } + }; +}; + +template +struct NodeIterator +{ +private: + using iterator = std::vector>; + iterator iter; + +public: + L& operator*() + { + return iter->node; + } + + const L& operator*() const + { + return iter->node; + } + + iterator& operator++() + { + ++iter; + return *this; + } + + iterator operator++(int) + { + iterator copy = *this; + ++*this; + return copy; + } + + bool operator==(const iterator& rhs) const + { + return iter == rhs.iter; + } + + bool operator!=(const iterator& rhs) const + { + return iter != rhs.iter; + } +}; + /// Each e-class is a set of e-nodes representing equivalent terms from a given language, /// and an e-node is a function symbol paired with a list of children e-classes. template struct EClass final { Id id; - std::vector nodes; + std::vector> nodes; D data; std::vector> parents; }; @@ -59,6 +123,15 @@ struct EClass final template struct EGraph final { + using EClassT = EClass; + + EGraph() = default; + + explicit EGraph(N analysis) + : analysis(std::move(analysis)) + { + } + Id find(Id id) const { return unionfind.find(id); @@ -85,33 +158,59 @@ struct EGraph final return id; } - void merge(Id id1, Id id2) + // Returns true if the two IDs were not previously merged. + bool merge(Id id1, Id id2) { id1 = find(id1); id2 = find(id2); if (id1 == id2) - return; + return false; - unionfind.merge(id1, id2); + const Id mergedId = unionfind.merge(id1, id2); - EClass& eclass1 = get(id1); - EClass eclass2 = std::move(get(id2)); + // Ensure that id1 is the Id that we keep, and id2 is the id that we drop. + if (mergedId == id2) + std::swap(id1, id2); + + EClassT& eclass1 = get(id1); + EClassT eclass2 = std::move(get(id2)); classes.erase(id2); - worklist.reserve(worklist.size() + eclass2.parents.size()); - for (auto [enode, id] : eclass2.parents) - worklist.push_back({std::move(enode), id}); + eclass1.nodes.insert(eclass1.nodes.end(), eclass2.nodes.begin(), eclass2.nodes.end()); + eclass1.parents.insert(eclass1.parents.end(), eclass2.parents.begin(), eclass2.parents.end()); + + std::sort( + eclass1.nodes.begin(), + eclass1.nodes.end(), + [](const Node& left, const Node& right) + { + return left.node.index() < right.node.index(); + } + ); + + worklist.reserve(worklist.size() + eclass1.parents.size()); + for (const auto& [eclass, id] : eclass1.parents) + worklist.push_back(id); analysis.join(eclass1.data, eclass2.data); + + return true; } void rebuild() { + std::unordered_set seen; + while (!worklist.empty()) { - auto [enode, id] = worklist.back(); + Id id = worklist.back(); worklist.pop_back(); - repair(get(find(id))); + + const bool isFresh = seen.insert(id).second; + if (!isFresh) + continue; + + repair(find(id)); } } @@ -120,16 +219,26 @@ struct EGraph final return classes.size(); } - EClass& operator[](Id id) + EClassT& operator[](Id id) { return get(find(id)); } - const EClass& operator[](Id id) const + const EClassT& operator[](Id id) const { return const_cast(this)->get(find(id)); } + const std::unordered_map& getAllClasses() const + { + return classes; + } + + void markBoring(Id id, size_t index) + { + get(id).nodes[index].boring = true; + } + private: Analysis analysis; @@ -139,20 +248,25 @@ private: /// The e-class map ð‘€ maps e-class ids to e-classes. All equivalent e-class ids map to the same /// e-class, i.e., 𑎠≡id ð‘ iff ð‘€[ð‘Ž] is the same set as ð‘€[ð‘]. An e-class id ð‘Ž is said to refer to the /// e-class ð‘€[find(ð‘Ž)]. - std::unordered_map> classes; + std::unordered_map classes; /// The hashcons ð» is a map from e-nodes to e-class ids. std::unordered_map hashcons; - std::vector> worklist; + std::vector worklist; private: void canonicalize(L& enode) { // An e-node ð‘› is canonical iff ð‘› = canonicalize(ð‘›), where // canonicalize(ð‘“(ð‘Ž1, ð‘Ž2, ...)) = ð‘“(find(ð‘Ž1), find(ð‘Ž2), ...). - for (Id& id : enode.operands()) - id = find(id); + Luau::EqSat::canonicalize( + enode, + [&](Id id) + { + return find(id); + } + ); } bool isCanonical(const L& enode) const @@ -171,9 +285,9 @@ private: classes.insert_or_assign( id, - EClass{ + EClassT{ id, - {enode}, + {Node{enode, false}}, analysis.make(*this, enode), {}, } @@ -182,7 +296,7 @@ private: for (Id operand : enode.operands()) get(operand).parents.push_back({enode, id}); - worklist.emplace_back(enode, id); + worklist.emplace_back(id); hashcons.insert_or_assign(enode, id); return id; @@ -190,12 +304,13 @@ private: // Looks up for an eclass from a given non-canonicalized `id`. // For a canonicalized eclass, use `get(find(id))` or `egraph[id]`. - EClass& get(Id id) + EClassT& get(Id id) { + LUAU_ASSERT(classes.count(id)); return classes.at(id); } - void repair(EClass& eclass) + void repair(Id id) { // In the egg paper, the `repair` function makes use of two loops over the `eclass.parents` // by first erasing the old enode entry, and adding back the canonicalized enode with the canonical id. @@ -204,26 +319,62 @@ private: // Here, we unify the two loops. I think it's equivalent? // After canonicalizing the enodes, the eclass may contain multiple enodes that are equivalent. - std::unordered_map map; - for (auto& [enode, id] : eclass.parents) + std::unordered_map newParents; + + // The eclass can be deallocated if it is merged into another eclass, so + // we take what we need from it and avoid retaining a pointer. + std::vector> parents = get(id).parents; + for (auto& pair : parents) { + L& parentNode = pair.first; + Id parentId = pair.second; + // By removing the old enode from the hashcons map, we will always find our new canonicalized eclass id. - hashcons.erase(enode); - canonicalize(enode); - hashcons.insert_or_assign(enode, find(id)); + hashcons.erase(parentNode); + canonicalize(parentNode); + hashcons.insert_or_assign(parentNode, find(parentId)); - if (auto it = map.find(enode); it != map.end()) - merge(id, it->second); + if (auto it = newParents.find(parentNode); it != newParents.end()) + merge(parentId, it->second); - map.insert_or_assign(enode, find(id)); + newParents.insert_or_assign(parentNode, find(parentId)); } - eclass.parents.clear(); - for (auto it = map.begin(); it != map.end();) + // We reacquire the pointer because the prior loop potentially merges + // the eclass into another, which might move it around in memory. + EClassT* eclass = &get(find(id)); + + eclass->parents.clear(); + + for (const auto& [node, id] : newParents) + eclass->parents.emplace_back(std::move(node), std::move(id)); + + std::unordered_map newNodes; + for (Node node : eclass->nodes) { - auto node = map.extract(it++); - eclass.parents.emplace_back(std::move(node.key()), node.mapped()); + canonicalize(node.node); + + bool& b = newNodes[std::move(node.node)]; + b = b || node.boring; } + + eclass->nodes.clear(); + + while (!newNodes.empty()) + { + auto n = newNodes.extract(newNodes.begin()); + eclass->nodes.push_back(Node{n.key(), n.mapped()}); + } + + // FIXME: Extract into sortByTag() + std::sort( + eclass->nodes.begin(), + eclass->nodes.end(), + [](const Node& left, const Node& right) + { + return left.node.index() < right.node.index(); + } + ); } }; diff --git a/EqSat/include/Luau/Id.h b/EqSat/include/Luau/Id.h index c56a6ab6..7069f23c 100644 --- a/EqSat/include/Luau/Id.h +++ b/EqSat/include/Luau/Id.h @@ -2,6 +2,7 @@ #pragma once #include +#include #include namespace Luau::EqSat @@ -9,15 +10,17 @@ namespace Luau::EqSat struct Id final { - explicit Id(size_t id); + explicit Id(uint32_t id); - explicit operator size_t() const; + explicit operator uint32_t() const; bool operator==(Id rhs) const; bool operator!=(Id rhs) const; + bool operator<(Id rhs) const; + private: - size_t id; + uint32_t id; }; } // namespace Luau::EqSat diff --git a/EqSat/include/Luau/Language.h b/EqSat/include/Luau/Language.h index 8855d851..f9d3aa4d 100644 --- a/EqSat/include/Luau/Language.h +++ b/EqSat/include/Luau/Language.h @@ -6,9 +6,19 @@ #include "Luau/Slice.h" #include "Luau/Variant.h" +#include #include #include +#include #include +#include + +#define LUAU_EQSAT_UNIT(name) \ + struct name : ::Luau::EqSat::Unit \ + { \ + static constexpr const char* tag = #name; \ + using Unit::Unit; \ + } #define LUAU_EQSAT_ATOM(name, t) \ struct name : public ::Luau::EqSat::Atom \ @@ -31,21 +41,57 @@ using NodeVector::NodeVector; \ } -#define LUAU_EQSAT_FIELD(name) \ - struct name : public ::Luau::EqSat::Field \ - { \ - } - -#define LUAU_EQSAT_NODE_FIELDS(name, ...) \ - struct name : public ::Luau::EqSat::NodeFields \ +#define LUAU_EQSAT_NODE_SET(name) \ + struct name : public ::Luau::EqSat::NodeSet> \ { \ static constexpr const char* tag = #name; \ - using NodeFields::NodeFields; \ + using NodeSet::NodeSet; \ + } + +#define LUAU_EQSAT_NODE_ATOM_WITH_VECTOR(name, t) \ + struct name : public ::Luau::EqSat::NodeAtomAndVector> \ + { \ + static constexpr const char* tag = #name; \ + using NodeAtomAndVector::NodeAtomAndVector; \ } namespace Luau::EqSat { +template +struct Unit +{ + Slice mutableOperands() + { + return {}; + } + + Slice operands() const + { + return {}; + } + + bool operator==(const Unit& rhs) const + { + return true; + } + + bool operator!=(const Unit& rhs) const + { + return false; + } + + struct Hash + { + size_t operator()(const Unit& value) const + { + // chosen by fair dice roll. + // guaranteed to be random. + return 4; + } + }; +}; + template struct Atom { @@ -60,7 +106,7 @@ struct Atom } public: - Slice operands() + Slice mutableOperands() { return {}; } @@ -92,6 +138,62 @@ private: T _value; }; +template +struct NodeAtomAndVector +{ + template + NodeAtomAndVector(const X& value, Args&&... args) + : _value(value) + , vector{std::forward(args)...} + { + } + + Id operator[](size_t i) const + { + return vector[i]; + } + +public: + const X& value() const + { + return _value; + } + + Slice mutableOperands() + { + return Slice{vector.data(), vector.size()}; + } + + Slice operands() const + { + return Slice{vector.data(), vector.size()}; + } + + bool operator==(const NodeAtomAndVector& rhs) const + { + return _value == rhs._value && vector == rhs.vector; + } + + bool operator!=(const NodeAtomAndVector& rhs) const + { + return !(*this == rhs); + } + + struct Hash + { + size_t operator()(const NodeAtomAndVector& value) const + { + size_t result = languageHash(value._value); + hashCombine(result, languageHash(value.vector)); + return result; + } + }; + +private: + X _value; + T vector; +}; + template struct NodeVector { @@ -107,7 +209,7 @@ struct NodeVector } public: - Slice operands() + Slice mutableOperands() { return Slice{vector.data(), vector.size()}; } @@ -139,93 +241,70 @@ private: T vector; }; -/// Empty base class just for static_asserts. -struct FieldBase +template +struct NodeSet { - FieldBase() = delete; + template + friend void canonicalize(NodeSet& node, Find&& find); - FieldBase(FieldBase&&) = delete; - FieldBase& operator=(FieldBase&&) = delete; - - FieldBase(const FieldBase&) = delete; - FieldBase& operator=(const FieldBase&) = delete; -}; - -template -struct Field : FieldBase -{ -}; - -template -struct NodeFields -{ - static_assert(std::conjunction...>::value); - - template - static constexpr int getIndex() + template + NodeSet(Args&&... args) + : vector{std::forward(args)...} { - constexpr int N = sizeof...(Fields); - constexpr bool is[N] = {std::is_same_v, Fields>...}; + std::sort(begin(vector), end(vector)); + auto it = std::unique(begin(vector), end(vector)); + vector.erase(it, end(vector)); + } - for (int i = 0; i < N; ++i) - if (is[i]) - return i; - - return -1; + Id operator[](size_t i) const + { + return vector[i]; } public: - template - NodeFields(Args&&... args) - : array{std::forward(args)...} + Slice mutableOperands() { - } - - Slice operands() - { - return Slice{array}; + return Slice{vector.data(), vector.size()}; } Slice operands() const { - return Slice{array.data(), array.size()}; + return Slice{vector.data(), vector.size()}; } - template - Id field() const + bool operator==(const NodeSet& rhs) const { - static_assert(std::disjunction_v, Fields>...>); - return array[getIndex()]; + return vector == rhs.vector; } - bool operator==(const NodeFields& rhs) const - { - return array == rhs.array; - } - - bool operator!=(const NodeFields& rhs) const + bool operator!=(const NodeSet& rhs) const { return !(*this == rhs); } struct Hash { - size_t operator()(const NodeFields& value) const + size_t operator()(const NodeSet& value) const { - return languageHash(value.array); + return languageHash(value.vector); } }; -private: - std::array array; +protected: + T vector; }; template struct Language final { + using VariantTy = Luau::Variant; + template using WithinDomain = std::disjunction, Ts>...>; + template + friend void canonicalize(Language& enode, Find&& find); + template Language(T&& t, std::enable_if_t::value>* = 0) noexcept : v(std::forward(t)) @@ -237,14 +316,14 @@ struct Language final return v.index(); } - /// You should never call this function with the intention of mutating the `Id`. - /// Reading is ok, but you should also never assume that these `Id`s are stable. - Slice operands() noexcept + /// This should only be used in canonicalization! + /// Always prefer operands() + Slice mutableOperands() noexcept { return visit( [](auto&& v) -> Slice { - return v.operands(); + return v.mutableOperands(); }, v ); @@ -306,7 +385,40 @@ public: }; private: - Variant v; + VariantTy v; }; +template +void canonicalize(Node& node, Find&& find) +{ + // An e-node ð‘› is canonical iff ð‘› = canonicalize(ð‘›), where + // canonicalize(ð‘“(ð‘Ž1, ð‘Ž2, ...)) = ð‘“(find(ð‘Ž1), find(ð‘Ž2), ...). + for (Id& id : node.mutableOperands()) + id = find(id); +} + +// Canonicalizing the Ids in a NodeSet may result in the set decreasing in size. +template +void canonicalize(NodeSet& node, Find&& find) +{ + for (Id& id : node.vector) + id = find(id); + + std::sort(begin(node.vector), end(node.vector)); + auto endIt = std::unique(begin(node.vector), end(node.vector)); + node.vector.erase(endIt, end(node.vector)); +} + +template +void canonicalize(Language& enode, Find&& find) +{ + visit( + [&](auto&& v) + { + Luau::EqSat::canonicalize(v, find); + }, + enode.v + ); +} + } // namespace Luau::EqSat diff --git a/EqSat/include/Luau/LanguageHash.h b/EqSat/include/Luau/LanguageHash.h index 506f352b..cfc33b83 100644 --- a/EqSat/include/Luau/LanguageHash.h +++ b/EqSat/include/Luau/LanguageHash.h @@ -3,6 +3,7 @@ #include #include +#include #include namespace Luau::EqSat diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index 559ee119..22a61628 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -14,7 +14,9 @@ struct UnionFind final Id makeSet(); Id find(Id id) const; Id find(Id id); - void merge(Id a, Id b); + + // Merge aSet with bSet and return the canonicalized Id into the merged set. + Id merge(Id aSet, Id bSet); private: std::vector parents; diff --git a/EqSat/src/Id.cpp b/EqSat/src/Id.cpp index 960249ba..eae6a974 100644 --- a/EqSat/src/Id.cpp +++ b/EqSat/src/Id.cpp @@ -1,15 +1,16 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Id.h" +#include "Luau/Common.h" namespace Luau::EqSat { -Id::Id(size_t id) +Id::Id(uint32_t id) : id(id) { } -Id::operator size_t() const +Id::operator uint32_t() const { return id; } @@ -24,9 +25,14 @@ bool Id::operator!=(Id rhs) const return id != rhs.id; } +bool Id::operator<(Id rhs) const +{ + return id < rhs.id; +} + } // namespace Luau::EqSat size_t std::hash::operator()(Luau::EqSat::Id id) const { - return std::hash()(size_t(id)); + return std::hash()(uint32_t(id)); } diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 619c3f47..6a952999 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -3,12 +3,16 @@ #include "Luau/Common.h" +#include + namespace Luau::EqSat { Id UnionFind::makeSet() { - Id id{parents.size()}; + LUAU_ASSERT(parents.size() < std::numeric_limits::max()); + + Id id{uint32_t(parents.size())}; parents.push_back(id); ranks.push_back(0); @@ -25,42 +29,44 @@ Id UnionFind::find(Id id) Id set = canonicalize(id); // An e-class id ð‘Ž is canonical iff find(ð‘Ž) = ð‘Ž. - while (id != parents[size_t(id)]) + while (id != parents[uint32_t(id)]) { // Note: we don't update the ranks here since a rank // represents the upper bound on the maximum depth of a tree - Id parent = parents[size_t(id)]; - parents[size_t(id)] = set; + Id parent = parents[uint32_t(id)]; + parents[uint32_t(id)] = set; id = parent; } return set; } -void UnionFind::merge(Id a, Id b) +Id UnionFind::merge(Id a, Id b) { Id aSet = find(a); Id bSet = find(b); if (aSet == bSet) - return; + return aSet; // Ensure that the rank of set A is greater than the rank of set B - if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) + if (ranks[uint32_t(aSet)] > ranks[uint32_t(bSet)]) std::swap(aSet, bSet); - parents[size_t(bSet)] = aSet; + parents[uint32_t(bSet)] = aSet; - if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) - ranks[size_t(aSet)]++; + if (ranks[uint32_t(aSet)] == ranks[uint32_t(bSet)]) + ranks[uint32_t(aSet)]++; + + return aSet; } Id UnionFind::canonicalize(Id id) const { - LUAU_ASSERT(size_t(id) < parents.size()); + LUAU_ASSERT(uint32_t(id) < parents.size()); // An e-class id ð‘Ž is canonical iff find(ð‘Ž) = ð‘Ž. - while (id != parents[size_t(id)]) - id = parents[size_t(id)]; + while (id != parents[uint32_t(id)]) + id = parents[uint32_t(id)]; return id; } diff --git a/Makefile b/Makefile index 3e6b85ad..2ad0fc00 100644 --- a/Makefile +++ b/Makefile @@ -42,23 +42,23 @@ ISOCLINE_SOURCES=extern/isocline/src/isocline.c ISOCLINE_OBJECTS=$(ISOCLINE_SOURCES:%=$(BUILD)/%.o) ISOCLINE_TARGET=$(BUILD)/libisocline.a -TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Flags.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/Require.cpp +TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Profiler.cpp CLI/src/Coverage.cpp CLI/src/Repl.cpp CLI/src/Require.cpp TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/ReplEntry.cpp CLI/Require.cpp +REPL_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Profiler.cpp CLI/src/Coverage.cpp CLI/src/Repl.cpp CLI/src/ReplEntry.cpp CLI/src/Require.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau -ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Analyze.cpp +ANALYZE_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Require.cpp CLI/src/Analyze.cpp ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o) ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze -COMPILE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Compile.cpp +COMPILE_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Compile.cpp COMPILE_CLI_OBJECTS=$(COMPILE_CLI_SOURCES:%=$(BUILD)/%.o) COMPILE_CLI_TARGET=$(BUILD)/luau-compile -BYTECODE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Flags.cpp CLI/Bytecode.cpp +BYTECODE_CLI_SOURCES=CLI/src/FileUtils.cpp CLI/src/Flags.cpp CLI/src/Bytecode.cpp BYTECODE_CLI_OBJECTS=$(BYTECODE_CLI_SOURCES:%=$(BUILD)/%.o) BYTECODE_CLI_TARGET=$(BUILD)/luau-bytecode @@ -82,8 +82,10 @@ LDFLAGS= # some gcc versions treat var in `if (type var = val)` as unused # some gcc versions treat variables used in constexpr if blocks as unused +# some gcc versions warn maybe uninitalized on optional members on structs ifeq ($(findstring g++,$(shell $(CXX) --version)),g++) CXXFLAGS+=-Wno-unused + CXXFLAGS+=-Wno-maybe-uninitialized endif # enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere @@ -142,16 +144,16 @@ endif $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include $(CONFIG_OBJECTS): CXXFLAGS+=-std=c++17 -IConfig/include -ICommon/include -IAst/include -$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include +$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -ICompiler/include -IVM/include $(EQSAT_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IEqSat/include $(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include $(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include -$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -IEqSat/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY -$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -Iextern -Iextern/isocline/include -$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -Iextern -$(COMPILE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -$(BYTECODE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include +$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -IEqSat/include -ICodeGen/include -IVM/include -ICLI/include -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY +$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -Iextern -Iextern/isocline/include -ICLI/include +$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IEqSat/include -IConfig/include -Iextern -ICLI/include +$(COMPILE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -ICLI/include +$(BYTECODE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -ICLI/include $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IEqSat/include -IVM/include -ICodeGen/include -IConfig/include $(TESTS_TARGET): LDFLAGS+=-lpthread @@ -227,7 +229,7 @@ luau-tests: $(TESTS_TARGET) # executable targets $(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) -$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) +$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(EQSAT_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(COMPILER_TARGET) $(VM_TARGET) $(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(BYTECODE_CLI_TARGET): $(BYTECODE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) diff --git a/README.md b/README.md index ba337585..df28d2ae 100644 --- a/README.md +++ b/README.md @@ -3,19 +3,19 @@ Luau ![CI](https://github.com/luau-lang/luau/actions/workflows/build.yml/badge.s Luau (lowercase u, /ˈlu.aÊŠ/) is a fast, small, safe, gradually typed embeddable scripting language derived from [Lua](https://lua.org). -It is designed to be backwards compatible with Lua 5.1, as well as incorporating [some features](https://luau-lang.org/compatibility) from future Lua releases, but also expands the feature set (most notably with type annotations). Luau is largely implemented from scratch, with the language runtime being a very heavily modified version of Lua 5.1 runtime, with completely rewritten interpreter and other [performance innovations](https://luau-lang.org/performance). The runtime mostly preserves Lua 5.1 API, so existing bindings should be more or less compatible with a few caveats. +It is designed to be backwards compatible with Lua 5.1, as well as incorporating [some features](https://luau.org/compatibility) from future Lua releases, but also expands the feature set (most notably with type annotations and a state-of-the-art type inference system). Luau is largely implemented from scratch, with the language runtime being a very heavily modified version of Lua 5.1 runtime, with completely rewritten interpreter and other [performance innovations](https://luau.org/performance). The runtime mostly preserves Lua 5.1 API, so existing bindings should be more or less compatible with a few caveats. -Luau is used by Roblox game developers to write game code, as well as by Roblox engineers to implement large parts of the user-facing application code as well as portions of the editor (Roblox Studio) as plugins. Roblox chose to open-source Luau to foster collaboration within the Roblox community as well as to allow other companies and communities to benefit from the ongoing language and runtime innovation. As a consequence, Luau is now also used by games like Alan Wake 2 and Warframe. +Luau is used by Roblox game developers to write game code, and by Roblox engineers to implement large parts of the user-facing application code as well as portions of the editor (Roblox Studio) as plugins. Roblox chose to open-source Luau to foster collaboration within the Roblox community as well as to allow other companies and communities to benefit from the ongoing language and runtime innovation. More recently, Luau has seen adoption in games like Alan Wake 2, Farming Simulator 2025, Second Life, and Warframe. -This repository hosts source code for the language implementation and associated tooling. Documentation for the language is available at https://luau-lang.org/ and accepts contributions via [site repository](https://github.com/luau-lang/site); the language is evolved through RFCs that are located in [rfcs repository](https://github.com/luau-lang/rfcs). +This repository hosts source code for the language implementation and associated tooling. Documentation for the language is available at https://luau.org/ and accepts contributions via [site repository](https://github.com/luau-lang/site); the language is evolved through RFCs that are located in [rfcs repository](https://github.com/luau-lang/rfcs). # Usage -Luau is an embeddable language, but it also comes with two command-line tools by default, `luau` and `luau-analyze`. +Luau is an embeddable programming language, but it also comes with two command-line tools by default, `luau` and `luau-analyze`. `luau` is a command-line REPL and can also run input files. Note that REPL runs in a sandboxed environment and as such doesn't have access to the underlying file system except for ability to `require` modules. -`luau-analyze` is a command-line type checker and linter; given a set of input files, it produces errors/warnings according to the file configuration, which can be customized by using `--!` comments in the files or [`.luaurc`](https://rfcs.luau-lang.org/config-luaurc) files. For details please refer to [type checking]( https://luau-lang.org/typecheck) and [linting](https://luau-lang.org/lint) documentation. +`luau-analyze` is a command-line type checker and linter; given a set of input files, it produces errors/warnings according to the file configuration, which can be customized by using `--!` comments in the files or [`.luaurc`](https://rfcs.luau.org/config-luaurc) files. For details, please refer to our [type checking](https://luau.org/typecheck) and [linting](https://luau.org/lint) documentation. Our community maintains a language server frontend for `luau-analyze` called [luau-lsp](https://github.com/JohnnyMorganz/luau-lsp) for use with text editors. # Installation @@ -28,7 +28,7 @@ Alternatively, you can use one of the packaged distributions (note that these ar - Alpine Linux: [Enable community repositories](https://wiki.alpinelinux.org/w/index.php?title=Enable_Community_Repository) and run `apk add luau` - Gentoo Linux: Luau is [officially packaged by Gentoo](https://packages.gentoo.org/packages/dev-lang/luau) and can be installed using `emerge dev-lang/luau`. You may have to unmask the package first before installing it (which can be done by including the `--autounmask=y` option in the `emerge` command). -After installing, you will want to validate the installation was successful by running the test case [here](https://luau-lang.org/getting-started). +After installing, you will want to validate the installation was successful by running the test case [here](https://luau.org/getting-started). ## Building @@ -41,13 +41,13 @@ cmake --build . --target Luau.Repl.CLI --config RelWithDebInfo cmake --build . --target Luau.Analyze.CLI --config RelWithDebInfo ``` -Alternatively, on Linux/macOS you can use `make`: +Alternatively, on Linux and macOS, you can also use `make`: ```sh make config=release luau luau-analyze ``` -To integrate Luau into your CMake application projects as a library, at the minimum you'll need to depend on `Luau.Compiler` and `Luau.VM` projects. From there you need to create a new Luau state (using Lua 5.x API such as `lua_newstate`), compile source to bytecode and load it into the VM like this: +To integrate Luau into your CMake application projects as a library, at the minimum, you'll need to depend on `Luau.Compiler` and `Luau.VM` projects. From there you need to create a new Luau state (using Lua 5.x API such as `lua_newstate`), compile source to bytecode and load it into the VM like this: ```cpp // needs lua.h and luacode.h @@ -60,24 +60,24 @@ if (result == 0) return 1; /* return chunk main function */ ``` -For more details about the use of host API you currently need to consult [Lua 5.x API](https://www.lua.org/manual/5.1/manual.html#3). Luau closely tracks that API but has a few deviations, such as the need to compile source separately (which is important to be able to deploy VM without a compiler), or lack of `__gc` support (use `lua_newuserdatadtor` instead). +For more details about the use of the host API, you currently need to consult [Lua 5.x API](https://www.lua.org/manual/5.1/manual.html#3). Luau closely tracks that API but has a few deviations, such as the need to compile source separately (which is important to be able to deploy VM without a compiler), and the lack of `__gc` support (use `lua_newuserdatadtor` instead). -To gain advantage of many performance improvements it's highly recommended to use `safeenv` feature, which sandboxes individual scripts' global tables from each other as well as protects builtin libraries from monkey-patching. For this to work you need to call `luaL_sandbox` for the global state and `luaL_sandboxthread` for each new script's execution thread. +To gain advantage of many performance improvements, it's highly recommended to use the `safeenv` feature, which sandboxes individual scripts' global tables from each other, and protects builtin libraries from monkey-patching. For this to work, you must call `luaL_sandbox` on the global state and `luaL_sandboxthread` for each new script's execution thread. # Testing -Luau has an internal test suite; in CMake builds it is split into two targets, `Luau.UnitTest` (for bytecode compiler and type checker/linter tests) and `Luau.Conformance` (for VM tests). The unit tests are written in C++, whereas the conformance tests are largely written in Luau (see `tests/conformance`). +Luau has an internal test suite; in CMake builds, it is split into two targets, `Luau.UnitTest` (for the bytecode compiler and type checker/linter tests) and `Luau.Conformance` (for the VM tests). The unit tests are written in C++, whereas the conformance tests are largely written in Luau (see `tests/conformance`). -Makefile builds combine both into a single target and can be ran via `make test`. +Makefile builds combine both into a single target that can be run via `make test`. # Dependencies -Luau uses C++ as its implementation language. The runtime requires C++11, whereas the compiler and analysis components require C++17. It should build without issues using Microsoft Visual Studio 2017 or later, or gcc-7 or clang-7 or later. +Luau uses C++ as its implementation language. The runtime requires C++11, while the compiler and analysis components require C++17. It should build without issues using Microsoft Visual Studio 2017 or later, or gcc-7 or clang-7 or later. -Other than the STL/CRT, Luau library components don't have external dependencies. The test suite depends on [doctest](https://github.com/onqtam/doctest) testing framework, and the REPL command-line depends on [isocline](https://github.com/daanx/isocline). +Other than the STL/CRT, Luau library components don't have external dependencies. The test suite depends on the [doctest](https://github.com/onqtam/doctest) testing framework, and the REPL command-line depends on [isocline](https://github.com/daanx/isocline). # License -Luau implementation is distributed under the terms of [MIT License](https://github.com/luau-lang/luau/blob/master/LICENSE.txt). It is based on Lua 5.x implementation that is MIT licensed as well. +Luau implementation is distributed under the terms of [MIT License](https://github.com/luau-lang/luau/blob/master/LICENSE.txt). It is based on the Lua 5.x implementation, also under the MIT License. -When Luau is integrated into external projects, we ask to honor the license agreement and include Luau attribution into the user-facing product documentation. The attribution using [Luau logo](https://github.com/luau-lang/site/blob/master/logo.svg) is also encouraged. +When Luau is integrated into external projects, we ask that you honor the license agreement and include Luau attribution into the user-facing product documentation. Attribution making use of the [Luau logo](https://github.com/luau-lang/site/blob/master/logo.svg) is also encouraged when reasonable. diff --git a/SECURITY.md b/SECURITY.md index ca3f5923..7b190324 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -6,9 +6,9 @@ Any source code can not result in memory safety errors or crashes during its com Note that Luau does not provide termination guarantees - some code may exhaust CPU or RAM resources on the system during compilation or execution. -The runtime expects valid bytecode as an input. Feeding bytecode that was not produced by Luau compiler into the VM is not supported and +The runtime expects valid bytecode as an input. Feeding bytecode that was not produced by Luau compiler into the VM is not supported, and doesn't come with any security guarantees; make sure to sign and/or encrypt the bytecode when it crosses a network or file system boundary to avoid tampering. # Reporting a Vulnerability -You can report security bugs via [Hackerone](https://hackerone.com/roblox). Please refer to the linked page for rules of the bounty program. +You can report security bugs via [HackerOne](https://hackerone.com/roblox). Please refer to the linked page for rules of the bounty program. diff --git a/Sources.cmake b/Sources.cmake index 80bcd5b2..1c312cb9 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -14,8 +14,10 @@ endif() # Luau.Ast Sources target_sources(Luau.Ast PRIVATE + Ast/include/Luau/Allocator.h Ast/include/Luau/Ast.h Ast/include/Luau/Confusables.h + Ast/include/Luau/Cst.h Ast/include/Luau/Lexer.h Ast/include/Luau/Location.h Ast/include/Luau/ParseOptions.h @@ -24,8 +26,10 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/StringUtils.h Ast/include/Luau/TimeTrace.h + Ast/src/Allocator.cpp Ast/src/Ast.cpp Ast/src/Confusables.cpp + Ast/src/Cst.cpp Ast/src/Lexer.cpp Ast/src/Location.cpp Ast/src/Parser.cpp @@ -76,6 +80,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/CodeBlockUnwind.h CodeGen/include/Luau/CodeGen.h CodeGen/include/Luau/CodeGenCommon.h + CodeGen/include/Luau/CodeGenOptions.h CodeGen/include/Luau/ConditionA64.h CodeGen/include/Luau/ConditionX64.h CodeGen/include/Luau/IrAnalysis.h @@ -87,6 +92,7 @@ target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/IrUtils.h CodeGen/include/Luau/IrVisitUseDef.h CodeGen/include/Luau/Label.h + CodeGen/include/Luau/LoweringStats.h CodeGen/include/Luau/NativeProtoExecData.h CodeGen/include/Luau/OperandX64.h CodeGen/include/Luau/OptimizeConstProp.h @@ -168,6 +174,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstQuery.h Analysis/include/Luau/Autocomplete.h + Analysis/include/Luau/AutocompleteTypes.h Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Cancellation.h Analysis/include/Luau/Clone.h @@ -181,7 +188,9 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Differ.h Analysis/include/Luau/Documentation.h Analysis/include/Luau/Error.h + Analysis/include/Luau/EqSatSimplification.h Analysis/include/Luau/FileResolver.h + Analysis/include/Luau/FragmentAutocomplete.h Analysis/include/Luau/Frontend.h Analysis/include/Luau/Generalization.h Analysis/include/Luau/GlobalTypes.h @@ -223,6 +232,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypedAllocator.h Analysis/include/Luau/TypeFunction.h Analysis/include/Luau/TypeFunctionReductionGuesser.h + Analysis/include/Luau/TypeFunctionRuntime.h + Analysis/include/Luau/TypeFunctionRuntimeBuilder.h Analysis/include/Luau/TypeFwd.h Analysis/include/Luau/TypeInfer.h Analysis/include/Luau/TypeOrPack.h @@ -242,6 +253,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/AstJsonEncoder.cpp Analysis/src/AstQuery.cpp Analysis/src/Autocomplete.cpp + Analysis/src/AutocompleteCore.cpp Analysis/src/BuiltinDefinitions.cpp Analysis/src/Clone.cpp Analysis/src/Constraint.cpp @@ -253,6 +265,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Differ.cpp Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp + Analysis/src/EqSatSimplification.cpp + Analysis/src/FragmentAutocomplete.cpp Analysis/src/Frontend.cpp Analysis/src/Generalization.cpp Analysis/src/GlobalTypes.cpp @@ -287,6 +301,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/TypedAllocator.cpp Analysis/src/TypeFunction.cpp Analysis/src/TypeFunctionReductionGuesser.cpp + Analysis/src/TypeFunctionRuntime.cpp + Analysis/src/TypeFunctionRuntimeBuilder.cpp Analysis/src/TypeInfer.cpp Analysis/src/TypeOrPack.cpp Analysis/src/TypePack.cpp @@ -345,6 +361,7 @@ target_sources(Luau.VM PRIVATE VM/src/ltm.cpp VM/src/ludata.cpp VM/src/lutf8lib.cpp + VM/src/lveclib.cpp VM/src/lvmexecute.cpp VM/src/lvmload.cpp VM/src/lvmutils.cpp @@ -376,41 +393,46 @@ target_sources(isocline PRIVATE # Common sources shared between all CLI apps target_sources(Luau.CLI.lib PRIVATE - CLI/FileUtils.cpp - CLI/Flags.cpp - CLI/Flags.h - CLI/FileUtils.h + CLI/include/Luau/FileUtils.h + CLI/include/Luau/Flags.h + CLI/include/Luau/Require.h + + CLI/src/FileUtils.cpp + CLI/src/Flags.cpp + CLI/src/Require.cpp ) if(TARGET Luau.Repl.CLI) # Luau.Repl.CLI Sources target_sources(Luau.Repl.CLI PRIVATE - CLI/Coverage.h - CLI/Coverage.cpp - CLI/Profiler.h - CLI/Profiler.cpp - CLI/Repl.cpp - CLI/ReplEntry.cpp - CLI/Require.cpp) + CLI/include/Luau/Coverage.h + CLI/include/Luau/Profiler.h + + CLI/src/Coverage.cpp + CLI/src/Profiler.cpp + CLI/src/Repl.cpp + CLI/src/ReplEntry.cpp + ) endif() if(TARGET Luau.Analyze.CLI) # Luau.Analyze.CLI Sources target_sources(Luau.Analyze.CLI PRIVATE - CLI/Analyze.cpp) + CLI/src/Analyze.cpp + ) endif() if(TARGET Luau.Ast.CLI) # Luau.Ast.CLI Sources target_sources(Luau.Ast.CLI PRIVATE - CLI/Ast.cpp + CLI/src/Ast.cpp ) endif() if(TARGET Luau.UnitTest) # Luau.UnitTest Sources target_sources(Luau.UnitTest PRIVATE - tests/AnyTypeSummary.test.cpp + tests/AnyTypeSummary.test.cpp tests/AssemblyBuilderA64.test.cpp tests/AssemblyBuilderX64.test.cpp tests/AstJsonEncoder.test.cpp @@ -437,9 +459,11 @@ if(TARGET Luau.UnitTest) tests/EqSat.language.test.cpp tests/EqSat.propositional.test.cpp tests/EqSat.slice.test.cpp + tests/EqSatSimplification.test.cpp tests/Error.test.cpp tests/Fixture.cpp tests/Fixture.h + tests/FragmentAutocomplete.test.cpp tests/Frontend.test.cpp tests/Generalization.test.cpp tests/InsertionOrderedMap.test.cpp @@ -474,6 +498,7 @@ if(TARGET Luau.UnitTest) tests/Transpiler.test.cpp tests/TxnLog.test.cpp tests/TypeFunction.test.cpp + tests/TypeFunction.user.test.cpp tests/TypeInfer.aliases.test.cpp tests/TypeInfer.annotations.test.cpp tests/TypeInfer.anyerror.test.cpp @@ -525,12 +550,12 @@ endif() if(TARGET Luau.CLI.Test) # Luau.CLI.Test Sources target_sources(Luau.CLI.Test PRIVATE - CLI/Coverage.h - CLI/Coverage.cpp - CLI/Profiler.h - CLI/Profiler.cpp - CLI/Repl.cpp - CLI/Require.cpp + CLI/include/Luau/Coverage.h + CLI/include/Luau/Profiler.h + + CLI/src/Coverage.cpp + CLI/src/Profiler.cpp + CLI/src/Repl.cpp tests/RegisterCallbacks.h tests/RegisterCallbacks.cpp @@ -542,24 +567,24 @@ endif() if(TARGET Luau.Web) # Luau.Web Sources target_sources(Luau.Web PRIVATE - CLI/Web.cpp) + CLI/src/Web.cpp) endif() if(TARGET Luau.Reduce.CLI) # Luau.Reduce.CLI Sources target_sources(Luau.Reduce.CLI PRIVATE - CLI/Reduce.cpp + CLI/src/Reduce.cpp ) endif() if(TARGET Luau.Compile.CLI) # Luau.Compile.CLI Sources target_sources(Luau.Compile.CLI PRIVATE - CLI/Compile.cpp) + CLI/src/Compile.cpp) endif() if(TARGET Luau.Bytecode.CLI) # Luau.Bytecode.CLI Sources target_sources(Luau.Bytecode.CLI PRIVATE - CLI/Bytecode.cpp) + CLI/src/Bytecode.cpp) endif() diff --git a/VM/include/lua.h b/VM/include/lua.h index 4ee9306e..303d7162 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -154,6 +154,7 @@ LUA_API const float* lua_tovector(lua_State* L, int idx); LUA_API int lua_toboolean(lua_State* L, int idx); LUA_API const char* lua_tolstring(lua_State* L, int idx, size_t* len); LUA_API const char* lua_tostringatom(lua_State* L, int idx, int* atom); +LUA_API const char* lua_tolstringatom(lua_State* L, int idx, size_t* len, int* atom); LUA_API const char* lua_namecallatom(lua_State* L, int* atom); LUA_API int lua_objlen(lua_State* L, int idx); LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx); @@ -189,6 +190,7 @@ LUA_API int lua_pushthread(lua_State* L); LUA_API void lua_pushlightuserdatatagged(lua_State* L, void* p, int tag); LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); +LUA_API void* lua_newuserdatataggedwithmetatable(lua_State* L, size_t sz, int tag); // metatable fetched with lua_getuserdatametatable LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); LUA_API void* lua_newbuffer(lua_State* L, size_t sz); @@ -334,6 +336,7 @@ LUA_API const char* lua_getlightuserdataname(lua_State* L, int tag); LUA_API void lua_clonefunction(lua_State* L, int idx); LUA_API void lua_cleartable(lua_State* L, int idx); +LUA_API void lua_clonetable(lua_State* L, int idx); LUA_API lua_Alloc lua_getallocf(lua_State* L, void** ud); @@ -453,6 +456,8 @@ struct lua_Callbacks void (*debugstep)(lua_State* L, lua_Debug* ar); // gets called after each instruction in single step mode void (*debuginterrupt)(lua_State* L, lua_Debug* ar); // gets called when thread execution is interrupted by break in another thread void (*debugprotectederror)(lua_State* L); // gets called when protected call results in an error + + void (*onallocate)(lua_State* L, size_t osize, size_t nsize); // gets called when memory is allocated }; typedef struct lua_Callbacks lua_Callbacks; diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 367a0281..5860c613 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -136,6 +136,9 @@ LUALIB_API int luaopen_math(lua_State* L); #define LUA_DBLIBNAME "debug" LUALIB_API int luaopen_debug(lua_State* L); +#define LUA_VECLIBNAME "vector" +LUALIB_API int luaopen_vector(lua_State* L); + // open all builtin libraries LUALIB_API void luaL_openlibs(lua_State* L); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 87f85af8..a956fa94 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -39,8 +39,8 @@ const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Ri "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; -const char* luau_ident = "$Luau: Copyright (C) 2019-2023 Roblox Corporation $\n" - "$URL: luau-lang.org $\n"; +const char* luau_ident = "$Luau: Copyright (C) 2019-2024 Roblox Corporation $\n" + "$URL: luau.org $\n"; #define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base)) @@ -64,7 +64,7 @@ const char* luau_ident = "$Luau: Copyright (C) 2019-2023 Roblox Corporation $\n" ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; \ } -static Table* getcurrenv(lua_State* L) +static LuaTable* getcurrenv(lua_State* L) { if (L->ci == L->base_ci) // no enclosing function? return L->gt; // use global table as environment @@ -454,6 +454,29 @@ const char* lua_tostringatom(lua_State* L, int idx, int* atom) return getstr(s); } +const char* lua_tolstringatom(lua_State* L, int idx, size_t* len, int* atom) +{ + StkId o = index2addr(L, idx); + + if (!ttisstring(o)) + { + if (len) + *len = 0; + return NULL; + } + + TString* s = tsvalue(o); + if (len) + *len = s->len; + if (atom) + { + updateatom(L, s); + *atom = s->atom; + } + + return getstr(s); +} + const char* lua_namecallatom(lua_State* L, int* atom) { TString* s = L->namecall; @@ -762,7 +785,7 @@ void lua_setreadonly(lua_State* L, int objindex, int enabled) { const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); - Table* t = hvalue(o); + LuaTable* t = hvalue(o); api_check(L, t != hvalue(registry(L))); t->readonly = bool(enabled); } @@ -771,7 +794,7 @@ int lua_getreadonly(lua_State* L, int objindex) { const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); - Table* t = hvalue(o); + LuaTable* t = hvalue(o); int res = t->readonly; return res; } @@ -780,14 +803,14 @@ void lua_setsafeenv(lua_State* L, int objindex, int enabled) { const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); - Table* t = hvalue(o); + LuaTable* t = hvalue(o); t->safeenv = bool(enabled); } int lua_getmetatable(lua_State* L, int objindex) { luaC_threadbarrier(L); - Table* mt = NULL; + LuaTable* mt = NULL; const TValue* obj = index2addr(L, objindex); switch (ttype(obj)) { @@ -894,7 +917,7 @@ int lua_setmetatable(lua_State* L, int objindex) api_checknelems(L, 1); TValue* obj = index2addr(L, objindex); api_checkvalidindex(L, obj); - Table* mt = NULL; + LuaTable* mt = NULL; if (!ttisnil(L->top - 1)) { api_check(L, ttistable(L->top - 1)); @@ -1214,7 +1237,7 @@ int lua_rawiter(lua_State* L, int idx, int iter) api_check(L, ttistable(t)); api_check(L, iter >= 0); - Table* h = hvalue(t); + LuaTable* h = hvalue(t); int sizearray = h->sizearray; // first we advance iter through the array portion @@ -1283,6 +1306,26 @@ void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) return u->data; } +void* lua_newuserdatataggedwithmetatable(lua_State* L, size_t sz, int tag) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + luaC_checkGC(L); + luaC_threadbarrier(L); + Udata* u = luaU_newudata(L, sz, tag); + + // currently, we always allocate unmarked objects, so forward barrier can be skipped + LUAU_ASSERT(!isblack(obj2gco(u))); + + LuaTable* h = L->global->udatamt[tag]; + api_check(L, h != nullptr); + + u->metatable = h; + + setuvalue(L, L->top, u); + api_incr_top(L); + return u->data; +} + void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) { luaC_checkGC(L); @@ -1374,7 +1417,7 @@ int lua_ref(lua_State* L, int idx) StkId p = index2addr(L, idx); if (!ttisnil(p)) { - Table* reg = hvalue(registry(L)); + LuaTable* reg = hvalue(registry(L)); if (g->registryfree != 0) { // reuse existing slot @@ -1401,7 +1444,7 @@ void lua_unref(lua_State* L, int ref) return; global_State* g = L->global; - Table* reg = hvalue(registry(L)); + LuaTable* reg = hvalue(registry(L)); TValue* slot = luaH_setnum(L, reg, ref); setnvalue(slot, g->registryfree); // NB: no barrier needed because value isn't collectable g->registryfree = ref; @@ -1442,7 +1485,7 @@ void lua_getuserdatametatable(lua_State* L, int tag) api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); luaC_threadbarrier(L); - if (Table* h = L->global->udatamt[tag]) + if (LuaTable* h = L->global->udatamt[tag]) { sethvalue(L, L->top, h); } @@ -1490,12 +1533,22 @@ void lua_cleartable(lua_State* L, int idx) { StkId t = index2addr(L, idx); api_check(L, ttistable(t)); - Table* tt = hvalue(t); + LuaTable* tt = hvalue(t); if (tt->readonly) luaG_readonlyerror(L); luaH_clear(tt); } +void lua_clonetable(lua_State* L, int idx) +{ + StkId t = index2addr(L, idx); + api_check(L, ttistable(t)); + + LuaTable* tt = luaH_clone(L, hvalue(t)); + sethvalue(L, L->top, tt); + api_incr_top(L); +} + lua_Callbacks* lua_callbacks(lua_State* L) { return &L->global->cb; diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 4262eb49..1d23b155 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,6 +11,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauLibWhereErrorAutoreserve) + // convert a stack index to positive #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -67,6 +69,7 @@ static l_noret tag_error(lua_State* L, int narg, int tag) luaL_typeerrorL(L, narg, lua_typename(L, tag)); } +// Can be called without stack space reservation void luaL_where(lua_State* L, int level) { lua_Debug ar; @@ -75,9 +78,14 @@ void luaL_where(lua_State* L, int level) lua_pushfstring(L, "%s:%d: ", ar.short_src, ar.currentline); return; } + + if (FFlag::LuauLibWhereErrorAutoreserve) + lua_rawcheckstack(L, 1); + lua_pushliteral(L, ""); // else, no information available... } +// Can be called without stack space reservation l_noret luaL_errorL(lua_State* L, const char* fmt, ...) { va_list argp; diff --git a/VM/src/lbuflib.cpp b/VM/src/lbuflib.cpp index 178261fb..ec14eb27 100644 --- a/VM/src/lbuflib.cpp +++ b/VM/src/lbuflib.cpp @@ -247,6 +247,87 @@ static int buffer_fill(lua_State* L) return 0; } +static int buffer_readbits(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int64_t bitoffset = (int64_t)luaL_checknumber(L, 2); + int bitcount = luaL_checkinteger(L, 3); + + if (bitoffset < 0) + luaL_error(L, "buffer access out of bounds"); + + if (unsigned(bitcount) > 32) + luaL_error(L, "bit count is out of range of [0; 32]"); + + if (uint64_t(bitoffset + bitcount) > uint64_t(len) * 8) + luaL_error(L, "buffer access out of bounds"); + + unsigned startbyte = unsigned(bitoffset / 8); + unsigned endbyte = unsigned((bitoffset + bitcount + 7) / 8); + + uint64_t data = 0; + +#if defined(LUAU_BIG_ENDIAN) + for (int i = int(endbyte) - 1; i >= int(startbyte); i--) + data = (data << 8) + uint8_t(((char*)buf)[i]); +#else + memcpy(&data, (char*)buf + startbyte, endbyte - startbyte); +#endif + + uint64_t subbyteoffset = bitoffset & 0x7; + uint64_t mask = (1ull << bitcount) - 1; + + lua_pushunsigned(L, unsigned((data >> subbyteoffset) & mask)); + return 1; +} + +static int buffer_writebits(lua_State* L) +{ + size_t len = 0; + void* buf = luaL_checkbuffer(L, 1, &len); + int64_t bitoffset = (int64_t)luaL_checknumber(L, 2); + int bitcount = luaL_checkinteger(L, 3); + unsigned value = luaL_checkunsigned(L, 4); + + if (bitoffset < 0) + luaL_error(L, "buffer access out of bounds"); + + if (unsigned(bitcount) > 32) + luaL_error(L, "bit count is out of range of [0; 32]"); + + if (uint64_t(bitoffset + bitcount) > uint64_t(len) * 8) + luaL_error(L, "buffer access out of bounds"); + + unsigned startbyte = unsigned(bitoffset / 8); + unsigned endbyte = unsigned((bitoffset + bitcount + 7) / 8); + + uint64_t data = 0; + +#if defined(LUAU_BIG_ENDIAN) + for (int i = int(endbyte) - 1; i >= int(startbyte); i--) + data = data * 256 + uint8_t(((char*)buf)[i]); +#else + memcpy(&data, (char*)buf + startbyte, endbyte - startbyte); +#endif + + uint64_t subbyteoffset = bitoffset & 0x7; + uint64_t mask = ((1ull << bitcount) - 1) << subbyteoffset; + + data = (data & ~mask) | ((uint64_t(value) << subbyteoffset) & mask); + +#if defined(LUAU_BIG_ENDIAN) + for (int i = int(startbyte); i < int(endbyte); i++) + { + ((char*)buf)[i] = data & 0xff; + data >>= 8; + } +#else + memcpy((char*)buf + startbyte, &data, endbyte - startbyte); +#endif + return 0; +} + static const luaL_Reg bufferlib[] = { {"create", buffer_create}, {"fromstring", buffer_fromstring}, @@ -272,6 +353,8 @@ static const luaL_Reg bufferlib[] = { {"len", buffer_len}, {"copy", buffer_copy}, {"fill", buffer_fill}, + {"readbits", buffer_readbits}, + {"writebits", buffer_writebits}, {NULL, NULL}, }; diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index e28bb169..3fc687e1 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -998,7 +998,7 @@ static int luauF_rawset(lua_State* L, StkId res, TValue* arg0, int nresults, Stk else if (ttisvector(key) && luai_vecisnan(vvalue(key))) return -1; - Table* t = hvalue(arg0); + LuaTable* t = hvalue(arg0); if (t->readonly) return -1; @@ -1015,7 +1015,7 @@ static int luauF_tinsert(lua_State* L, StkId res, TValue* arg0, int nresults, St { if (nparams == 2 && nresults <= 0 && ttistable(arg0)) { - Table* t = hvalue(arg0); + LuaTable* t = hvalue(arg0); if (t->readonly) return -1; @@ -1032,7 +1032,7 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St { if (nparams >= 1 && nresults < 0 && ttistable(arg0)) { - Table* t = hvalue(arg0); + LuaTable* t = hvalue(arg0); int n = -1; if (nparams == 1) @@ -1055,23 +1055,30 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { - if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) { - double x = nvalue(arg0); - double y = nvalue(args); - double z = nvalue(args + 1); + float x = (float)nvalue(arg0); + float y = (float)nvalue(args); + float z = 0.0f; + + if (nparams >= 3) + { + if (!ttisnumber(args + 1)) + return -1; + z = (float)nvalue(args + 1); + } #if LUA_VECTOR_SIZE == 4 - double w = 0.0; + float w = 0.0f; if (nparams >= 4) { if (!ttisnumber(args + 2)) return -1; - w = nvalue(args + 2); + w = (float)nvalue(args + 2); } - setvvalue(res, float(x), float(y), float(z), float(w)); + setvvalue(res, x, y, z, w); #else - setvvalue(res, float(x), float(y), float(z), 0.0f); + setvvalue(res, x, y, z, 0.0f); #endif return 1; @@ -1160,7 +1167,7 @@ static int luauF_rawlen(lua_State* L, StkId res, TValue* arg0, int nresults, Stk { if (ttistable(arg0)) { - Table* h = hvalue(arg0); + LuaTable* h = hvalue(arg0); setnvalue(res, double(luaH_getn(h))); return 1; } @@ -1204,7 +1211,7 @@ static int luauF_getmetatable(lua_State* L, StkId res, TValue* arg0, int nresult { if (nparams >= 1 && nresults <= 1) { - Table* mt = NULL; + LuaTable* mt = NULL; if (ttistable(arg0)) mt = hvalue(arg0)->metatable; else if (ttisuserdata(arg0)) @@ -1239,11 +1246,11 @@ static int luauF_setmetatable(lua_State* L, StkId res, TValue* arg0, int nresult // note: setmetatable(_, nil) is rare so we use fallback for it to optimize the fast path if (nparams >= 2 && nresults <= 1 && ttistable(arg0) && ttistable(args)) { - Table* t = hvalue(arg0); + LuaTable* t = hvalue(arg0); if (t->readonly || t->metatable != NULL) return -1; // note: overwriting non-null metatable is very rare but it requires __metatable check - Table* mt = hvalue(args); + LuaTable* mt = hvalue(args); t->metatable = mt; luaC_objbarrier(L, t, mt); @@ -1437,6 +1444,280 @@ static int luauF_writefp(lua_State* L, StkId res, TValue* arg0, int nresults, St return -1; } +static int luauF_vectormagnitude(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisvector(arg0)) + { + const float* v = vvalue(arg0); + +#if LUA_VECTOR_SIZE == 4 + setnvalue(res, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + v[3] * v[3])); +#else + setnvalue(res, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); +#endif + + return 1; + } + + return -1; +} + +static int luauF_vectornormalize(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisvector(arg0)) + { + const float* v = vvalue(arg0); + +#if LUA_VECTOR_SIZE == 4 + float invSqrt = 1.0f / sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + v[3] * v[3]); + + setvvalue(res, v[0] * invSqrt, v[1] * invSqrt, v[2] * invSqrt, v[3] * invSqrt); +#else + float invSqrt = 1.0f / sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]); + + setvvalue(res, v[0] * invSqrt, v[1] * invSqrt, v[2] * invSqrt, 0.0f); +#endif + + return 1; + } + + return -1; +} + +static int luauF_vectorcross(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisvector(arg0) && ttisvector(args)) + { + const float* a = vvalue(arg0); + const float* b = vvalue(args); + + // same for 3- and 4- wide vectors + setvvalue(res, a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0], 0.0f); + return 1; + } + + return -1; +} + +static int luauF_vectordot(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisvector(arg0) && ttisvector(args)) + { + const float* a = vvalue(arg0); + const float* b = vvalue(args); + +#if LUA_VECTOR_SIZE == 4 + setnvalue(res, a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3]); +#else + setnvalue(res, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); +#endif + + return 1; + } + + return -1; +} + +static int luauF_vectorfloor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisvector(arg0)) + { + const float* v = vvalue(arg0); + +#if LUA_VECTOR_SIZE == 4 + setvvalue(res, floorf(v[0]), floorf(v[1]), floorf(v[2]), floorf(v[3])); +#else + setvvalue(res, floorf(v[0]), floorf(v[1]), floorf(v[2]), 0.0f); +#endif + + return 1; + } + + return -1; +} + +static int luauF_vectorceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisvector(arg0)) + { + const float* v = vvalue(arg0); + +#if LUA_VECTOR_SIZE == 4 + setvvalue(res, ceilf(v[0]), ceilf(v[1]), ceilf(v[2]), ceilf(v[3])); +#else + setvvalue(res, ceilf(v[0]), ceilf(v[1]), ceilf(v[2]), 0.0f); +#endif + + return 1; + } + + return -1; +} + +static int luauF_vectorabs(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisvector(arg0)) + { + const float* v = vvalue(arg0); + +#if LUA_VECTOR_SIZE == 4 + setvvalue(res, fabsf(v[0]), fabsf(v[1]), fabsf(v[2]), fabsf(v[3])); +#else + setvvalue(res, fabsf(v[0]), fabsf(v[1]), fabsf(v[2]), 0.0f); +#endif + + return 1; + } + + return -1; +} + +static int luauF_vectorsign(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisvector(arg0)) + { + const float* v = vvalue(arg0); + +#if LUA_VECTOR_SIZE == 4 + setvvalue(res, luaui_signf(v[0]), luaui_signf(v[1]), luaui_signf(v[2]), luaui_signf(v[3])); +#else + setvvalue(res, luaui_signf(v[0]), luaui_signf(v[1]), luaui_signf(v[2]), 0.0f); +#endif + + return 1; + } + + return -1; +} + +static int luauF_vectorclamp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 3 && nresults <= 1 && ttisvector(arg0) && ttisvector(args) && ttisvector(args + 1)) + { + const float* v = vvalue(arg0); + const float* min = vvalue(args); + const float* max = vvalue(args + 1); + + if (min[0] <= max[0] && min[1] <= max[1] && min[2] <= max[2]) + { +#if LUA_VECTOR_SIZE == 4 + setvvalue( + res, + luaui_clampf(v[0], min[0], max[0]), + luaui_clampf(v[1], min[1], max[1]), + luaui_clampf(v[2], min[2], max[2]), + luaui_clampf(v[3], min[3], max[3]) + ); +#else + setvvalue(res, luaui_clampf(v[0], min[0], max[0]), luaui_clampf(v[1], min[1], max[1]), luaui_clampf(v[2], min[2], max[2]), 0.0f); +#endif + + return 1; + } + } + + return -1; +} + +static int luauF_vectormin(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisvector(arg0) && ttisvector(args)) + { + const float* a = vvalue(arg0); + const float* b = vvalue(args); + + float result[4]; + + result[0] = (b[0] < a[0]) ? b[0] : a[0]; + result[1] = (b[1] < a[1]) ? b[1] : a[1]; + result[2] = (b[2] < a[2]) ? b[2] : a[2]; + +#if LUA_VECTOR_SIZE == 4 + result[3] = (b[3] < a[3]) ? b[3] : a[3]; +#else + result[3] = 0.0f; +#endif + + for (int i = 3; i <= nparams; ++i) + { + if (!ttisvector(args + (i - 2))) + return -1; + + const float* c = vvalue(args + (i - 2)); + + result[0] = (c[0] < result[0]) ? c[0] : result[0]; + result[1] = (c[1] < result[1]) ? c[1] : result[1]; + result[2] = (c[2] < result[2]) ? c[2] : result[2]; +#if LUA_VECTOR_SIZE == 4 + result[3] = (c[3] < result[3]) ? c[3] : result[3]; +#endif + } + + setvvalue(res, result[0], result[1], result[2], result[3]); + return 1; + } + + return -1; +} + +static int luauF_vectormax(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisvector(arg0) && ttisvector(args)) + { + const float* a = vvalue(arg0); + const float* b = vvalue(args); + + float result[4]; + + result[0] = (b[0] > a[0]) ? b[0] : a[0]; + result[1] = (b[1] > a[1]) ? b[1] : a[1]; + result[2] = (b[2] > a[2]) ? b[2] : a[2]; + +#if LUA_VECTOR_SIZE == 4 + result[3] = (b[3] > a[3]) ? b[3] : a[3]; +#else + result[3] = 0.0f; +#endif + + for (int i = 3; i <= nparams; ++i) + { + if (!ttisvector(args + (i - 2))) + return -1; + + const float* c = vvalue(args + (i - 2)); + + result[0] = (c[0] > result[0]) ? c[0] : result[0]; + result[1] = (c[1] > result[1]) ? c[1] : result[1]; + result[2] = (c[2] > result[2]) ? c[2] : result[2]; +#if LUA_VECTOR_SIZE == 4 + result[3] = (c[3] > result[3]) ? c[3] : result[3]; +#endif + } + + setvvalue(res, result[0], result[1], result[2], result[3]); + return 1; + } + + return -1; +} + +static int luauF_lerp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + double a = nvalue(arg0); + double b = nvalue(args); + double t = nvalue(args + 1); + + double r = (t == 1.0) ? b : a + (b - a) * t; + + setnvalue(res, r); + return 1; + } + + return -1; +} + static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { return -1; @@ -1620,6 +1901,20 @@ const luau_FastFunction luauF_table[256] = { luauF_readfp, luauF_writefp, + luauF_vectormagnitude, + luauF_vectornormalize, + luauF_vectorcross, + luauF_vectordot, + luauF_vectorfloor, + luauF_vectorceil, + luauF_vectorabs, + luauF_vectorsign, + luauF_vectorclamp, + luauF_vectormin, + luauF_vectormax, + + luauF_lerp, + // When adding builtins, add them above this line; what follows is 64 "dummy" entries with luauF_missing fallback. // This is important so that older versions of the runtime that don't support newer builtins automatically fall back via luauF_missing. // Given the builtin addition velocity this should always provide a larger compatibility window than bytecode versions suggest. diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 3d39a2de..5a372aec 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -2,9 +2,12 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lualib.h" +#include "ldebug.h" #include "lstate.h" #include "lvm.h" +LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) + #define CO_STATUS_ERROR -1 #define CO_STATUS_BREAK -2 @@ -37,6 +40,12 @@ static int auxresume(lua_State* L, lua_State* co, int narg) luaL_error(L, "too many arguments to resume"); lua_xmove(L, co, narg); } + else + { + // coroutine might be completely full already + if ((co->top - co->base) > LUAI_MAXCSTACK) + luaL_error(L, "too many arguments to resume"); + } co->singlestep = L->singlestep; @@ -227,8 +236,22 @@ static int coclose(lua_State* L) else { lua_pushboolean(L, false); - if (lua_gettop(co)) - lua_xmove(co, L, 1); // move error message + + if (DFFlag::LuauStackLimit) + { + if (co->status == LUA_ERRMEM) + lua_pushstring(L, LUA_MEMERRMSG); + else if (co->status == LUA_ERRERR) + lua_pushstring(L, LUA_ERRERRMSG); + else if (lua_gettop(co)) + lua_xmove(co, L, 1); // move error message + } + else + { + if (lua_gettop(co)) + lua_xmove(co, L, 1); // move error message + } + lua_resetthread(co); return 2; } diff --git a/VM/src/ldblib.cpp b/VM/src/ldblib.cpp index dfc61e4d..ff9fdd76 100644 --- a/VM/src/ldblib.cpp +++ b/VM/src/ldblib.cpp @@ -107,6 +107,10 @@ static int db_info(lua_State* L) break; default: + // restore stack state of another thread as 'f' option might not have been visited yet + if (L != L1) + lua_settop(L1, l1top); + luaL_argerror(L, arg + 2, "invalid option"); } } diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 07cc117e..44da57c2 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -422,6 +422,20 @@ int luaG_isnative(lua_State* L, int level) return (ci->flags & LUA_CALLINFO_NATIVE) != 0 ? 1 : 0; } +int luaG_hasnative(lua_State* L, int level) +{ + if (unsigned(level) >= unsigned(L->ci - L->base_ci)) + return 0; + + CallInfo* ci = L->ci - level; + + Proto* proto = getluaproto(ci); + if (proto == nullptr) + return 0; + + return (proto->execdata != nullptr); +} + void lua_singlestep(lua_State* L, int enabled) { L->singlestep = bool(enabled); diff --git a/VM/src/ldebug.h b/VM/src/ldebug.h index 49b1ca88..f215e815 100644 --- a/VM/src/ldebug.h +++ b/VM/src/ldebug.h @@ -31,3 +31,4 @@ LUAI_FUNC bool luaG_onbreak(lua_State* L); LUAI_FUNC int luaG_getline(Proto* p, int pc); LUAI_FUNC int luaG_isnative(lua_State* L, int level); +LUAI_FUNC int luaG_hasnative(lua_State* L, int level); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 0cffec40..f9fe30d6 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,7 +17,11 @@ #include -LUAU_FASTFLAGVARIABLE(LuauErrorResumeCleanupArgs, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauStackLimit, false) +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauPopIncompleteCi, false) + +// keep max stack allocation request under 1GB +#define MAX_STACK_SIZE (int(1024 / sizeof(TValue)) * 1024 * 1024) /* ** {====================================================== @@ -176,8 +180,24 @@ static void correctstack(lua_State* L, TValue* oldstack) L->base = (L->base - oldstack) + L->stack; } -void luaD_reallocstack(lua_State* L, int newsize) +void luaD_reallocstack(lua_State* L, int newsize, int fornewci) { + // throw 'out of memory' error because space for a custom error message cannot be guaranteed here + if (DFFlag::LuauStackLimit && newsize > MAX_STACK_SIZE) + { + // reallocation was performaed to setup a new CallInfo frame, which we have to remove + if (DFFlag::LuauPopIncompleteCi && fornewci) + { + CallInfo* cip = L->ci - 1; + + L->ci = cip; + L->base = cip->base; + L->top = cip->top; + } + + luaD_throw(L, LUA_ERRMEM); + } + TValue* oldstack = L->stack; int realsize = newsize + EXTRA_STACK; LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK); @@ -201,10 +221,17 @@ void luaD_reallocCI(lua_State* L, int newsize) void luaD_growstack(lua_State* L, int n) { - if (n <= L->stacksize) // double size is enough? - luaD_reallocstack(L, 2 * L->stacksize); + if (DFFlag::LuauPopIncompleteCi) + { + luaD_reallocstack(L, getgrownstacksize(L, n), 0); + } else - luaD_reallocstack(L, L->stacksize + n); + { + if (n <= L->stacksize) // double size is enough? + luaD_reallocstack(L, 2 * L->stacksize, 0); + else + luaD_reallocstack(L, L->stacksize + n, 0); + } } CallInfo* luaD_growCI(lua_State* L) @@ -430,11 +457,7 @@ static void resume_handle(lua_State* L, void* ud) static int resume_error(lua_State* L, const char* msg, int narg) { - if (FFlag::LuauErrorResumeCleanupArgs) - L->top -= narg; - else - L->top = L->ci->base; - + L->top -= narg; setsvalue(L, L->top, luaS_new(L, msg)); incr_top(L); return LUA_ERRRUN; diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 0f7b42ad..707af0ee 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -7,11 +7,21 @@ #include "luaconf.h" #include "ldebug.h" +// returns target stack for 'n' extra elements to reallocate +// if possible, stack size growth factor is 2x +#define getgrownstacksize(L, n) ((n) <= L->stacksize ? 2 * L->stacksize : L->stacksize + (n)) + +#define luaD_checkstackfornewci(L, n) \ + if ((char*)L->stack_last - (char*)L->top <= (n) * (int)sizeof(TValue)) \ + luaD_reallocstack(L, getgrownstacksize(L, (n)), 1); \ + else \ + condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK, 1)); + #define luaD_checkstack(L, n) \ if ((char*)L->stack_last - (char*)L->top <= (n) * (int)sizeof(TValue)) \ luaD_growstack(L, n); \ else \ - condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK)); + condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK, 0)); #define incr_top(L) \ { \ @@ -47,7 +57,7 @@ LUAI_FUNC CallInfo* luaD_growCI(lua_State* L); LUAI_FUNC void luaD_call(lua_State* L, StkId func, int nresults); LUAI_FUNC int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t oldtop, ptrdiff_t ef); LUAI_FUNC void luaD_reallocCI(lua_State* L, int newsize); -LUAI_FUNC void luaD_reallocstack(lua_State* L, int newsize); +LUAI_FUNC void luaD_reallocstack(lua_State* L, int newsize, int fornewci); LUAI_FUNC void luaD_growstack(lua_State* L, int n); LUAI_FUNC void luaD_checkCstack(lua_State* L); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 2a1e45c4..b172d0ad 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -55,7 +55,7 @@ Proto* luaF_newproto(lua_State* L) return f; } -Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) +Closure* luaF_newLclosure(lua_State* L, int nelems, LuaTable* e, Proto* p) { Closure* c = luaM_newgco(L, Closure, sizeLclosure(nelems), L->activememcat); luaC_init(L, c, LUA_TFUNCTION); @@ -70,7 +70,7 @@ Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) return c; } -Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e) +Closure* luaF_newCclosure(lua_State* L, int nelems, LuaTable* e) { Closure* c = luaM_newgco(L, Closure, sizeCclosure(nelems), L->activememcat); luaC_init(L, c, LUA_TFUNCTION); diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h index 679836e7..453cf581 100644 --- a/VM/src/lfunc.h +++ b/VM/src/lfunc.h @@ -8,8 +8,8 @@ #define sizeLclosure(n) (offsetof(Closure, l.uprefs) + sizeof(TValue) * (n)) LUAI_FUNC Proto* luaF_newproto(lua_State* L); -LUAI_FUNC Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p); -LUAI_FUNC Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e); +LUAI_FUNC Closure* luaF_newLclosure(lua_State* L, int nelems, LuaTable* e, Proto* p); +LUAI_FUNC Closure* luaF_newCclosure(lua_State* L, int nelems, LuaTable* e); LUAI_FUNC UpVal* luaF_findupval(lua_State* L, StkId level); LUAI_FUNC void luaF_close(lua_State* L, StkId level); LUAI_FUNC void luaF_closeupval(lua_State* L, UpVal* uv, bool dead); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 4473f04f..c5e16e43 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -244,7 +244,7 @@ static void reallymarkobject(global_State* g, GCObject* o) } case LUA_TUSERDATA: { - Table* mt = gco2u(o)->metatable; + LuaTable* mt = gco2u(o)->metatable; gray2black(o); // udata are never gray if (mt) markobject(g, mt); @@ -292,7 +292,7 @@ static void reallymarkobject(global_State* g, GCObject* o) } } -static const char* gettablemode(global_State* g, Table* h) +static const char* gettablemode(global_State* g, LuaTable* h) { const TValue* mode = gfasttm(g, h->metatable, TM_MODE); @@ -302,13 +302,13 @@ static const char* gettablemode(global_State* g, Table* h) return NULL; } -static int traversetable(global_State* g, Table* h) +static int traversetable(global_State* g, LuaTable* h) { int i; int weakkey = 0; int weakvalue = 0; if (h->metatable) - markobject(g, cast_to(Table*, h->metatable)); + markobject(g, cast_to(LuaTable*, h->metatable)); // is there a weak mode? if (const char* modev = gettablemode(g, h)) @@ -436,11 +436,13 @@ static void shrinkstack(lua_State* L) int s_used = cast_int(lim - L->stack); // part of stack in use if (L->size_ci > LUAI_MAXCALLS) // handling overflow? return; // do not touch the stacks - if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) + + if (3 * size_t(ci_used) < size_t(L->size_ci) && 2 * BASIC_CI_SIZE < L->size_ci) luaD_reallocCI(L, L->size_ci / 2); // still big enough... condhardstacktests(luaD_reallocCI(L, ci_used + 1)); - if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) - luaD_reallocstack(L, L->stacksize / 2); // still big enough... + + if (3 * size_t(s_used) < size_t(L->stacksize) && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) + luaD_reallocstack(L, L->stacksize / 2, 0); // still big enough... condhardstacktests(luaD_reallocstack(L, s_used)); } @@ -457,11 +459,11 @@ static size_t propagatemark(global_State* g) { case LUA_TTABLE: { - Table* h = gco2h(o); + LuaTable* h = gco2h(o); g->gray = h->gclist; if (traversetable(g, h)) // table is weak? black2gray(o); // keep it gray - return sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); + return sizeof(LuaTable) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); } case LUA_TFUNCTION: { @@ -551,8 +553,8 @@ static size_t cleartable(lua_State* L, GCObject* l) size_t work = 0; while (l) { - Table* h = gco2h(l); - work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); + LuaTable* h = gco2h(l); + work += sizeof(LuaTable) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); int i = h->sizearray; while (i--) @@ -1153,7 +1155,7 @@ void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v) makewhite(g, o); // mark as white just to avoid other barriers } -void luaC_barriertable(lua_State* L, Table* t, GCObject* v) +void luaC_barriertable(lua_State* L, LuaTable* t, GCObject* v) { global_State* g = L->global; GCObject* o = obj2gco(t); diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 722de9d1..683542b6 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -131,7 +131,7 @@ LUAI_FUNC void luaC_fullgc(lua_State* L); LUAI_FUNC void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt); LUAI_FUNC void luaC_upvalclosed(lua_State* L, UpVal* uv); LUAI_FUNC void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v); -LUAI_FUNC void luaC_barriertable(lua_State* L, Table* t, GCObject* v); +LUAI_FUNC void luaC_barriertable(lua_State* L, LuaTable* t, GCObject* v); LUAI_FUNC void luaC_barrierback(lua_State* L, GCObject* o, GCObject** gclist); LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index 768561cb..7a47ab86 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -34,7 +34,7 @@ static void validateref(global_State* g, GCObject* f, TValue* v) } } -static void validatetable(global_State* g, Table* h) +static void validatetable(global_State* g, LuaTable* h) { int sizenode = 1 << h->lsizenode; @@ -290,9 +290,9 @@ static void dumpstring(FILE* f, TString* ts) fprintf(f, "\"}"); } -static void dumptable(FILE* f, Table* h) +static void dumptable(FILE* f, LuaTable* h) { - size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); + size_t size = sizeof(LuaTable) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); fprintf(f, "{\"type\":\"table\",\"cat\":%d,\"size\":%d", h->memcat, int(size)); @@ -654,9 +654,9 @@ static void enumstring(EnumContext* ctx, TString* ts) enumnode(ctx, obj2gco(ts), ts->len, NULL); } -static void enumtable(EnumContext* ctx, Table* h) +static void enumtable(EnumContext* ctx, LuaTable* h) { - size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); + size_t size = sizeof(LuaTable) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); // Provide a name for a special registry table enumnode(ctx, obj2gco(h), size, h == hvalue(registry(ctx->L)) ? "registry" : NULL); @@ -754,7 +754,7 @@ static void enumudata(EnumContext* ctx, Udata* u) { const char* name = NULL; - if (Table* h = u->metatable) + if (LuaTable* h = u->metatable) { if (h->node != &luaH_dummynode) { diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index aad6513f..efcf1904 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -15,6 +15,7 @@ static const luaL_Reg lualibs[] = { {LUA_UTF8LIBNAME, luaopen_utf8}, {LUA_BITLIBNAME, luaopen_bit32}, {LUA_BUFFERLIBNAME, luaopen_buffer}, + {LUA_VECLIBNAME, luaopen_vector}, {NULL, NULL}, }; diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 7adaf0b4..546725ca 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -403,6 +403,30 @@ static int math_round(lua_State* L) return 1; } +static int math_map(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + double inmin = luaL_checknumber(L, 2); + double inmax = luaL_checknumber(L, 3); + double outmin = luaL_checknumber(L, 4); + double outmax = luaL_checknumber(L, 5); + + double result = outmin + (x - inmin) * (outmax - outmin) / (inmax - inmin); + lua_pushnumber(L, result); + return 1; +} + +static int math_lerp(lua_State* L) +{ + double a = luaL_checknumber(L, 1); + double b = luaL_checknumber(L, 2); + double t = luaL_checknumber(L, 3); + + double r = (t == 1.0) ? b : a + (b - a) * t; + lua_pushnumber(L, r); + return 1; +} + static const luaL_Reg mathlib[] = { {"abs", math_abs}, {"acos", math_acos}, @@ -436,6 +460,8 @@ static const luaL_Reg mathlib[] = { {"clamp", math_clamp}, {"sign", math_sign}, {"round", math_round}, + {"map", math_map}, + {"lerp", math_lerp}, {NULL, NULL}, }; @@ -455,5 +481,6 @@ int luaopen_math(lua_State* L) lua_setfield(L, -2, "pi"); lua_pushnumber(L, HUGE_VAL); lua_setfield(L, -2, "huge"); + return 1; } diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 5ff5de72..0738840b 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -121,7 +121,7 @@ static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); static_assert(offsetof(Udata, data) == ABISWITCH(16, 16, 12), "size mismatch for userdata header"); -static_assert(sizeof(Table) == ABISWITCH(48, 32, 32), "size mismatch for table header"); +static_assert(sizeof(LuaTable) == ABISWITCH(48, 32, 32), "size mismatch for table header"); static_assert(offsetof(Buffer, data) == ABISWITCH(8, 8, 8), "size mismatch for buffer header"); const size_t kSizeClasses = LUA_SIZECLASSES; @@ -192,7 +192,7 @@ struct SizeClassConfig const SizeClassConfig kSizeClassConfig; // size class for a block of size sz; returns -1 for size=0 because empty allocations take no space -#define sizeclass(sz) (size_t((sz)-1) < kMaxSmallSizeUsed ? kSizeClassConfig.classForSize[sz] : -1) +#define sizeclass(sz) (size_t((sz) - 1) < kMaxSmallSizeUsed ? kSizeClassConfig.classForSize[sz] : -1) // metadata for a block is stored in the first pointer of the block #define metadata(block) (*(void**)(block)) @@ -504,6 +504,11 @@ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) g->totalbytes += nsize; g->memcatbytes[memcat] += nsize; + if (LUAU_UNLIKELY(!!g->cb.onallocate)) + { + g->cb.onallocate(L, 0, nsize); + } + return block; } @@ -539,6 +544,11 @@ GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) g->totalbytes += nsize; g->memcatbytes[memcat] += nsize; + if (LUAU_UNLIKELY(!!g->cb.onallocate)) + { + g->cb.onallocate(L, 0, nsize); + } + return (GCObject*)block; } @@ -618,6 +628,12 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 LUAU_ASSERT((nsize == 0) == (result == NULL)); g->totalbytes = (g->totalbytes - osize) + nsize; g->memcatbytes[memcat] += nsize - osize; + + if (LUAU_UNLIKELY(!!g->cb.onallocate)) + { + g->cb.onallocate(L, osize, nsize); + } + return result; } diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index 38bfb322..de56bb09 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -33,6 +33,17 @@ inline bool luai_vecisnan(const float* a) #endif } +inline float luaui_signf(float v) +{ + return v > 0.0f ? 1.0f : v < 0.0f ? -1.0f : 0.0f; +} + +inline float luaui_clampf(float v, float min, float max) +{ + float r = v < min ? min : v; + return r > max ? max : r; +} + LUAU_FASTMATH_BEGIN inline double luai_nummod(double a, double b) { diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index f685b235..e4202d70 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -116,6 +116,13 @@ const char* luaO_pushfstring(lua_State* L, const char* fmt, ...) return msg; } +// Possible chunkname prefixes: +// +// '=' prefix: meant to represent custom chunknames. When truncation is needed, +// the beginning of the chunkname is kept. +// +// '@' prefix: meant to represent filepaths. When truncation is needed, the end +// of the filepath is kept, as this is more useful for identifying the file. const char* luaO_chunkid(char* buf, size_t buflen, const char* source, size_t srclen) { if (*source == '=') diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 18c69641..6719faaf 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -263,7 +263,7 @@ typedef struct Udata int len; - struct Table* metatable; + struct LuaTable* metatable; union { @@ -390,7 +390,7 @@ typedef struct Closure uint8_t preload; GCObject* gclist; - struct Table* env; + struct LuaTable* env; union { @@ -454,7 +454,7 @@ typedef struct LuaNode } // clang-format off -typedef struct Table +typedef struct LuaTable { CommonHeader; @@ -473,11 +473,11 @@ typedef struct Table }; - struct Table* metatable; + struct LuaTable* metatable; TValue* array; // array part LuaNode* node; GCObject* gclist; -} Table; +} LuaTable; // clang-format on /* diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 6b7a9aa0..ddb1e12e 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -149,7 +149,7 @@ void lua_resetthread(lua_State* L) L->nCcalls = L->baseCcalls = 0; // clear thread stack if (L->stacksize != BASIC_STACK_SIZE + EXTRA_STACK) - luaD_reallocstack(L, BASIC_STACK_SIZE); + luaD_reallocstack(L, BASIC_STACK_SIZE, 0); for (int i = 0; i < L->stacksize; i++) setnilvalue(L->stack + i); } diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 3f4f9425..ad162391 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -198,7 +198,7 @@ typedef struct global_State struct lua_State* mainthread; UpVal uvhead; // head of double-linked list of all open upvalues - struct Table* mt[LUA_T_COUNT]; // metatables for basic types + struct LuaTable* mt[LUA_T_COUNT]; // metatables for basic types TString* ttname[LUA_T_COUNT]; // names for basic types TString* tmname[TM_N]; // array with tag-method names @@ -217,7 +217,7 @@ typedef struct global_State lua_ExecutionCallbacks ecb; void (*udatagc[LUA_UTAG_LIMIT])(lua_State*, void*); // for each userdata tag, a gc callback to be called immediately before freeing memory - Table* udatamt[LUA_UTAG_LIMIT]; // metatables for tagged userdata + LuaTable* udatamt[LUA_UTAG_LIMIT]; // metatables for tagged userdata TString* lightuserdataname[LUA_LUTAG_LIMIT]; // names for tagged lightuserdata @@ -266,7 +266,7 @@ struct lua_State int cachedslot; // when table operations or INDEX/NEWINDEX is invoked from Luau, what is the expected slot for lookup? - Table* gt; // table of globals + LuaTable* gt; // table of globals UpVal* openupval; // list of open upvalues in this stack GCObject* gclist; @@ -285,7 +285,7 @@ union GCObject struct TString ts; struct Udata u; struct Closure cl; - struct Table h; + struct LuaTable h; struct Proto p; struct UpVal uv; struct lua_State th; // thread diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index b5a4bd13..5c9402f9 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,6 +8,8 @@ #include #include +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauStringFormatFixC, false) + // macro to `unsign' a character #define uchar(c) ((unsigned char)(c)) @@ -999,8 +1001,17 @@ static int str_format(lua_State* L) { case 'c': { - snprintf(buff, sizeof(buff), form, (int)luaL_checknumber(L, arg)); - break; + if (DFFlag::LuauStringFormatFixC) + { + int count = snprintf(buff, sizeof(buff), form, (int)luaL_checknumber(L, arg)); + luaL_addlstring(&b, buff, count); + continue; // skip the 'luaL_addlstring' at the end + } + else + { + snprintf(buff, sizeof(buff), form, (int)luaL_checknumber(L, arg)); + break; + } } case 'd': case 'i': diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index dafb2b3f..ee5ae7ec 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -58,7 +58,7 @@ const LuaNode luaH_dummynode = { #define hashstr(t, str) hashpow2(t, (str)->hash) #define hashboolean(t, p) hashpow2(t, p) -static LuaNode* hashpointer(const Table* t, const void* p) +static LuaNode* hashpointer(const LuaTable* t, const void* p) { // we discard the high 32-bit portion of the pointer on 64-bit platforms as it doesn't carry much entropy anyway unsigned int h = unsigned(uintptr_t(p)); @@ -73,7 +73,7 @@ static LuaNode* hashpointer(const Table* t, const void* p) return hashpow2(t, h); } -static LuaNode* hashnum(const Table* t, double n) +static LuaNode* hashnum(const LuaTable* t, double n) { static_assert(sizeof(double) == sizeof(unsigned int) * 2, "expected a 8-byte double"); unsigned int i[2]; @@ -99,7 +99,7 @@ static LuaNode* hashnum(const Table* t, double n) return hashpow2(t, h2); } -static LuaNode* hashvec(const Table* t, const float* v) +static LuaNode* hashvec(const LuaTable* t, const float* v) { unsigned int i[LUA_VECTOR_SIZE]; memcpy(i, v, sizeof(i)); @@ -130,7 +130,7 @@ static LuaNode* hashvec(const Table* t, const float* v) ** returns the `main' position of an element in a table (that is, the index ** of its hash value) */ -static LuaNode* mainposition(const Table* t, const TValue* key) +static LuaNode* mainposition(const LuaTable* t, const TValue* key) { switch (ttype(key)) { @@ -166,7 +166,7 @@ static int arrayindex(double key) ** elements in the array part, then elements in the hash part. The ** beginning of a traversal is signalled by -1. */ -static int findindex(lua_State* L, Table* t, StkId key) +static int findindex(lua_State* L, LuaTable* t, StkId key) { int i; if (ttisnil(key)) @@ -194,7 +194,7 @@ static int findindex(lua_State* L, Table* t, StkId key) } } -int luaH_next(lua_State* L, Table* t, StkId key) +int luaH_next(lua_State* L, LuaTable* t, StkId key) { int i = findindex(L, t, key); // find original element for (i++; i < t->sizearray; i++) @@ -270,7 +270,7 @@ static int countint(double key, int* nums) return 0; } -static int numusearray(const Table* t, int* nums) +static int numusearray(const LuaTable* t, int* nums) { int lg; int ttlg; // 2^lg @@ -298,7 +298,7 @@ static int numusearray(const Table* t, int* nums) return ause; } -static int numusehash(const Table* t, int* nums, int* pnasize) +static int numusehash(const LuaTable* t, int* nums, int* pnasize) { int totaluse = 0; // total number of elements int ause = 0; // summation of `nums' @@ -317,7 +317,7 @@ static int numusehash(const Table* t, int* nums, int* pnasize) return totaluse; } -static void setarrayvector(lua_State* L, Table* t, int size) +static void setarrayvector(lua_State* L, LuaTable* t, int size) { if (size > MAXSIZE) luaG_runerror(L, "table overflow"); @@ -328,7 +328,7 @@ static void setarrayvector(lua_State* L, Table* t, int size) t->sizearray = size; } -static void setnodevector(lua_State* L, Table* t, int size) +static void setnodevector(lua_State* L, LuaTable* t, int size) { int lsize; if (size == 0) @@ -357,9 +357,9 @@ static void setnodevector(lua_State* L, Table* t, int size) t->lastfree = size; // all positions are free } -static TValue* newkey(lua_State* L, Table* t, const TValue* key); +static TValue* newkey(lua_State* L, LuaTable* t, const TValue* key); -static TValue* arrayornewkey(lua_State* L, Table* t, const TValue* key) +static TValue* arrayornewkey(lua_State* L, LuaTable* t, const TValue* key) { if (ttisnumber(key)) { @@ -373,7 +373,7 @@ static TValue* arrayornewkey(lua_State* L, Table* t, const TValue* key) return newkey(L, t, key); } -static void resize(lua_State* L, Table* t, int nasize, int nhsize) +static void resize(lua_State* L, LuaTable* t, int nasize, int nhsize) { if (nasize > MAXSIZE || nhsize > MAXSIZE) luaG_runerror(L, "table overflow"); @@ -424,7 +424,7 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); // free old array } -static int adjustasize(Table* t, int size, const TValue* ek) +static int adjustasize(LuaTable* t, int size, const TValue* ek) { bool tbound = t->node != dummynode || size < t->sizearray; int ekindex = ek && ttisnumber(ek) ? arrayindex(nvalue(ek)) : -1; @@ -434,19 +434,19 @@ static int adjustasize(Table* t, int size, const TValue* ek) return size; } -void luaH_resizearray(lua_State* L, Table* t, int nasize) +void luaH_resizearray(lua_State* L, LuaTable* t, int nasize) { int nsize = (t->node == dummynode) ? 0 : sizenode(t); int asize = adjustasize(t, nasize, NULL); resize(L, t, asize, nsize); } -void luaH_resizehash(lua_State* L, Table* t, int nhsize) +void luaH_resizehash(lua_State* L, LuaTable* t, int nhsize) { resize(L, t, t->sizearray, nhsize); } -static void rehash(lua_State* L, Table* t, const TValue* ek) +static void rehash(lua_State* L, LuaTable* t, const TValue* ek) { int nums[MAXBITS + 1]; // nums[i] = number of keys between 2^(i-1) and 2^i for (int i = 0; i <= MAXBITS; i++) @@ -491,9 +491,9 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) ** }============================================================= */ -Table* luaH_new(lua_State* L, int narray, int nhash) +LuaTable* luaH_new(lua_State* L, int narray, int nhash) { - Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); + LuaTable* t = luaM_newgco(L, LuaTable, sizeof(LuaTable), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = NULL; t->tmcache = cast_byte(~0); @@ -512,16 +512,16 @@ Table* luaH_new(lua_State* L, int narray, int nhash) return t; } -void luaH_free(lua_State* L, Table* t, lua_Page* page) +void luaH_free(lua_State* L, LuaTable* t, lua_Page* page) { if (t->node != dummynode) luaM_freearray(L, t->node, sizenode(t), LuaNode, t->memcat); if (t->array) luaM_freearray(L, t->array, t->sizearray, TValue, t->memcat); - luaM_freegco(L, t, sizeof(Table), t->memcat, page); + luaM_freegco(L, t, sizeof(LuaTable), t->memcat, page); } -static LuaNode* getfreepos(Table* t) +static LuaNode* getfreepos(LuaTable* t) { while (t->lastfree > 0) { @@ -541,7 +541,7 @@ static LuaNode* getfreepos(Table* t) ** put new key in its main position; otherwise (colliding node is in its main ** position), new key goes to an empty position. */ -static TValue* newkey(lua_State* L, Table* t, const TValue* key) +static TValue* newkey(lua_State* L, LuaTable* t, const TValue* key) { // enforce boundary invariant if (ttisnumber(key) && nvalue(key) == t->sizearray + 1) @@ -601,7 +601,7 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) /* ** search function for integers */ -const TValue* luaH_getnum(Table* t, int key) +const TValue* luaH_getnum(LuaTable* t, int key) { // (1 <= key && key <= t->sizearray) if (cast_to(unsigned int, key - 1) < cast_to(unsigned int, t->sizearray)) @@ -627,7 +627,7 @@ const TValue* luaH_getnum(Table* t, int key) /* ** search function for strings */ -const TValue* luaH_getstr(Table* t, TString* key) +const TValue* luaH_getstr(LuaTable* t, TString* key) { LuaNode* n = hashstr(t, key); for (;;) @@ -644,7 +644,7 @@ const TValue* luaH_getstr(Table* t, TString* key) /* ** main search function */ -const TValue* luaH_get(Table* t, const TValue* key) +const TValue* luaH_get(LuaTable* t, const TValue* key) { switch (ttype(key)) { @@ -677,7 +677,7 @@ const TValue* luaH_get(Table* t, const TValue* key) } } -TValue* luaH_set(lua_State* L, Table* t, const TValue* key) +TValue* luaH_set(lua_State* L, LuaTable* t, const TValue* key) { const TValue* p = luaH_get(t, key); invalidateTMcache(t); @@ -687,7 +687,7 @@ TValue* luaH_set(lua_State* L, Table* t, const TValue* key) return luaH_newkey(L, t, key); } -TValue* luaH_newkey(lua_State* L, Table* t, const TValue* key) +TValue* luaH_newkey(lua_State* L, LuaTable* t, const TValue* key) { if (ttisnil(key)) luaG_runerror(L, "table index is nil"); @@ -698,7 +698,7 @@ TValue* luaH_newkey(lua_State* L, Table* t, const TValue* key) return newkey(L, t, key); } -TValue* luaH_setnum(lua_State* L, Table* t, int key) +TValue* luaH_setnum(lua_State* L, LuaTable* t, int key) { // (1 <= key && key <= t->sizearray) if (cast_to(unsigned int, key - 1) < cast_to(unsigned int, t->sizearray)) @@ -715,7 +715,7 @@ TValue* luaH_setnum(lua_State* L, Table* t, int key) } } -TValue* luaH_setstr(lua_State* L, Table* t, TString* key) +TValue* luaH_setstr(lua_State* L, LuaTable* t, TString* key) { const TValue* p = luaH_getstr(t, key); invalidateTMcache(t); @@ -729,7 +729,7 @@ TValue* luaH_setstr(lua_State* L, Table* t, TString* key) } } -static int updateaboundary(Table* t, int boundary) +static int updateaboundary(LuaTable* t, int boundary) { if (boundary < t->sizearray && ttisnil(&t->array[boundary - 1])) { @@ -752,7 +752,7 @@ static int updateaboundary(Table* t, int boundary) ** Try to find a boundary in table `t'. A `boundary' is an integer index ** such that t[i] is non-nil and t[i+1] is nil (and 0 if t[1] is nil). */ -int luaH_getn(Table* t) +int luaH_getn(LuaTable* t) { int boundary = getaboundary(t); @@ -793,9 +793,9 @@ int luaH_getn(Table* t) } } -Table* luaH_clone(lua_State* L, Table* tt) +LuaTable* luaH_clone(lua_State* L, LuaTable* tt) { - Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); + LuaTable* t = luaM_newgco(L, LuaTable, sizeof(LuaTable), L->activememcat); luaC_init(L, t, LUA_TTABLE); t->metatable = tt->metatable; t->tmcache = tt->tmcache; @@ -830,7 +830,7 @@ Table* luaH_clone(lua_State* L, Table* tt) return t; } -void luaH_clear(Table* tt) +void luaH_clear(LuaTable* tt) { // clear array part for (int i = 0; i < tt->sizearray; ++i) diff --git a/VM/src/ltable.h b/VM/src/ltable.h index 021f21bf..50d1e643 100644 --- a/VM/src/ltable.h +++ b/VM/src/ltable.h @@ -14,21 +14,21 @@ // reset cache of absent metamethods, cache is updated in luaT_gettm #define invalidateTMcache(t) t->tmcache = 0 -LUAI_FUNC const TValue* luaH_getnum(Table* t, int key); -LUAI_FUNC TValue* luaH_setnum(lua_State* L, Table* t, int key); -LUAI_FUNC const TValue* luaH_getstr(Table* t, TString* key); -LUAI_FUNC TValue* luaH_setstr(lua_State* L, Table* t, TString* key); -LUAI_FUNC const TValue* luaH_get(Table* t, const TValue* key); -LUAI_FUNC TValue* luaH_set(lua_State* L, Table* t, const TValue* key); -LUAI_FUNC TValue* luaH_newkey(lua_State* L, Table* t, const TValue* key); -LUAI_FUNC Table* luaH_new(lua_State* L, int narray, int lnhash); -LUAI_FUNC void luaH_resizearray(lua_State* L, Table* t, int nasize); -LUAI_FUNC void luaH_resizehash(lua_State* L, Table* t, int nhsize); -LUAI_FUNC void luaH_free(lua_State* L, Table* t, struct lua_Page* page); -LUAI_FUNC int luaH_next(lua_State* L, Table* t, StkId key); -LUAI_FUNC int luaH_getn(Table* t); -LUAI_FUNC Table* luaH_clone(lua_State* L, Table* tt); -LUAI_FUNC void luaH_clear(Table* tt); +LUAI_FUNC const TValue* luaH_getnum(LuaTable* t, int key); +LUAI_FUNC TValue* luaH_setnum(lua_State* L, LuaTable* t, int key); +LUAI_FUNC const TValue* luaH_getstr(LuaTable* t, TString* key); +LUAI_FUNC TValue* luaH_setstr(lua_State* L, LuaTable* t, TString* key); +LUAI_FUNC const TValue* luaH_get(LuaTable* t, const TValue* key); +LUAI_FUNC TValue* luaH_set(lua_State* L, LuaTable* t, const TValue* key); +LUAI_FUNC TValue* luaH_newkey(lua_State* L, LuaTable* t, const TValue* key); +LUAI_FUNC LuaTable* luaH_new(lua_State* L, int narray, int lnhash); +LUAI_FUNC void luaH_resizearray(lua_State* L, LuaTable* t, int nasize); +LUAI_FUNC void luaH_resizehash(lua_State* L, LuaTable* t, int nhsize); +LUAI_FUNC void luaH_free(lua_State* L, LuaTable* t, struct lua_Page* page); +LUAI_FUNC int luaH_next(lua_State* L, LuaTable* t, StkId key); +LUAI_FUNC int luaH_getn(LuaTable* t); +LUAI_FUNC LuaTable* luaH_clone(lua_State* L, LuaTable* tt); +LUAI_FUNC void luaH_clear(LuaTable* tt); #define luaH_setslot(L, t, slot, key) (invalidateTMcache(t), (slot == luaO_nilobject ? luaH_newkey(L, t, key) : cast_to(TValue*, slot))) diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 75d9f400..dbe60e4e 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -53,7 +53,7 @@ static int maxn(lua_State* L) double max = 0; luaL_checktype(L, 1, LUA_TTABLE); - Table* t = hvalue(L->base); + LuaTable* t = hvalue(L->base); for (int i = 0; i < t->sizearray; i++) { @@ -87,8 +87,8 @@ static int getn(lua_State* L) static void moveelements(lua_State* L, int srct, int dstt, int f, int e, int t) { - Table* src = hvalue(L->base + (srct - 1)); - Table* dst = hvalue(L->base + (dstt - 1)); + LuaTable* src = hvalue(L->base + (srct - 1)); + LuaTable* dst = hvalue(L->base + (dstt - 1)); if (dst->readonly) luaG_readonlyerror(L); @@ -213,7 +213,7 @@ static int tmove(lua_State* L) int n = e - f + 1; // number of elements to move luaL_argcheck(L, t <= INT_MAX - n + 1, 4, "destination wrap around"); - Table* dst = hvalue(L->base + (tt - 1)); + LuaTable* dst = hvalue(L->base + (tt - 1)); if (dst->readonly) // also checked in moveelements, but this blocks resizes of r/o tables luaG_readonlyerror(L); @@ -229,7 +229,7 @@ static int tmove(lua_State* L) return 1; } -static void addfield(lua_State* L, luaL_Strbuf* b, int i, Table* t) +static void addfield(lua_State* L, luaL_Strbuf* b, int i, LuaTable* t) { if (t && unsigned(i - 1) < unsigned(t->sizearray) && ttisstring(&t->array[i - 1])) { @@ -253,7 +253,7 @@ static int tconcat(lua_State* L) int i = luaL_optinteger(L, 3, 1); int last = luaL_opt(L, luaL_checkinteger, 4, lua_objlen(L, 1)); - Table* t = hvalue(L->base); + LuaTable* t = hvalue(L->base); luaL_Strbuf b; luaL_buffinit(L, &b); @@ -274,7 +274,7 @@ static int tpack(lua_State* L) int n = lua_gettop(L); // number of elements to pack lua_createtable(L, n, 1); // create result table - Table* t = hvalue(L->top - 1); + LuaTable* t = hvalue(L->top - 1); for (int i = 0; i < n; ++i) { @@ -292,7 +292,7 @@ static int tpack(lua_State* L) static int tunpack(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); - Table* t = hvalue(L->base); + LuaTable* t = hvalue(L->base); int i = luaL_optinteger(L, 2, 1); int e = luaL_opt(L, luaL_checkinteger, 3, lua_objlen(L, 1)); @@ -335,7 +335,7 @@ static int sort_func(lua_State* L, const TValue* l, const TValue* r) return !l_isfalse(L->top); } -inline void sort_swap(lua_State* L, Table* t, int i, int j) +inline void sort_swap(lua_State* L, LuaTable* t, int i, int j) { TValue* arr = t->array; int n = t->sizearray; @@ -348,7 +348,7 @@ inline void sort_swap(lua_State* L, Table* t, int i, int j) setobj2t(L, &arr[j], &temp); } -inline int sort_less(lua_State* L, Table* t, int i, int j, SortPredicate pred) +inline int sort_less(lua_State* L, LuaTable* t, int i, int j, SortPredicate pred) { TValue* arr = t->array; int n = t->sizearray; @@ -363,7 +363,7 @@ inline int sort_less(lua_State* L, Table* t, int i, int j, SortPredicate pred) return res; } -static void sort_siftheap(lua_State* L, Table* t, int l, int u, SortPredicate pred, int root) +static void sort_siftheap(lua_State* L, LuaTable* t, int l, int u, SortPredicate pred, int root) { LUAU_ASSERT(l <= u); int count = u - l + 1; @@ -389,7 +389,7 @@ static void sort_siftheap(lua_State* L, Table* t, int l, int u, SortPredicate pr sort_swap(L, t, l + root, l + lastleft); } -static void sort_heap(lua_State* L, Table* t, int l, int u, SortPredicate pred) +static void sort_heap(lua_State* L, LuaTable* t, int l, int u, SortPredicate pred) { LUAU_ASSERT(l <= u); int count = u - l + 1; @@ -404,7 +404,7 @@ static void sort_heap(lua_State* L, Table* t, int l, int u, SortPredicate pred) } } -static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredicate pred) +static void sort_rec(lua_State* L, LuaTable* t, int l, int u, int limit, SortPredicate pred) { // sort range [l..u] (inclusive, 0-based) while (l < u) @@ -477,7 +477,7 @@ static void sort_rec(lua_State* L, Table* t, int l, int u, int limit, SortPredic static int tsort(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); - Table* t = hvalue(L->base); + LuaTable* t = hvalue(L->base); int n = luaH_getn(t); if (t->readonly) luaG_readonlyerror(L); @@ -504,7 +504,7 @@ static int tcreate(lua_State* L) if (!lua_isnoneornil(L, 2)) { lua_createtable(L, size, 0); - Table* t = hvalue(L->top - 1); + LuaTable* t = hvalue(L->top - 1); StkId v = L->base + 1; @@ -530,7 +530,7 @@ static int tfind(lua_State* L) if (init < 1) luaL_argerror(L, 3, "index out of range"); - Table* t = hvalue(L->base); + LuaTable* t = hvalue(L->base); StkId v = L->base + 1; for (int i = init;; ++i) @@ -554,7 +554,7 @@ static int tclear(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); - Table* tt = hvalue(L->base); + LuaTable* tt = hvalue(L->base); if (tt->readonly) luaG_readonlyerror(L); @@ -587,7 +587,7 @@ static int tclone(lua_State* L) luaL_checktype(L, 1, LUA_TTABLE); luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); - Table* tt = luaH_clone(L, hvalue(L->base)); + LuaTable* tt = luaH_clone(L, hvalue(L->base)); TValue v; sethvalue(L, &v, tt); diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index 16775f9b..f6b0079a 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -10,8 +10,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauPreserveLudataRenaming, false) - // clang-format off const char* const luaT_typenames[] = { // ORDER TYPE @@ -88,7 +86,7 @@ void luaT_init(lua_State* L) ** function to be used with macro "fasttm": optimized for absence of ** tag methods. */ -const TValue* luaT_gettm(Table* events, TMS event, TString* ename) +const TValue* luaT_gettm(LuaTable* events, TMS event, TString* ename) { const TValue* tm = luaH_getstr(events, ename); LUAU_ASSERT(event <= TM_EQ); @@ -107,7 +105,7 @@ const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event) NB: Tag-methods were replaced by meta-methods in Lua 5.0, but the old names are still around (this function, for example). */ - Table* mt; + LuaTable* mt; switch (ttype(o)) { case LUA_TTABLE: @@ -124,74 +122,40 @@ const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event) const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) { - if (FFlag::LuauPreserveLudataRenaming) + // Userdata created by the environment can have a custom type name set in the individual metatable + // If there is no custom name, 'userdata' is returned + if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) { - // Userdata created by the environment can have a custom type name set in the individual metatable - // If there is no custom name, 'userdata' is returned - if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) - { - const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); + const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); - if (ttisstring(type)) - return tsvalue(type); - - return L->global->ttname[ttype(o)]; - } - - // Tagged lightuserdata can be named using lua_setlightuserdataname - if (ttislightuserdata(o)) - { - int tag = lightuserdatatag(o); - - if (unsigned(tag) < LUA_LUTAG_LIMIT) - { - if (const TString* name = L->global->lightuserdataname[tag]) - return name; - } - } - - // For all types except userdata and table, a global metatable can be set with a global name override - if (Table* mt = L->global->mt[ttype(o)]) - { - const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); - - if (ttisstring(type)) - return tsvalue(type); - } + if (ttisstring(type)) + return tsvalue(type); return L->global->ttname[ttype(o)]; } - else + + // Tagged lightuserdata can be named using lua_setlightuserdataname + if (ttislightuserdata(o)) { - if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) + int tag = lightuserdatatag(o); + + if (unsigned(tag) < LUA_LUTAG_LIMIT) { - const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); - - if (ttisstring(type)) - return tsvalue(type); + if (const TString* name = L->global->lightuserdataname[tag]) + return name; } - else if (ttislightuserdata(o)) - { - int tag = lightuserdatatag(o); - - if (unsigned(tag) < LUA_LUTAG_LIMIT) - { - const TString* name = L->global->lightuserdataname[tag]; - - if (name) - return name; - } - } - else if (Table* mt = L->global->mt[ttype(o)]) - { - const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); - - if (ttisstring(type)) - return tsvalue(type); - } - - return L->global->ttname[ttype(o)]; } + + // For all types except userdata and table, a global metatable can be set with a global name override + if (LuaTable* mt = L->global->mt[ttype(o)]) + { + const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); + + if (ttisstring(type)) + return tsvalue(type); + } + + return L->global->ttname[ttype(o)]; } const char* luaT_objtypename(lua_State* L, const TValue* o) diff --git a/VM/src/ltm.h b/VM/src/ltm.h index 7dafd4ed..f3294b64 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -51,7 +51,7 @@ typedef enum LUAI_DATA const char* const luaT_typenames[]; LUAI_DATA const char* const luaT_eventname[]; -LUAI_FUNC const TValue* luaT_gettm(Table* events, TMS event, TString* ename); +LUAI_FUNC const TValue* luaT_gettm(LuaTable* events, TMS event, TString* ename); LUAI_FUNC const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event); LUAI_FUNC const TString* luaT_objtypenamestr(lua_State* L, const TValue* o); diff --git a/VM/src/lveclib.cpp b/VM/src/lveclib.cpp new file mode 100644 index 00000000..f24dc8b6 --- /dev/null +++ b/VM/src/lveclib.cpp @@ -0,0 +1,344 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lualib.h" + +#include "lcommon.h" +#include "lnumutils.h" + +#include + +static int vector_create(lua_State* L) +{ + // checking argument count to avoid accepting 'nil' as a valid value + int count = lua_gettop(L); + + double x = luaL_checknumber(L, 1); + double y = luaL_checknumber(L, 2); + double z = count >= 3 ? luaL_checknumber(L, 3) : 0.0; + +#if LUA_VECTOR_SIZE == 4 + double w = count >= 4 ? luaL_checknumber(L, 4) : 0.0; + + lua_pushvector(L, float(x), float(y), float(z), float(w)); +#else + lua_pushvector(L, float(x), float(y), float(z)); +#endif + + return 1; +} + +static int vector_magnitude(lua_State* L) +{ + const float* v = luaL_checkvector(L, 1); + +#if LUA_VECTOR_SIZE == 4 + lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + v[3] * v[3])); +#else + lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); +#endif + + return 1; +} + +static int vector_normalize(lua_State* L) +{ + const float* v = luaL_checkvector(L, 1); + +#if LUA_VECTOR_SIZE == 4 + float invSqrt = 1.0f / sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2] + v[3] * v[3]); + + lua_pushvector(L, v[0] * invSqrt, v[1] * invSqrt, v[2] * invSqrt, v[3] * invSqrt); +#else + float invSqrt = 1.0f / sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]); + + lua_pushvector(L, v[0] * invSqrt, v[1] * invSqrt, v[2] * invSqrt); +#endif + + return 1; +} + +static int vector_cross(lua_State* L) +{ + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0], 0.0f); +#else + lua_pushvector(L, a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]); +#endif + + return 1; +} + +static int vector_dot(lua_State* L) +{ + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); + +#if LUA_VECTOR_SIZE == 4 + lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + a[3] * b[3]); +#else + lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); +#endif + + return 1; +} + +static int vector_angle(lua_State* L) +{ + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); + const float* axis = luaL_optvector(L, 3, nullptr); + + // cross(a, b) + float cross[] = {a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]}; + + double sinA = sqrt(cross[0] * cross[0] + cross[1] * cross[1] + cross[2] * cross[2]); + double cosA = a[0] * b[0] + a[1] * b[1] + a[2] * b[2]; + double angle = atan2(sinA, cosA); + + if (axis) + { + if (cross[0] * axis[0] + cross[1] * axis[1] + cross[2] * axis[2] < 0.0f) + angle = -angle; + } + + lua_pushnumber(L, angle); + return 1; +} + +static int vector_floor(lua_State* L) +{ + const float* v = luaL_checkvector(L, 1); + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, floorf(v[0]), floorf(v[1]), floorf(v[2]), floorf(v[3])); +#else + lua_pushvector(L, floorf(v[0]), floorf(v[1]), floorf(v[2])); +#endif + + return 1; +} + +static int vector_ceil(lua_State* L) +{ + const float* v = luaL_checkvector(L, 1); + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, ceilf(v[0]), ceilf(v[1]), ceilf(v[2]), ceilf(v[3])); +#else + lua_pushvector(L, ceilf(v[0]), ceilf(v[1]), ceilf(v[2])); +#endif + + return 1; +} + +static int vector_abs(lua_State* L) +{ + const float* v = luaL_checkvector(L, 1); + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, fabsf(v[0]), fabsf(v[1]), fabsf(v[2]), fabsf(v[3])); +#else + lua_pushvector(L, fabsf(v[0]), fabsf(v[1]), fabsf(v[2])); +#endif + + return 1; +} + +static int vector_sign(lua_State* L) +{ + const float* v = luaL_checkvector(L, 1); + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, luaui_signf(v[0]), luaui_signf(v[1]), luaui_signf(v[2]), luaui_signf(v[3])); +#else + lua_pushvector(L, luaui_signf(v[0]), luaui_signf(v[1]), luaui_signf(v[2])); +#endif + + return 1; +} + +static int vector_clamp(lua_State* L) +{ + const float* v = luaL_checkvector(L, 1); + const float* min = luaL_checkvector(L, 2); + const float* max = luaL_checkvector(L, 3); + + luaL_argcheck(L, min[0] <= max[0], 3, "max.x must be greater than or equal to min.x"); + luaL_argcheck(L, min[1] <= max[1], 3, "max.y must be greater than or equal to min.y"); + luaL_argcheck(L, min[2] <= max[2], 3, "max.z must be greater than or equal to min.z"); + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector( + L, + luaui_clampf(v[0], min[0], max[0]), + luaui_clampf(v[1], min[1], max[1]), + luaui_clampf(v[2], min[2], max[2]), + luaui_clampf(v[3], min[3], max[3]) + ); +#else + lua_pushvector(L, luaui_clampf(v[0], min[0], max[0]), luaui_clampf(v[1], min[1], max[1]), luaui_clampf(v[2], min[2], max[2])); +#endif + + return 1; +} + +static int vector_min(lua_State* L) +{ + int n = lua_gettop(L); + const float* v = luaL_checkvector(L, 1); + +#if LUA_VECTOR_SIZE == 4 + float result[] = {v[0], v[1], v[2], v[3]}; +#else + float result[] = {v[0], v[1], v[2]}; +#endif + + for (int i = 2; i <= n; i++) + { + const float* b = luaL_checkvector(L, i); + + if (b[0] < result[0]) + result[0] = b[0]; + if (b[1] < result[1]) + result[1] = b[1]; + if (b[2] < result[2]) + result[2] = b[2]; +#if LUA_VECTOR_SIZE == 4 + if (b[3] < result[3]) + result[3] = b[3]; +#endif + } + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, result[0], result[1], result[2], result[3]); +#else + lua_pushvector(L, result[0], result[1], result[2]); +#endif + + return 1; +} + +static int vector_max(lua_State* L) +{ + int n = lua_gettop(L); + const float* v = luaL_checkvector(L, 1); + +#if LUA_VECTOR_SIZE == 4 + float result[] = {v[0], v[1], v[2], v[3]}; +#else + float result[] = {v[0], v[1], v[2]}; +#endif + + for (int i = 2; i <= n; i++) + { + const float* b = luaL_checkvector(L, i); + + if (b[0] > result[0]) + result[0] = b[0]; + if (b[1] > result[1]) + result[1] = b[1]; + if (b[2] > result[2]) + result[2] = b[2]; +#if LUA_VECTOR_SIZE == 4 + if (b[3] > result[3]) + result[3] = b[3]; +#endif + } + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, result[0], result[1], result[2], result[3]); +#else + lua_pushvector(L, result[0], result[1], result[2]); +#endif + + return 1; +} + +static int vector_index(lua_State* L) +{ + const float* v = luaL_checkvector(L, 1); + size_t namelen = 0; + const char* name = luaL_checklstring(L, 2, &namelen); + + // field access implementation mirrors the fast-path we have in the VM + if (namelen == 1) + { + int ic = (name[0] | ' ') - 'x'; + +#if LUA_VECTOR_SIZE == 4 + // 'w' is before 'x' in ascii, so ic is -1 when indexing with 'w' + if (ic == -1) + ic = 3; +#endif + + if (unsigned(ic) < LUA_VECTOR_SIZE) + { + lua_pushnumber(L, v[ic]); + return 1; + } + } + + luaL_error(L, "attempt to index vector with '%s'", name); +} + +static const luaL_Reg vectorlib[] = { + {"create", vector_create}, + {"magnitude", vector_magnitude}, + {"normalize", vector_normalize}, + {"cross", vector_cross}, + {"dot", vector_dot}, + {"angle", vector_angle}, + {"floor", vector_floor}, + {"ceil", vector_ceil}, + {"abs", vector_abs}, + {"sign", vector_sign}, + {"clamp", vector_clamp}, + {"max", vector_max}, + {"min", vector_min}, + {NULL, NULL}, +}; + +static void createmetatable(lua_State* L) +{ + lua_createtable(L, 0, 1); // create metatable for vectors + + // push dummy vector +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); +#else + lua_pushvector(L, 0.0f, 0.0f, 0.0f); +#endif + + lua_pushvalue(L, -2); + lua_setmetatable(L, -2); // set vector metatable + lua_pop(L, 1); // pop dummy vector + + lua_pushcfunction(L, vector_index, nullptr); + lua_setfield(L, -2, "__index"); + + lua_setreadonly(L, -1, true); + lua_pop(L, 1); // pop the metatable +} + +int luaopen_vector(lua_State* L) +{ + luaL_register(L, LUA_VECLIBNAME, vectorlib); + +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); + lua_setfield(L, -2, "zero"); + lua_pushvector(L, 1.0f, 1.0f, 1.0f, 1.0f); + lua_setfield(L, -2, "one"); +#else + lua_pushvector(L, 0.0f, 0.0f, 0.0f); + lua_setfield(L, -2, "zero"); + lua_pushvector(L, 1.0f, 1.0f, 1.0f); + lua_setfield(L, -2, "one"); +#endif + + createmetatable(L); + + return 1; +} diff --git a/VM/src/lvm.h b/VM/src/lvm.h index 0b8690be..6989bcee 100644 --- a/VM/src/lvm.h +++ b/VM/src/lvm.h @@ -26,7 +26,7 @@ LUAI_FUNC int luaV_tostring(lua_State* L, StkId obj); LUAI_FUNC void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId val); LUAI_FUNC void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val); LUAI_FUNC void luaV_concat(lua_State* L, int total, int last); -LUAI_FUNC void luaV_getimport(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil); +LUAI_FUNC void luaV_getimport(lua_State* L, LuaTable* env, TValue* k, StkId res, uint32_t id, bool propagatenil); LUAI_FUNC void luaV_prepareFORN(lua_State* L, StkId plimit, StkId pstep, StkId pinit); LUAI_FUNC void luaV_callTM(lua_State* L, int nparams, int res); LUAI_FUNC void luaV_tryfuncTM(lua_State* L, StkId func); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 0b26f079..ce07d878 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,6 +16,8 @@ #include +LUAU_DYNAMIC_FASTFLAG(LuauPopIncompleteCi) + // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -328,7 +330,7 @@ reentry: LUAU_ASSERT(ttisstring(kv)); // fast-path: value is in expected slot - Table* h = cl->env; + LuaTable* h = cl->env; int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -359,7 +361,7 @@ reentry: LUAU_ASSERT(ttisstring(kv)); // fast-path: value is in expected slot - Table* h = cl->env; + LuaTable* h = cl->env; int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -449,7 +451,7 @@ reentry: // fast-path: built-in table if (LUAU_LIKELY(ttistable(rb))) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -566,7 +568,7 @@ reentry: // fast-path: built-in table if (LUAU_LIKELY(ttistable(rb))) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -640,7 +642,7 @@ reentry: // fast-path: array lookup if (ttistable(rb) && ttisnumber(rc)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); double indexd = nvalue(rc); int index = int(indexd); @@ -670,7 +672,7 @@ reentry: // fast-path: array assign if (ttistable(rb) && ttisnumber(rc)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); double indexd = nvalue(rc); int index = int(indexd); @@ -701,7 +703,7 @@ reentry: // fast-path: array lookup if (ttistable(rb)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); if (LUAU_LIKELY(unsigned(c) < unsigned(h->sizearray) && !h->metatable)) { @@ -729,7 +731,7 @@ reentry: // fast-path: array assign if (ttistable(rb)) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); if (LUAU_LIKELY(unsigned(c) < unsigned(h->sizearray) && !h->metatable && !h->readonly)) { @@ -802,7 +804,7 @@ reentry: if (LUAU_LIKELY(ttistable(rb))) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); // note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works // for predictive lookups LuaNode* n = &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; @@ -842,7 +844,7 @@ reentry: } else { - Table* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; + LuaTable* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; const TValue* tmi = 0; // fast-path: metatable with __namecall @@ -856,7 +858,7 @@ reentry: } else if ((tmi = fasttm(L, mt, TM_INDEX)) && ttistable(tmi)) { - Table* h = hvalue(tmi); + LuaTable* h = hvalue(tmi); int slot = LUAU_INSN_C(insn) & h->nodemask8; LuaNode* n = &h->node[slot]; @@ -935,7 +937,14 @@ reentry: // note: this reallocs stack, but we don't need to VM_PROTECT this // this is because we're going to modify base/savedpc manually anyhow // crucially, we can't use ra/argtop after this line - luaD_checkstack(L, ccl->stacksize); + if (DFFlag::LuauPopIncompleteCi) + { + luaD_checkstackfornewci(L, ccl->stacksize); + } + else + { + luaD_checkstack(L, ccl->stacksize); + } LUAU_ASSERT(ci->top <= L->stack_last); @@ -2117,7 +2126,7 @@ reentry: // fast-path #1: tables if (LUAU_LIKELY(ttistable(rb))) { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); if (fastnotm(h->metatable, TM_LEN)) { @@ -2187,7 +2196,7 @@ reentry: L->top = L->ci->top; } - Table* h = hvalue(ra); + LuaTable* h = hvalue(ra); // TODO: we really don't need this anymore if (!ttistable(ra)) @@ -2272,7 +2281,7 @@ reentry: } else { - Table* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(Table*, NULL); + LuaTable* mt = ttistable(ra) ? hvalue(ra)->metatable : ttisuserdata(ra) ? uvalue(ra)->metatable : cast_to(LuaTable*, NULL); if (const TValue* fn = fasttm(L, mt, TM_ITER)) { @@ -2331,7 +2340,7 @@ reentry: // TODO: remove the table check per guarantee above if (ttisnil(ra) && ttistable(ra + 1)) { - Table* h = hvalue(ra + 1); + LuaTable* h = hvalue(ra + 1); int index = int(reinterpret_cast(pvalue(ra + 2))); int sizearray = h->sizearray; @@ -2923,10 +2932,13 @@ reentry: { VM_PROTECT_PC(); // f may fail due to OOM - setobj2s(L, L->top, arg2); - setobj2s(L, L->top + 1, arg3); + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 2 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top, arg2); + setobj2s(L, top + 1, arg3); - int n = f(L, ra, arg1, nresults, L->top, nparams); + int n = f(L, ra, arg1, nresults, top, nparams); if (n >= 0) { @@ -3068,7 +3080,14 @@ int luau_precall(lua_State* L, StkId func, int nresults) L->base = ci->base; // Note: L->top is assigned externally - luaD_checkstack(L, ccl->stacksize); + if (DFFlag::LuauPopIncompleteCi) + { + luaD_checkstackfornewci(L, ccl->stacksize); + } + else + { + luaD_checkstack(L, ccl->stacksize); + } LUAU_ASSERT(ci->top <= L->stack_last); if (!ccl->isC) diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index aa248fc1..2a3443eb 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -72,7 +72,7 @@ private: size_t originalThreshold = 0; }; -void luaV_getimport(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil) +void luaV_getimport(lua_State* L, LuaTable* env, TValue* k, StkId res, uint32_t id, bool propagatenil) { int count = id >> 30; LUAU_ASSERT(count > 0); @@ -141,7 +141,7 @@ static TString* readString(TempBuffer& strings, const char* data, size return id == 0 ? NULL : strings[id - 1]; } -static void resolveImportSafe(lua_State* L, Table* env, TValue* k, uint32_t id) +static void resolveImportSafe(lua_State* L, LuaTable* env, TValue* k, uint32_t id) { struct ResolveImport { @@ -273,7 +273,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size const ScopedSetGCThreshold pauseGC{L->global, SIZE_MAX}; // env is 0 for current environment and a stack index otherwise - Table* envt = (env == 0) ? L->gt : hvalue(luaA_toobject(L, env)); + LuaTable* envt = (env == 0) ? L->gt : hvalue(luaA_toobject(L, env)); TString* source = luaS_new(L, chunkname); @@ -481,7 +481,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size case LBC_CONSTANT_TABLE: { int keys = readVarInt(data, size, offset); - Table* h = luaH_new(L, 0, keys); + LuaTable* h = luaH_new(L, 0, keys); for (int i = 0; i < keys; ++i) { int key = readVarInt(data, size, offset); diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 0cf9d206..5c49139f 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -101,7 +101,7 @@ void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId val) const TValue* tm; if (ttistable(t)) { // `t' is a table? - Table* h = hvalue(t); + LuaTable* h = hvalue(t); const TValue* res = luaH_get(h, key); // do a primitive get @@ -137,7 +137,7 @@ void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val) const TValue* tm; if (ttistable(t)) { // `t' is a table? - Table* h = hvalue(t); + LuaTable* h = hvalue(t); const TValue* oldval = luaH_get(h, key); @@ -185,7 +185,7 @@ static int call_binTM(lua_State* L, const TValue* p1, const TValue* p2, StkId re return 1; } -static const TValue* get_compTM(lua_State* L, Table* mt1, Table* mt2, TMS event) +static const TValue* get_compTM(lua_State* L, LuaTable* mt1, LuaTable* mt2, TMS event) { const TValue* tm1 = fasttm(L, mt1, event); const TValue* tm2; @@ -533,7 +533,7 @@ void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) { case LUA_TTABLE: { - Table* h = hvalue(rb); + LuaTable* h = hvalue(rb); if ((tm = fasttm(L, h->metatable, TM_LEN)) == NULL) { setnvalue(ra, cast_num(luaH_getn(h))); diff --git a/bench/bench_support.lua b/bench/bench_support.lua index da637ac9..b731c2fc 100644 --- a/bench/bench_support.lua +++ b/bench/bench_support.lua @@ -66,7 +66,7 @@ end -- and 'false' otherwise. -- -- Example usage: --- local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +-- local function prequire(name) local success, result = pcall(require, name); return success and result end -- local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") -- function testFunc() -- ... diff --git a/bench/gc/test_BinaryTree.lua b/bench/gc/test_BinaryTree.lua index 36dff9de..b7a36d73 100644 --- a/bench/gc/test_BinaryTree.lua +++ b/bench/gc/test_BinaryTree.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_Boehm_Trees.lua b/bench/gc/test_GC_Boehm_Trees.lua index 8170103d..3a3a3698 100644 --- a/bench/gc/test_GC_Boehm_Trees.lua +++ b/bench/gc/test_GC_Boehm_Trees.lua @@ -1,5 +1,5 @@ --!nonstrict -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local stretchTreeDepth = 18 -- about 16Mb diff --git a/bench/gc/test_GC_Tree_Pruning_Eager.lua b/bench/gc/test_GC_Tree_Pruning_Eager.lua index 38aa7626..7a086254 100644 --- a/bench/gc/test_GC_Tree_Pruning_Eager.lua +++ b/bench/gc/test_GC_Tree_Pruning_Eager.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_Tree_Pruning_Gen.lua b/bench/gc/test_GC_Tree_Pruning_Gen.lua index 85081f70..eb747e77 100644 --- a/bench/gc/test_GC_Tree_Pruning_Gen.lua +++ b/bench/gc/test_GC_Tree_Pruning_Gen.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_Tree_Pruning_Lazy.lua b/bench/gc/test_GC_Tree_Pruning_Lazy.lua index 834ec1ab..16b68083 100644 --- a/bench/gc/test_GC_Tree_Pruning_Lazy.lua +++ b/bench/gc/test_GC_Tree_Pruning_Lazy.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_GC_hashtable_Keyval.lua b/bench/gc/test_GC_hashtable_Keyval.lua index aa7481d3..6e59072c 100644 --- a/bench/gc/test_GC_hashtable_Keyval.lua +++ b/bench/gc/test_GC_hashtable_Keyval.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_LB_mandel.lua b/bench/gc/test_LB_mandel.lua index a8beb4fd..be9977d6 100644 --- a/bench/gc/test_LB_mandel.lua +++ b/bench/gc/test_LB_mandel.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_LargeTableCtor_array.lua b/bench/gc/test_LargeTableCtor_array.lua index 016dfd2d..35b6f449 100644 --- a/bench/gc/test_LargeTableCtor_array.lua +++ b/bench/gc/test_LargeTableCtor_array.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_LargeTableCtor_hash.lua b/bench/gc/test_LargeTableCtor_hash.lua index c46a7ab4..e2b11b4b 100644 --- a/bench/gc/test_LargeTableCtor_hash.lua +++ b/bench/gc/test_LargeTableCtor_hash.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_Pcall_pcall_yield.lua b/bench/gc/test_Pcall_pcall_yield.lua index ae0a4b46..2ae0baa6 100644 --- a/bench/gc/test_Pcall_pcall_yield.lua +++ b/bench/gc/test_Pcall_pcall_yield.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_SunSpider_3d-raytrace.lua b/bench/gc/test_SunSpider_3d-raytrace.lua index 3c050df7..d8f224c4 100644 --- a/bench/gc/test_SunSpider_3d-raytrace.lua +++ b/bench/gc/test_SunSpider_3d-raytrace.lua @@ -22,7 +22,7 @@ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableCreate_nil.lua b/bench/gc/test_TableCreate_nil.lua index 707a2750..546e9d6b 100644 --- a/bench/gc/test_TableCreate_nil.lua +++ b/bench/gc/test_TableCreate_nil.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableCreate_number.lua b/bench/gc/test_TableCreate_number.lua index 3e4305bd..fe8437b7 100644 --- a/bench/gc/test_TableCreate_number.lua +++ b/bench/gc/test_TableCreate_number.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableCreate_zerofill.lua b/bench/gc/test_TableCreate_zerofill.lua index fed439b4..e2cfda30 100644 --- a/bench/gc/test_TableCreate_zerofill.lua +++ b/bench/gc/test_TableCreate_zerofill.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableMarshal_select.lua b/bench/gc/test_TableMarshal_select.lua index 9869da60..df5ebf78 100644 --- a/bench/gc/test_TableMarshal_select.lua +++ b/bench/gc/test_TableMarshal_select.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableMarshal_table_pack.lua b/bench/gc/test_TableMarshal_table_pack.lua index 3da855f5..3d0190e7 100644 --- a/bench/gc/test_TableMarshal_table_pack.lua +++ b/bench/gc/test_TableMarshal_table_pack.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/gc/test_TableMarshal_varargs.lua b/bench/gc/test_TableMarshal_varargs.lua index 64b41b43..b88d8213 100644 --- a/bench/gc/test_TableMarshal_varargs.lua +++ b/bench/gc/test_TableMarshal_varargs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_AbsSum_abs.lua b/bench/micro_tests/test_AbsSum_abs.lua index 7e85646e..ea473556 100644 --- a/bench/micro_tests/test_AbsSum_abs.lua +++ b/bench/micro_tests/test_AbsSum_abs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_AbsSum_and_or.lua b/bench/micro_tests/test_AbsSum_and_or.lua index c6ef3dea..6cd5b4d0 100644 --- a/bench/micro_tests/test_AbsSum_and_or.lua +++ b/bench/micro_tests/test_AbsSum_and_or.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_AbsSum_math_abs.lua b/bench/micro_tests/test_AbsSum_math_abs.lua index e95ea674..e02b710a 100644 --- a/bench/micro_tests/test_AbsSum_math_abs.lua +++ b/bench/micro_tests/test_AbsSum_math_abs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Assert.lua b/bench/micro_tests/test_Assert.lua index 014de8dc..750f411b 100644 --- a/bench/micro_tests/test_Assert.lua +++ b/bench/micro_tests/test_Assert.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Factorial.lua b/bench/micro_tests/test_Factorial.lua index 90cff22a..5dc797ce 100644 --- a/bench/micro_tests/test_Factorial.lua +++ b/bench/micro_tests/test_Factorial.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_pcall_a_bar.lua b/bench/micro_tests/test_Failure_pcall_a_bar.lua index 5b6108ba..95887e58 100644 --- a/bench/micro_tests/test_Failure_pcall_a_bar.lua +++ b/bench/micro_tests/test_Failure_pcall_a_bar.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_pcall_game_Foo.lua b/bench/micro_tests/test_Failure_pcall_game_Foo.lua index 6bd209ae..9966262d 100644 --- a/bench/micro_tests/test_Failure_pcall_game_Foo.lua +++ b/bench/micro_tests/test_Failure_pcall_game_Foo.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_xpcall_a_bar.lua b/bench/micro_tests/test_Failure_xpcall_a_bar.lua index e00a3ca6..44534da4 100644 --- a/bench/micro_tests/test_Failure_xpcall_a_bar.lua +++ b/bench/micro_tests/test_Failure_xpcall_a_bar.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Failure_xpcall_game_Foo.lua b/bench/micro_tests/test_Failure_xpcall_game_Foo.lua index 86dadc90..35659598 100644 --- a/bench/micro_tests/test_Failure_xpcall_game_Foo.lua +++ b/bench/micro_tests/test_Failure_xpcall_game_Foo.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableCtor_array.lua b/bench/micro_tests/test_LargeTableCtor_array.lua index 016dfd2d..35b6f449 100644 --- a/bench/micro_tests/test_LargeTableCtor_array.lua +++ b/bench/micro_tests/test_LargeTableCtor_array.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableCtor_hash.lua b/bench/micro_tests/test_LargeTableCtor_hash.lua index c46a7ab4..e2b11b4b 100644 --- a/bench/micro_tests/test_LargeTableCtor_hash.lua +++ b/bench/micro_tests/test_LargeTableCtor_hash.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_index.lua b/bench/micro_tests/test_LargeTableSum_loop_index.lua index 2aae109e..dd64ca00 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_index.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_index.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua b/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua index 29205e26..54ee888d 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_ipairs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_iter.lua b/bench/micro_tests/test_LargeTableSum_loop_iter.lua index ea2b157c..fb69470f 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_iter.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_iter.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_LargeTableSum_loop_pairs.lua b/bench/micro_tests/test_LargeTableSum_loop_pairs.lua index 8d789fcf..ffe19a20 100644 --- a/bench/micro_tests/test_LargeTableSum_loop_pairs.lua +++ b/bench/micro_tests/test_LargeTableSum_loop_pairs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_MethodCalls.lua b/bench/micro_tests/test_MethodCalls.lua index f8b44527..016a4798 100644 --- a/bench/micro_tests/test_MethodCalls.lua +++ b/bench/micro_tests/test_MethodCalls.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_OOP_constructor.lua b/bench/micro_tests/test_OOP_constructor.lua index 9fec3b67..b1c03dfc 100644 --- a/bench/micro_tests/test_OOP_constructor.lua +++ b/bench/micro_tests/test_OOP_constructor.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_OOP_method_call.lua b/bench/micro_tests/test_OOP_method_call.lua index 1e5249c5..09699acb 100644 --- a/bench/micro_tests/test_OOP_method_call.lua +++ b/bench/micro_tests/test_OOP_method_call.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_OOP_virtual_constructor.lua b/bench/micro_tests/test_OOP_virtual_constructor.lua index df99e13b..68dfba61 100644 --- a/bench/micro_tests/test_OOP_virtual_constructor.lua +++ b/bench/micro_tests/test_OOP_virtual_constructor.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_call_return.lua b/bench/micro_tests/test_Pcall_call_return.lua index 2a612175..45d8ca58 100644 --- a/bench/micro_tests/test_Pcall_call_return.lua +++ b/bench/micro_tests/test_Pcall_call_return.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_pcall_return.lua b/bench/micro_tests/test_Pcall_pcall_return.lua index 16bdfdd3..09a032df 100644 --- a/bench/micro_tests/test_Pcall_pcall_return.lua +++ b/bench/micro_tests/test_Pcall_pcall_return.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_pcall_yield.lua b/bench/micro_tests/test_Pcall_pcall_yield.lua index ae0a4b46..2ae0baa6 100644 --- a/bench/micro_tests/test_Pcall_pcall_yield.lua +++ b/bench/micro_tests/test_Pcall_pcall_yield.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_Pcall_xpcall_return.lua b/bench/micro_tests/test_Pcall_xpcall_return.lua index 8ac2f0eb..5fb69f1b 100644 --- a/bench/micro_tests/test_Pcall_xpcall_return.lua +++ b/bench/micro_tests/test_Pcall_xpcall_return.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_exponent.lua b/bench/micro_tests/test_SqrtSum_exponent.lua index bfd6fd72..1bb6a7d2 100644 --- a/bench/micro_tests/test_SqrtSum_exponent.lua +++ b/bench/micro_tests/test_SqrtSum_exponent.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_math_sqrt.lua b/bench/micro_tests/test_SqrtSum_math_sqrt.lua index 1e1f42c7..7a280460 100644 --- a/bench/micro_tests/test_SqrtSum_math_sqrt.lua +++ b/bench/micro_tests/test_SqrtSum_math_sqrt.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_sqrt.lua b/bench/micro_tests/test_SqrtSum_sqrt.lua index 96880e7b..ddcddb9d 100644 --- a/bench/micro_tests/test_SqrtSum_sqrt.lua +++ b/bench/micro_tests/test_SqrtSum_sqrt.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua b/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua index 55f29e2e..1dd29776 100644 --- a/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua +++ b/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua b/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua index bbe48a64..0527ea4d 100644 --- a/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua +++ b/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_StringInterp.lua b/bench/micro_tests/test_StringInterp.lua index 55430519..d44f5b07 100644 --- a/bench/micro_tests/test_StringInterp.lua +++ b/bench/micro_tests/test_StringInterp.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_TableCreate_nil.lua b/bench/micro_tests/test_TableCreate_nil.lua index 707a2750..546e9d6b 100644 --- a/bench/micro_tests/test_TableCreate_nil.lua +++ b/bench/micro_tests/test_TableCreate_nil.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableCreate_number.lua b/bench/micro_tests/test_TableCreate_number.lua index 3e4305bd..fe8437b7 100644 --- a/bench/micro_tests/test_TableCreate_number.lua +++ b/bench/micro_tests/test_TableCreate_number.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableCreate_zerofill.lua b/bench/micro_tests/test_TableCreate_zerofill.lua index fed439b4..e2cfda30 100644 --- a/bench/micro_tests/test_TableCreate_zerofill.lua +++ b/bench/micro_tests/test_TableCreate_zerofill.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableFind_loop_ipairs.lua b/bench/micro_tests/test_TableFind_loop_ipairs.lua index 46560274..ef7f4c81 100644 --- a/bench/micro_tests/test_TableFind_loop_ipairs.lua +++ b/bench/micro_tests/test_TableFind_loop_ipairs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableFind_table_find.lua b/bench/micro_tests/test_TableFind_table_find.lua index 3f22122f..05882c50 100644 --- a/bench/micro_tests/test_TableFind_table_find.lua +++ b/bench/micro_tests/test_TableFind_table_find.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_index_cached.lua b/bench/micro_tests/test_TableInsertion_index_cached.lua index 0c34818f..adb40822 100644 --- a/bench/micro_tests/test_TableInsertion_index_cached.lua +++ b/bench/micro_tests/test_TableInsertion_index_cached.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_index_len.lua b/bench/micro_tests/test_TableInsertion_index_len.lua index 120a5e28..797dec80 100644 --- a/bench/micro_tests/test_TableInsertion_index_len.lua +++ b/bench/micro_tests/test_TableInsertion_index_len.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_table_insert.lua b/bench/micro_tests/test_TableInsertion_table_insert.lua index 1ad3fe22..632e9080 100644 --- a/bench/micro_tests/test_TableInsertion_table_insert.lua +++ b/bench/micro_tests/test_TableInsertion_table_insert.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableInsertion_table_insert_index.lua b/bench/micro_tests/test_TableInsertion_table_insert_index.lua index 41747139..7b35fe39 100644 --- a/bench/micro_tests/test_TableInsertion_table_insert_index.lua +++ b/bench/micro_tests/test_TableInsertion_table_insert_index.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableIteration.lua b/bench/micro_tests/test_TableIteration.lua index 5f78a48b..2c44f43c 100644 --- a/bench/micro_tests/test_TableIteration.lua +++ b/bench/micro_tests/test_TableIteration.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_select.lua b/bench/micro_tests/test_TableMarshal_select.lua index 9869da60..df5ebf78 100644 --- a/bench/micro_tests/test_TableMarshal_select.lua +++ b/bench/micro_tests/test_TableMarshal_select.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_table_pack.lua b/bench/micro_tests/test_TableMarshal_table_pack.lua index 3da855f5..3d0190e7 100644 --- a/bench/micro_tests/test_TableMarshal_table_pack.lua +++ b/bench/micro_tests/test_TableMarshal_table_pack.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_table_unpack_array.lua b/bench/micro_tests/test_TableMarshal_table_unpack_array.lua index 13d1d1c3..32f2eb9a 100644 --- a/bench/micro_tests/test_TableMarshal_table_unpack_array.lua +++ b/bench/micro_tests/test_TableMarshal_table_unpack_array.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_table_unpack_range.lua b/bench/micro_tests/test_TableMarshal_table_unpack_range.lua index e3aa68be..fa53a31c 100644 --- a/bench/micro_tests/test_TableMarshal_table_unpack_range.lua +++ b/bench/micro_tests/test_TableMarshal_table_unpack_range.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMarshal_varargs.lua b/bench/micro_tests/test_TableMarshal_varargs.lua index 64b41b43..b88d8213 100644 --- a/bench/micro_tests/test_TableMarshal_varargs.lua +++ b/bench/micro_tests/test_TableMarshal_varargs.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMove_empty_table.lua b/bench/micro_tests/test_TableMove_empty_table.lua index 39335564..18737f74 100644 --- a/bench/micro_tests/test_TableMove_empty_table.lua +++ b/bench/micro_tests/test_TableMove_empty_table.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMove_same_table.lua b/bench/micro_tests/test_TableMove_same_table.lua index f62022b1..8fc9fa03 100644 --- a/bench/micro_tests/test_TableMove_same_table.lua +++ b/bench/micro_tests/test_TableMove_same_table.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableMove_table_create.lua b/bench/micro_tests/test_TableMove_table_create.lua index f03c4de7..3c0cb9e9 100644 --- a/bench/micro_tests/test_TableMove_table_create.lua +++ b/bench/micro_tests/test_TableMove_table_create.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableRemoval_table_remove.lua b/bench/micro_tests/test_TableRemoval_table_remove.lua index 13410116..3ba3e503 100644 --- a/bench/micro_tests/test_TableRemoval_table_remove.lua +++ b/bench/micro_tests/test_TableRemoval_table_remove.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_TableSort.lua b/bench/micro_tests/test_TableSort.lua index 502cb2a5..e3276845 100644 --- a/bench/micro_tests/test_TableSort.lua +++ b/bench/micro_tests/test_TableSort.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local arr_months = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"} diff --git a/bench/micro_tests/test_ToNumberString.lua b/bench/micro_tests/test_ToNumberString.lua index 842b7c22..cda886c0 100644 --- a/bench/micro_tests/test_ToNumberString.lua +++ b/bench/micro_tests/test_ToNumberString.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_UpvalueCapture.lua b/bench/micro_tests/test_UpvalueCapture.lua index 4a2608c4..6c2f2616 100644 --- a/bench/micro_tests/test_UpvalueCapture.lua +++ b/bench/micro_tests/test_UpvalueCapture.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_VariadicSelect.lua b/bench/micro_tests/test_VariadicSelect.lua index 5a62f2d8..9710e237 100644 --- a/bench/micro_tests/test_VariadicSelect.lua +++ b/bench/micro_tests/test_VariadicSelect.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/micro_tests/test_string_lib.lua b/bench/micro_tests/test_string_lib.lua index 041f5b15..5f180151 100644 --- a/bench/micro_tests/test_string_lib.lua +++ b/bench/micro_tests/test_string_lib.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_table_concat.lua b/bench/micro_tests/test_table_concat.lua index 590b7d4a..879b63fe 100644 --- a/bench/micro_tests/test_table_concat.lua +++ b/bench/micro_tests/test_table_concat.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") bench.runCode(function() diff --git a/bench/micro_tests/test_vector_lib.lua b/bench/micro_tests/test_vector_lib.lua new file mode 100644 index 00000000..59bddc04 --- /dev/null +++ b/bench/micro_tests/test_vector_lib.lua @@ -0,0 +1,14 @@ +local function prequire(name) local success, result = pcall(require, name); return success and result end +local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") + +bench.runCode(function() + for i=1,1000000 do + vector.create(i, 2, 3) + vector.create(i, 2, 3) + vector.create(i, 2, 3) + vector.create(i, 2, 3) + vector.create(i, 2, 3) + end +end, "vector: create") + +-- TODO: add more tests \ No newline at end of file diff --git a/bench/tests/base64.lua b/bench/tests/base64.lua index e580c595..13bfd070 100644 --- a/bench/tests/base64.lua +++ b/bench/tests/base64.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/chess.lua b/bench/tests/chess.lua index f551139e..7e6c9c0c 100644 --- a/bench/tests/chess.lua +++ b/bench/tests/chess.lua @@ -1,5 +1,5 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local RANKS = "12345678" diff --git a/bench/tests/life.lua b/bench/tests/life.lua index d050b013..a61730aa 100644 --- a/bench/tests/life.lua +++ b/bench/tests/life.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/matrixmult.lua b/bench/tests/matrixmult.lua index af38cb64..fa04b864 100644 --- a/bench/tests/matrixmult.lua +++ b/bench/tests/matrixmult.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local function mmul(matrix1, matrix2) diff --git a/bench/tests/mesh-normal-scalar.lua b/bench/tests/mesh-normal-scalar.lua index 05bef373..509e1e62 100644 --- a/bench/tests/mesh-normal-scalar.lua +++ b/bench/tests/mesh-normal-scalar.lua @@ -1,5 +1,5 @@ --!strict -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/mesh-normal-vector.lua b/bench/tests/mesh-normal-vector.lua new file mode 100644 index 00000000..ff4f2b46 --- /dev/null +++ b/bench/tests/mesh-normal-vector.lua @@ -0,0 +1,166 @@ +--!strict +local function prequire(name) local success, result = pcall(require, name); return success and result end +local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") + +function test() + + type Vertex = { p: vector, uv: vector, n: vector, t: vector, b: vector, h: number } + + local grid_size = 100 + + local mesh: { + vertices: {Vertex}, + indices: {number}, + triangle_cone_p: {vector}, + triangle_cone_n: {vector} + } = { + vertices = table.create(grid_size * grid_size), + indices = table.create((grid_size - 1) * (grid_size - 1) * 6), + triangle_cone_p = table.create((grid_size - 1) * (grid_size - 1) * 2), + triangle_cone_n = table.create((grid_size - 1) * (grid_size - 1) * 2) + } + + function init_vertices() + local i = 1 + for y = 1,grid_size do + for x = 1,grid_size do + local v: Vertex = {} + + v.p = vector.create(x, y, math.cos(x) + math.sin(y)) + v.uv = vector.create((x-1)/(grid_size-1), (y-1)/(grid_size-1), 0) + v.n = vector.create(0, 0, 0) + v.b = vector.create(0, 0, 0) + v.t = vector.create(0, 0, 0) + v.h = 0 + + mesh.vertices[i] = v + i += 1 + end + end + end + + function init_indices() + local i = 1 + for y = 1,grid_size-1 do + for x = 1,grid_size-1 do + mesh.indices[i] = x + (y-1)*grid_size + i += 1 + mesh.indices[i] = x + y*grid_size + i += 1 + mesh.indices[i] = (x+1) + (y-1)*grid_size + i += 1 + mesh.indices[i] = (x+1) + (y-1)*grid_size + i += 1 + mesh.indices[i] = x + y*grid_size + i += 1 + mesh.indices[i] = (x+1) + y*grid_size + i += 1 + end + end + end + + function calculate_normals() + local norm_sum = 0 + + for i = 1,#mesh.indices,3 do + local a = mesh.vertices[mesh.indices[i]] + local b = mesh.vertices[mesh.indices[i + 1]] + local c = mesh.vertices[mesh.indices[i + 2]] + + local n = vector.cross(a.p - b.p, a.p - c.p) + + a.n += n + b.n += n + c.n += n + end + + for _,v in ipairs(mesh.vertices) do + v.n = vector.normalize(v.n) + + norm_sum += vector.dot(v.n, v.n) + end + + return norm_sum + end + + function compute_triangle_cones() + local mesh_area = 0 + + local pos = 1 + + for i = 1,#mesh.indices,3 do + local p0 = mesh.vertices[mesh.indices[i]] + local p1 = mesh.vertices[mesh.indices[i + 1]] + local p2 = mesh.vertices[mesh.indices[i + 2]] + + local p10 = p1.p - p0.p + local p20 = p2.p - p0.p + + local normal = vector.cross(p10, p20) + + local area = vector.magnitude(normal) + local invarea = (area == 0) and 0 or 1 / area; + + mesh.triangle_cone_p[pos] = (p0.p + p1.p + p2.p) / 3 + mesh.triangle_cone_n[pos] = normal * invarea + pos += 1 + + mesh_area += area + end + + return mesh_area + end + + function compute_tangent_space() + local checksum = 0 + + for i = 1,#mesh.indices,3 do + local a = mesh.vertices[mesh.indices[i]] + local b = mesh.vertices[mesh.indices[i + 1]] + local c = mesh.vertices[mesh.indices[i + 2]] + + local vba = b.p - a.p + local vca = c.p - a.p + + local uvba = b.uv - a.uv + local uvca = c.uv - a.uv + + local r = 1.0 / (uvba.X * uvca.Y - uvca.X * uvba.Y); + + local sdir = (uvca.Y * vba - uvba.Y * vca) * r + local tdir = (uvba.X * vca - uvca.X * vba) * r + + a.t += sdir + b.t += sdir + c.t += sdir + + a.b += tdir + b.b += tdir + c.b += tdir + end + + for _,v in ipairs(mesh.vertices) do + local t = v.t + + -- Gram-Schmidt orthogonalize + v.t = vector.normalize(t - v.n * vector.dot(v.n, t)) + + local ht = vector.dot(vector.cross(v.n, t), v.b) + + v.h = ht < 0 and -1 or 1 + + checksum += v.t.X + v.h + end + + return checksum + end + + + init_vertices() + init_indices() + calculate_normals() + compute_triangle_cones() + compute_tangent_space() +end + +bench.runCode(test, "mesh-normal-vector") diff --git a/bench/tests/pcmmix.lua b/bench/tests/pcmmix.lua index c98cee2c..1e8e27a5 100644 --- a/bench/tests/pcmmix.lua +++ b/bench/tests/pcmmix.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") local samples = 100_000 diff --git a/bench/tests/qsort.lua b/bench/tests/qsort.lua index 566c1b98..37413fa2 100644 --- a/bench/tests/qsort.lua +++ b/bench/tests/qsort.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/sha256.lua b/bench/tests/sha256.lua index 2ac0ab33..e478e763 100644 --- a/bench/tests/sha256.lua +++ b/bench/tests/sha256.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/shootout/ack.lua b/bench/tests/shootout/ack.lua index f7fd43a8..ca8913ac 100644 --- a/bench/tests/shootout/ack.lua +++ b/bench/tests/shootout/ack.lua @@ -23,7 +23,7 @@ SOFTWARE. ]] -- http://www.bagley.org/~doug/shootout/ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/binary-trees.lua b/bench/tests/shootout/binary-trees.lua index 89c5933c..50d40597 100644 --- a/bench/tests/shootout/binary-trees.lua +++ b/bench/tests/shootout/binary-trees.lua @@ -25,7 +25,7 @@ SOFTWARE. -- http://benchmarksgame.alioth.debian.org/ -- contributed by Mike Pall -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/fannkuch-redux.lua b/bench/tests/shootout/fannkuch-redux.lua index 43bc9e41..60f7c3c0 100644 --- a/bench/tests/shootout/fannkuch-redux.lua +++ b/bench/tests/shootout/fannkuch-redux.lua @@ -25,7 +25,7 @@ SOFTWARE. -- http://benchmarksgame.alioth.debian.org/ -- contributed by Mike Pall -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/fixpoint-fact.lua b/bench/tests/shootout/fixpoint-fact.lua index 112acb4a..226c78a8 100644 --- a/bench/tests/shootout/fixpoint-fact.lua +++ b/bench/tests/shootout/fixpoint-fact.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/heapsort.lua b/bench/tests/shootout/heapsort.lua index 0daf97ab..69c1b885 100644 --- a/bench/tests/shootout/heapsort.lua +++ b/bench/tests/shootout/heapsort.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/mandel.lua b/bench/tests/shootout/mandel.lua index a3bbb7e5..547741e6 100644 --- a/bench/tests/shootout/mandel.lua +++ b/bench/tests/shootout/mandel.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/n-body.lua b/bench/tests/shootout/n-body.lua index e0f9c63c..082b7fa0 100644 --- a/bench/tests/shootout/n-body.lua +++ b/bench/tests/shootout/n-body.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/qt.lua b/bench/tests/shootout/qt.lua index d9b4a517..c15accd0 100644 --- a/bench/tests/shootout/qt.lua +++ b/bench/tests/shootout/qt.lua @@ -23,7 +23,7 @@ SOFTWARE. ]] -- Julia sets via interval cell-mapping (quadtree version) -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/queen.lua b/bench/tests/shootout/queen.lua index c3508d60..8f27e06f 100644 --- a/bench/tests/shootout/queen.lua +++ b/bench/tests/shootout/queen.lua @@ -21,7 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/scimark.lua b/bench/tests/shootout/scimark.lua index 1b66df53..dd7cae53 100644 --- a/bench/tests/shootout/scimark.lua +++ b/bench/tests/shootout/scimark.lua @@ -33,7 +33,7 @@ -- Modification to be compatible with Lua 5.3 ------------------------------------------------------------------------------ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/shootout/spectral-norm.lua b/bench/tests/shootout/spectral-norm.lua index b5116612..f1acd34c 100644 --- a/bench/tests/shootout/spectral-norm.lua +++ b/bench/tests/shootout/spectral-norm.lua @@ -25,7 +25,7 @@ SOFTWARE. -- http://benchmarksgame.alioth.debian.org/ -- contributed by Mike Pall -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sieve.lua b/bench/tests/sieve.lua index 1bb45d99..8d8cf82a 100644 --- a/bench/tests/sieve.lua +++ b/bench/tests/sieve.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua index aac7a156..ea132463 100644 --- a/bench/tests/sunspider/3d-cube.lua +++ b/bench/tests/sunspider/3d-cube.lua @@ -2,7 +2,7 @@ -- http://www.speich.net/computer/moztesting/3d.htm -- Created by Simon Speich -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/3d-morph.lua b/bench/tests/sunspider/3d-morph.lua index 8263f015..0dbf1c63 100644 --- a/bench/tests/sunspider/3d-morph.lua +++ b/bench/tests/sunspider/3d-morph.lua @@ -23,7 +23,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua index 33d464b8..83ca7bd9 100644 --- a/bench/tests/sunspider/3d-raytrace.lua +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -22,7 +22,7 @@ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/controlflow-recursive.lua b/bench/tests/sunspider/controlflow-recursive.lua index 1c78a3c2..67c77293 100644 --- a/bench/tests/sunspider/controlflow-recursive.lua +++ b/bench/tests/sunspider/controlflow-recursive.lua @@ -3,7 +3,7 @@ http://shootout.alioth.debian.org/ contributed by Isaac Gouy ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/crypto-aes.lua b/bench/tests/sunspider/crypto-aes.lua index 9692cf52..6b23719b 100644 --- a/bench/tests/sunspider/crypto-aes.lua +++ b/bench/tests/sunspider/crypto-aes.lua @@ -9,7 +9,7 @@ * returns byte-array encrypted value (16 bytes) */]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") -- Sbox is pre-computed multiplicative inverse in GF(2^8) used in SubBytes and KeyExpansion [§5.1.1] diff --git a/bench/tests/sunspider/fannkuch.lua b/bench/tests/sunspider/fannkuch.lua index 08cdcc24..24098740 100644 --- a/bench/tests/sunspider/fannkuch.lua +++ b/bench/tests/sunspider/fannkuch.lua @@ -3,7 +3,7 @@ http://shootout.alioth.debian.org/ contributed by Isaac Gouy ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/math-cordic.lua b/bench/tests/sunspider/math-cordic.lua index 2b622377..861cc51a 100644 --- a/bench/tests/sunspider/math-cordic.lua +++ b/bench/tests/sunspider/math-cordic.lua @@ -23,7 +23,7 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ]] - local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end + local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/math-partial-sums.lua b/bench/tests/sunspider/math-partial-sums.lua index f0b4b0b7..21f63295 100644 --- a/bench/tests/sunspider/math-partial-sums.lua +++ b/bench/tests/sunspider/math-partial-sums.lua @@ -3,7 +3,7 @@ http://shootout.alioth.debian.org/ contributed by Isaac Gouy ]] -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") function test() diff --git a/bench/tests/sunspider/n-body-oop.lua b/bench/tests/sunspider/n-body-oop.lua index e04286c8..469e22c1 100644 --- a/bench/tests/sunspider/n-body-oop.lua +++ b/bench/tests/sunspider/n-body-oop.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../../bench_support") local PI = 3.141592653589793 diff --git a/bench/tests/tictactoe.lua b/bench/tests/tictactoe.lua index 673dcd48..bc3282a0 100644 --- a/bench/tests/tictactoe.lua +++ b/bench/tests/tictactoe.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/trig.lua b/bench/tests/trig.lua index 64bf611c..269fd610 100644 --- a/bench/tests/trig.lua +++ b/bench/tests/trig.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") function test() diff --git a/bench/tests/vector-math.lua b/bench/tests/vector-math.lua new file mode 100644 index 00000000..11c37a2d --- /dev/null +++ b/bench/tests/vector-math.lua @@ -0,0 +1,39 @@ +local function prequire(name) local success, result = pcall(require, name); if success then return result end return nil end +local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") + +function fma(a: vector, b: vector, c: vector) + return a * b + c +end + +function approx(a: vector): vector + local r = vector.create(1, 1, 1) + local aa = a + r += aa * 0.123 + aa *= a + r += aa * 0.123 + aa *= a + r += aa * 0.123 + aa *= a + r += aa * 0.123 + aa *= a + r += aa * 0.123 + aa *= a + r += aa * 0.123 + return r +end + +function test() + local A = vector.create(1, 2, 3) + local B = vector.create(4, 5, 6) + local C = vector.create(7, 8, 9) + local fma = fma + local approx = approx + + for i=1,100000 do + fma(A, B, C) + + approx(A) + end +end + +bench.runCode(test, "vector-math") diff --git a/bench/tests/voxelgen.lua b/bench/tests/voxelgen.lua index b50a4592..813838c1 100644 --- a/bench/tests/voxelgen.lua +++ b/bench/tests/voxelgen.lua @@ -1,4 +1,4 @@ -local function prequire(name) local success, result = pcall(require, name); return if success then result else nil end +local function prequire(name) local success, result = pcall(require, name); return success and result end local bench = script and require(script.Parent.bench_support) or prequire("bench_support") or require("../bench_support") -- Based on voxel terrain generator by Stickmasterluke diff --git a/fuzz/basic.lua b/fuzz/basic.luau similarity index 100% rename from fuzz/basic.lua rename to fuzz/basic.luau diff --git a/tests/AnyTypeSummary.test.cpp b/tests/AnyTypeSummary.test.cpp index 5c3b4aa3..471c4bb1 100644 --- a/tests/AnyTypeSummary.test.cpp +++ b/tests/AnyTypeSummary.test.cpp @@ -18,6 +18,11 @@ LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTFLAG(DebugLuauMagicTypes) LUAU_FASTFLAG(StudioReportLuauAny2) +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) +LUAU_FASTFLAG(LuauStoreCSTData) +LUAU_FASTFLAG(LuauAstTypeGroup2) +LUAU_FASTFLAG(LuauDeferBidirectionalInferenceForTableAssignment) +LUAU_FASTFLAG(LuauSkipNoRefineDuringRefinement) struct ATSFixture : BuiltinsFixture @@ -71,7 +76,22 @@ export type t8 = t0 &((true | any)->('')) LUAU_ASSERT(module->ats.typeInfo.size() == 1); LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::Alias); - LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0 &((true | any)->(''))"); + if (FFlag::LuauStoreCSTData && FFlag::LuauAstTypeGroup2) + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0& (( true | any)->(''))"); + } + else if (FFlag::LuauStoreCSTData) + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0 &(( true | any)->(''))"); + } + else if (FFlag::LuauAstTypeGroup2) + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0& ((true | any)->(''))"); + } + else + { + LUAU_ASSERT(module->ats.typeInfo[0].node == "export type t8 = t0 &((true | any)->(''))"); + } } TEST_CASE_FIXTURE(ATSFixture, "typepacks") @@ -97,7 +117,10 @@ end LUAU_ASSERT(module->ats.typeInfo.size() == 3); LUAU_ASSERT(module->ats.typeInfo[1].code == Pattern::TypePk); - LUAU_ASSERT(module->ats.typeInfo[0].node == "local function fallible(t: number): ...any\n if t > 0 then\n return true, t\n end\n return false, 'must be positive'\nend"); + LUAU_ASSERT( + module->ats.typeInfo[0].node == + "local function fallible(t: number): ...any\n if t > 0 then\n return true, t\n end\n return false, 'must be positive'\nend" + ); } TEST_CASE_FIXTURE(ATSFixture, "typepacks_no_ret") @@ -111,7 +134,7 @@ TEST_CASE_FIXTURE(ATSFixture, "typepacks_no_ret") -- TODO: if partially typed, we'd want to know too local function fallible(t: number) if t > 0 then - return true, t + return true, t end return false, "must be positive" end @@ -421,11 +444,18 @@ end )"; CheckResult result1 = frontend.check("game/Gui/Modules/A"); - LUAU_REQUIRE_ERROR_COUNT(3, result1); + LUAU_REQUIRE_ERROR_COUNT(1, result1); ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); - LUAU_ASSERT(module->ats.typeInfo.size() == 0); + if (FFlag::LuauSkipNoRefineDuringRefinement) + { + REQUIRE_EQ(module->ats.typeInfo.size(), 1); + CHECK_EQ(module->ats.typeInfo[0].code, Pattern::Assign); + CHECK_EQ(module->ats.typeInfo[0].node, "descendant.CollisionGroup = CAR_COLLISION_GROUP"); + } + else + LUAU_ASSERT(module->ats.typeInfo.size() == 0); } TEST_CASE_FIXTURE(ATSFixture, "unknown_symbol") @@ -561,19 +591,35 @@ initialize() )"; CheckResult result1 = frontend.check("game/Gui/Modules/A"); - LUAU_REQUIRE_ERROR_COUNT(5, result1); + LUAU_REQUIRE_ERROR_COUNT(3, result1); ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); - LUAU_ASSERT(module->ats.typeInfo.size() == 11); + if (FFlag::LuauSkipNoRefineDuringRefinement) + CHECK_EQ(module->ats.typeInfo.size(), 12); + else + LUAU_ASSERT(module->ats.typeInfo.size() == 11); LUAU_ASSERT(module->ats.typeInfo[0].code == Pattern::FuncArg); - LUAU_ASSERT( - module->ats.typeInfo[0].node == - "local function onCharacterAdded(character: Model)\n\n character.DescendantAdded:Connect(function(descendant)\n if " - "descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n end)\n\n\n for _, descendant in " - "character:GetDescendants()do\n if descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n " - "end\nend" - ); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ( + module->ats.typeInfo[0].node, + "local function onCharacterAdded(character: Model)\n\n character.DescendantAdded:Connect(function(descendant)\n if " + "descendant:IsA('BasePart') then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n end)\n\n\n for _, descendant in " + "character:GetDescendants() do\n if descendant:IsA('BasePart') then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n " + "end\nend" + ); + } + else + { + LUAU_ASSERT( + module->ats.typeInfo[0].node == + "local function onCharacterAdded(character: Model)\n\n character.DescendantAdded:Connect(function(descendant)\n if " + "descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n end)\n\n\n for _, descendant in " + "character:GetDescendants()do\n if descendant:IsA('BasePart')then\n descendant.CollisionGroup = CHARACTER_COLLISION_GROUP\n end\n " + "end\nend" + ); + } } TEST_CASE_FIXTURE(ATSFixture, "racing_spawning_1") @@ -581,6 +627,9 @@ TEST_CASE_FIXTURE(ATSFixture, "racing_spawning_1") ScopedFastFlag sff[] = { {FFlag::LuauSolverV2, true}, {FFlag::StudioReportLuauAny2, true}, + // Previously we'd report an error because number <: 'a is not a + // supertype. + {FFlag::LuauTrackInteriorFreeTypesOnScope, true} }; fileResolver.source["game/Gui/Modules/A"] = R"( @@ -632,7 +681,7 @@ initialize() )"; CheckResult result1 = frontend.check("game/Gui/Modules/A"); - LUAU_REQUIRE_ERROR_COUNT(5, result1); + LUAU_REQUIRE_ERROR_COUNT(4, result1); ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); @@ -653,6 +702,7 @@ TEST_CASE_FIXTURE(ATSFixture, "mutually_recursive_generic") ScopedFastFlag sff[] = { {FFlag::LuauSolverV2, true}, {FFlag::StudioReportLuauAny2, true}, + {FFlag::LuauDeferBidirectionalInferenceForTableAssignment, true} }; fileResolver.source["game/Gui/Modules/A"] = R"( @@ -665,8 +715,7 @@ TEST_CASE_FIXTURE(ATSFixture, "mutually_recursive_generic") y.g.i = y )"; - CheckResult result1 = frontend.check("game/Gui/Modules/A"); - LUAU_REQUIRE_ERROR_COUNT(2, result1); + LUAU_REQUIRE_NO_ERRORS(frontend.check("game/Gui/Modules/A")); ModulePtr module = frontend.moduleResolver.getModule("game/Gui/Modules/A"); @@ -911,7 +960,7 @@ TEST_CASE_FIXTURE(ATSFixture, "type_alias_any") fileResolver.source["game/Gui/Modules/A"] = R"( type Clear = any - local z: Clear = "zip" + local z: Clear = "zip" )"; CheckResult result1 = frontend.check("game/Gui/Modules/A"); @@ -938,7 +987,7 @@ TEST_CASE_FIXTURE(ATSFixture, "multi_module_any") fileResolver.source["game/B"] = R"( local MyFunc = require(script.Parent.A) type Clear = any - local z: Clear = "zip" + local z: Clear = "zip" )"; fileResolver.source["game/Gui/Modules/A"] = R"( @@ -972,7 +1021,7 @@ TEST_CASE_FIXTURE(ATSFixture, "cast_on_cyclic_req") fileResolver.source["game/B"] = R"( local MyFunc = require(script.Parent.A) :: any type Clear = any - local z: Clear = "zip" + local z: Clear = "zip" )"; CheckResult result = frontend.check("game/B"); diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index 2cd821b5..b730cb1e 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -10,6 +10,8 @@ using namespace Luau::CodeGen; using namespace Luau::CodeGen::A64; +LUAU_FASTFLAG(LuauVectorLibNativeDot); + static std::string bytecodeAsArray(const std::vector& bytecode) { std::string result = "{"; @@ -387,6 +389,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPBasic") TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath") { + ScopedFastFlag sff{FFlag::LuauVectorLibNativeDot, true}; + SINGLE_COMPARE(fabs(d1, d2), 0x1E60C041); SINGLE_COMPARE(fadd(d1, d2, d3), 0x1E632841); SINGLE_COMPARE(fadd(s29, s29, s28), 0x1E3C2BBD); @@ -400,6 +404,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath") SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841); SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD); + SINGLE_COMPARE(faddp(s29, s28), 0x7E30DB9D); + SINGLE_COMPARE(faddp(d29, d28), 0x7E70DB9D); + SINGLE_COMPARE(frinta(d1, d2), 0x1E664041); SINGLE_COMPARE(frintm(d1, d2), 0x1E654041); SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041); diff --git a/tests/AssemblyBuilderX64.test.cpp b/tests/AssemblyBuilderX64.test.cpp index 655fa8f1..fd1deccf 100644 --- a/tests/AssemblyBuilderX64.test.cpp +++ b/tests/AssemblyBuilderX64.test.cpp @@ -10,6 +10,8 @@ using namespace Luau::CodeGen; using namespace Luau::CodeGen::X64; +LUAU_FASTFLAG(LuauVectorLibNativeDot); + static std::string bytecodeAsArray(const std::vector& bytecode) { std::string result = "{"; @@ -504,6 +506,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXBinaryInstructionForms") SINGLE_COMPARE(vmaxsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5f, 0xc6); SINGLE_COMPARE(vminsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0x5d, 0xc6); + SINGLE_COMPARE(vcmpeqsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x00); SINGLE_COMPARE(vcmpltsd(xmm8, xmm10, xmm14), 0xc4, 0x41, 0x2b, 0xc2, 0xc6, 0x01); } @@ -568,6 +571,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXConversionInstructionForms") TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") { + ScopedFastFlag sff{FFlag::LuauVectorLibNativeDot, true}; + SINGLE_COMPARE(vroundsd(xmm7, xmm12, xmm3, RoundingModeX64::RoundToNegativeInfinity), 0xc4, 0xe3, 0x19, 0x0b, 0xfb, 0x09); SINGLE_COMPARE( vroundsd(xmm8, xmm13, xmmword[r13 + rdx], RoundingModeX64::RoundToPositiveInfinity), 0xc4, 0x43, 0x11, 0x0b, 0x44, 0x15, 0x00, 0x0a @@ -577,6 +582,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms") SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4); SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02); + + SINGLE_COMPARE(vdpps(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x40, 0x3c, 0x11, 0x02); } TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions") diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index 76538cf1..7dff66d7 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -11,6 +11,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauAstTypeGroup2) + struct JsonEncoderFixture { Allocator allocator; @@ -67,7 +69,7 @@ TEST_CASE("encode_constants") charString.data = const_cast("a\x1d\0\\\"b"); charString.size = 6; - AstExprConstantString needsEscaping{Location(), charString}; + AstExprConstantString needsEscaping{Location(), charString, AstExprConstantString::QuotedSimple}; CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil)); CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b)); @@ -83,7 +85,7 @@ TEST_CASE("basic_escaping") { std::string s = "hello \"world\""; AstArray theString{s.data(), s.size()}; - AstExprConstantString str{Location(), theString}; + AstExprConstantString str{Location(), theString, AstExprConstantString::QuotedSimple}; std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})"; CHECK_EQ(expected, toJson(&str)); @@ -138,7 +140,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_table_array") CHECK( json == - R"({"type":"AstStatBlock","location":"0,0 - 0,17","hasEnd":true,"body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"type":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","nameLocation":"0,10 - 0,16","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","nameLocation":"0,10 - 0,16","parameters":[]}}},"exported":false}]})" + R"({"type":"AstStatBlock","location":"0,0 - 0,17","hasEnd":true,"body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"value":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","nameLocation":"0,10 - 0,16","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","nameLocation":"0,10 - 0,16","parameters":[]}}},"exported":false}]})" ); } @@ -151,7 +153,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_table_indexer") CHECK( json == - R"({"type":"AstStatBlock","location":"0,0 - 0,17","hasEnd":true,"body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"type":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","nameLocation":"0,10 - 0,16","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","nameLocation":"0,10 - 0,16","parameters":[]}}},"exported":false}]})" + R"({"type":"AstStatBlock","location":"0,0 - 0,17","hasEnd":true,"body":[{"type":"AstStatTypeAlias","location":"0,0 - 0,17","name":"X","generics":[],"genericPacks":[],"value":{"type":"AstTypeTable","location":"0,9 - 0,17","props":[],"indexer":{"location":"0,10 - 0,16","indexType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"number","nameLocation":"0,10 - 0,16","parameters":[]},"resultType":{"type":"AstTypeReference","location":"0,10 - 0,16","name":"string","nameLocation":"0,10 - 0,16","parameters":[]}}},"exported":false}]})" ); } @@ -250,7 +252,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprFunction") AstExpr* expr = expectParseExpr("function (a) return a end"); std::string_view expected = - R"({"type":"AstExprFunction","location":"0,4 - 0,29","generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,16 - 0,26","hasEnd":true,"body":[{"type":"AstStatReturn","location":"0,17 - 0,25","list":[{"type":"AstExprLocal","location":"0,24 - 0,25","local":{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}}]}]},"functionDepth":1,"debugname":""})"; + R"({"type":"AstExprFunction","location":"0,4 - 0,29","attributes":[],"generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,16 - 0,26","hasEnd":true,"body":[{"type":"AstStatReturn","location":"0,17 - 0,25","list":[{"type":"AstExprLocal","location":"0,24 - 0,25","local":{"luauType":null,"name":"a","type":"AstLocal","location":"0,14 - 0,15"}}]}]},"functionDepth":1,"debugname":""})"; CHECK(toJson(expr) == expected); } @@ -398,7 +400,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatLocalFunction") AstStat* statement = expectParseStatement("local function a(b) return end"); std::string_view expected = - R"({"type":"AstStatLocalFunction","location":"0,0 - 0,30","name":{"luauType":null,"name":"a","type":"AstLocal","location":"0,15 - 0,16"},"func":{"type":"AstExprFunction","location":"0,0 - 0,30","generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"b","type":"AstLocal","location":"0,17 - 0,18"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,19 - 0,27","hasEnd":true,"body":[{"type":"AstStatReturn","location":"0,20 - 0,26","list":[]}]},"functionDepth":1,"debugname":"a"}})"; + R"({"type":"AstStatLocalFunction","location":"0,0 - 0,30","name":{"luauType":null,"name":"a","type":"AstLocal","location":"0,15 - 0,16"},"func":{"type":"AstExprFunction","location":"0,0 - 0,30","attributes":[],"generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"b","type":"AstLocal","location":"0,17 - 0,18"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,19 - 0,27","hasEnd":true,"body":[{"type":"AstStatReturn","location":"0,20 - 0,26","list":[]}]},"functionDepth":1,"debugname":"a"}})"; CHECK(toJson(statement) == expected); } @@ -408,7 +410,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatTypeAlias") AstStat* statement = expectParseStatement("type A = B"); std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,10","name":"A","generics":[],"genericPacks":[],"type":{"type":"AstTypeReference","location":"0,9 - 0,10","name":"B","nameLocation":"0,9 - 0,10","parameters":[]},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,10","name":"A","generics":[],"genericPacks":[],"value":{"type":"AstTypeReference","location":"0,9 - 0,10","name":"B","nameLocation":"0,9 - 0,10","parameters":[]},"exported":false})"; CHECK(toJson(statement) == expected); } @@ -418,7 +420,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction") AstStat* statement = expectParseStatement("declare function foo(x: number): string"); std::string_view expected = - R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","name":"foo","nameLocation":"0,17 - 0,20","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}]},"paramNames":[{"type":"AstArgumentName","name":"x","location":"0,21 - 0,22"}],"vararg":false,"varargLocation":"0,0 - 0,0","retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","nameLocation":"0,33 - 0,39","parameters":[]}]},"generics":[],"genericPacks":[]})"; + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,39","attributes":[],"name":"foo","nameLocation":"0,17 - 0,20","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}]},"paramNames":[{"type":"AstArgumentName","name":"x","location":"0,21 - 0,22"}],"vararg":false,"varargLocation":"0,0 - 0,0","retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,33 - 0,39","name":"string","nameLocation":"0,33 - 0,39","parameters":[]}]},"generics":[],"genericPacks":[]})"; CHECK(toJson(statement) == expected); } @@ -428,11 +430,21 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareFunction2") AstStat* statement = expectParseStatement("declare function foo(x: number, ...: string): string"); std::string_view expected = - R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,52","name":"foo","nameLocation":"0,17 - 0,20","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}],"tailType":{"type":"AstTypePackVariadic","location":"0,37 - 0,43","variadicType":{"type":"AstTypeReference","location":"0,37 - 0,43","name":"string","nameLocation":"0,37 - 0,43","parameters":[]}}},"paramNames":[{"type":"AstArgumentName","name":"x","location":"0,21 - 0,22"}],"vararg":true,"varargLocation":"0,32 - 0,35","retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,46 - 0,52","name":"string","nameLocation":"0,46 - 0,52","parameters":[]}]},"generics":[],"genericPacks":[]})"; + R"({"type":"AstStatDeclareFunction","location":"0,0 - 0,52","attributes":[],"name":"foo","nameLocation":"0,17 - 0,20","params":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,24 - 0,30","name":"number","nameLocation":"0,24 - 0,30","parameters":[]}],"tailType":{"type":"AstTypePackVariadic","location":"0,37 - 0,43","variadicType":{"type":"AstTypeReference","location":"0,37 - 0,43","name":"string","nameLocation":"0,37 - 0,43","parameters":[]}}},"paramNames":[{"type":"AstArgumentName","name":"x","location":"0,21 - 0,22"}],"vararg":true,"varargLocation":"0,32 - 0,35","retTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,46 - 0,52","name":"string","nameLocation":"0,46 - 0,52","parameters":[]}]},"generics":[],"genericPacks":[]})"; CHECK(toJson(statement) == expected); } +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstAttr") +{ + AstStat* expr = expectParseStatement("@checked function a(b) return c end"); + + std::string_view expected = + R"({"type":"AstStatFunction","location":"0,9 - 0,35","name":{"type":"AstExprGlobal","location":"0,18 - 0,19","global":"a"},"func":{"type":"AstExprFunction","location":"0,9 - 0,35","attributes":[{"type":"AstAttr","location":"0,0 - 0,8","name":"checked"}],"generics":[],"genericPacks":[],"args":[{"luauType":null,"name":"b","type":"AstLocal","location":"0,20 - 0,21"}],"vararg":false,"varargLocation":"0,0 - 0,0","body":{"type":"AstStatBlock","location":"0,22 - 0,32","hasEnd":true,"body":[{"type":"AstStatReturn","location":"0,23 - 0,31","list":[{"type":"AstExprGlobal","location":"0,30 - 0,31","global":"c"}]}]},"functionDepth":1,"debugname":"a"}})"; + + CHECK(toJson(expr) == expected); +} + TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") { AstStatBlock* root = expectParse(R"( @@ -449,7 +461,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstStatDeclareClass") REQUIRE(2 == root->body.size); std::string_view expected1 = - R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","nameLocation":"2,12 - 2,16","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]},"location":"2,12 - 2,24"},{"name":"method","nameLocation":"3,21 - 3,27","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,12 - 3,54","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"argNames":[{"type":"AstArgumentName","name":"foo","location":"3,34 - 3,37"}],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}},"location":"3,12 - 3,54"}],"indexer":null})"; + R"({"type":"AstStatDeclareClass","location":"1,22 - 4,11","name":"Foo","props":[{"name":"prop","nameLocation":"2,12 - 2,16","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeReference","location":"2,18 - 2,24","name":"number","nameLocation":"2,18 - 2,24","parameters":[]},"location":"2,12 - 2,24"},{"name":"method","nameLocation":"3,21 - 3,27","type":"AstDeclaredClassProp","luauType":{"type":"AstTypeFunction","location":"3,12 - 3,54","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,39 - 3,45","name":"number","nameLocation":"3,39 - 3,45","parameters":[]}]},"argNames":[{"type":"AstArgumentName","name":"foo","location":"3,34 - 3,37"}],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"3,48 - 3,54","name":"string","nameLocation":"3,48 - 3,54","parameters":[]}]}},"location":"3,12 - 3,54"}],"indexer":null})"; CHECK(toJson(root->body.data[0]) == expected1); std::string_view expected2 = @@ -461,10 +473,18 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_annotation") { AstStat* statement = expectParseStatement("type T = ((number) -> (string | nil)) & ((string) -> ())"); - std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,36","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","nameLocation":"0,11 - 0,17","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","nameLocation":"0,23 - 0,29","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","nameLocation":"0,32 - 0,35","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","nameLocation":"0,42 - 0,48","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; - - CHECK(toJson(statement) == expected); + if (FFlag::LuauAstTypeGroup2) + { + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,56","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeIntersection","location":"0,9 - 0,56","types":[{"type":"AstTypeGroup","location":"0,9 - 0,37","inner":{"type":"AstTypeFunction","location":"0,10 - 0,36","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","nameLocation":"0,11 - 0,17","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeGroup","location":"0,22 - 0,36","inner":{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","nameLocation":"0,23 - 0,29","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","nameLocation":"0,32 - 0,35","parameters":[]}]}}]}}},{"type":"AstTypeGroup","location":"0,40 - 0,56","inner":{"type":"AstTypeFunction","location":"0,41 - 0,55","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","nameLocation":"0,42 - 0,48","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[]}}}]},"exported":false})"; + CHECK(toJson(statement) == expected); + } + else + { + std::string_view expected = + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,55","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeIntersection","location":"0,9 - 0,55","types":[{"type":"AstTypeFunction","location":"0,10 - 0,36","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,11 - 0,17","name":"number","nameLocation":"0,11 - 0,17","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[{"type":"AstTypeUnion","location":"0,23 - 0,35","types":[{"type":"AstTypeReference","location":"0,23 - 0,29","name":"string","nameLocation":"0,23 - 0,29","parameters":[]},{"type":"AstTypeReference","location":"0,32 - 0,35","name":"nil","nameLocation":"0,32 - 0,35","parameters":[]}]}]}},{"type":"AstTypeFunction","location":"0,41 - 0,55","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,42 - 0,48","name":"string","nameLocation":"0,42 - 0,48","parameters":[]}]},"argNames":[],"returnTypes":{"type":"AstTypeList","types":[]}}]},"exported":false})"; + CHECK(toJson(statement) == expected); + } } TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_type_literal") @@ -474,7 +494,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_type_literal") auto json = toJson(statement); std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,73","name":"Action","generics":[],"genericPacks":[],"type":{"type":"AstTypeTable","location":"0,14 - 0,73","props":[{"name":"strings","type":"AstTableProp","location":"0,16 - 0,23","propType":{"type":"AstTypeUnion","location":"0,25 - 0,40","types":[{"type":"AstTypeSingletonString","location":"0,25 - 0,28","value":"A"},{"type":"AstTypeSingletonString","location":"0,31 - 0,34","value":"B"},{"type":"AstTypeSingletonString","location":"0,37 - 0,40","value":"C"}]}},{"name":"mixed","type":"AstTableProp","location":"0,42 - 0,47","propType":{"type":"AstTypeUnion","location":"0,49 - 0,71","types":[{"type":"AstTypeSingletonString","location":"0,49 - 0,55","value":"This"},{"type":"AstTypeSingletonString","location":"0,58 - 0,64","value":"That"},{"type":"AstTypeSingletonBool","location":"0,67 - 0,71","value":true}]}}],"indexer":null},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,73","name":"Action","generics":[],"genericPacks":[],"value":{"type":"AstTypeTable","location":"0,14 - 0,73","props":[{"name":"strings","type":"AstTableProp","location":"0,16 - 0,23","propType":{"type":"AstTypeUnion","location":"0,25 - 0,40","types":[{"type":"AstTypeSingletonString","location":"0,25 - 0,28","value":"A"},{"type":"AstTypeSingletonString","location":"0,31 - 0,34","value":"B"},{"type":"AstTypeSingletonString","location":"0,37 - 0,40","value":"C"}]}},{"name":"mixed","type":"AstTableProp","location":"0,42 - 0,47","propType":{"type":"AstTypeUnion","location":"0,49 - 0,71","types":[{"type":"AstTypeSingletonString","location":"0,49 - 0,55","value":"This"},{"type":"AstTypeSingletonString","location":"0,58 - 0,64","value":"That"},{"type":"AstTypeSingletonBool","location":"0,67 - 0,71","value":true}]}}],"indexer":null},"exported":false})"; CHECK(toJson(statement) == expected); } @@ -484,7 +504,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_indexed_type_literal") AstStat* statement = expectParseStatement(R"(type StringSet = { [string]: true })"); std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,35","name":"StringSet","generics":[],"genericPacks":[],"type":{"type":"AstTypeTable","location":"0,17 - 0,35","props":[],"indexer":{"location":"0,19 - 0,33","indexType":{"type":"AstTypeReference","location":"0,20 - 0,26","name":"string","nameLocation":"0,20 - 0,26","parameters":[]},"resultType":{"type":"AstTypeSingletonBool","location":"0,29 - 0,33","value":true}}},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,35","name":"StringSet","generics":[],"genericPacks":[],"value":{"type":"AstTypeTable","location":"0,17 - 0,35","props":[],"indexer":{"location":"0,19 - 0,33","indexType":{"type":"AstTypeReference","location":"0,20 - 0,26","name":"string","nameLocation":"0,20 - 0,26","parameters":[]},"resultType":{"type":"AstTypeSingletonBool","location":"0,29 - 0,33","value":true}}},"exported":false})"; CHECK(toJson(statement) == expected); } @@ -494,7 +514,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypeFunction") AstStat* statement = expectParseStatement(R"(type fun = (string, bool, named: number) -> ())"); std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,46","name":"fun","generics":[],"genericPacks":[],"type":{"type":"AstTypeFunction","location":"0,11 - 0,46","generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,12 - 0,18","name":"string","nameLocation":"0,12 - 0,18","parameters":[]},{"type":"AstTypeReference","location":"0,20 - 0,24","name":"bool","nameLocation":"0,20 - 0,24","parameters":[]},{"type":"AstTypeReference","location":"0,33 - 0,39","name":"number","nameLocation":"0,33 - 0,39","parameters":[]}]},"argNames":[null,null,{"type":"AstArgumentName","name":"named","location":"0,26 - 0,31"}],"returnTypes":{"type":"AstTypeList","types":[]}},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,46","name":"fun","generics":[],"genericPacks":[],"value":{"type":"AstTypeFunction","location":"0,11 - 0,46","attributes":[],"generics":[],"genericPacks":[],"argTypes":{"type":"AstTypeList","types":[{"type":"AstTypeReference","location":"0,12 - 0,18","name":"string","nameLocation":"0,12 - 0,18","parameters":[]},{"type":"AstTypeReference","location":"0,20 - 0,24","name":"bool","nameLocation":"0,20 - 0,24","parameters":[]},{"type":"AstTypeReference","location":"0,33 - 0,39","name":"number","nameLocation":"0,33 - 0,39","parameters":[]}]},"argNames":[null,null,{"type":"AstArgumentName","name":"named","location":"0,26 - 0,31"}],"returnTypes":{"type":"AstTypeList","types":[]}},"exported":false})"; CHECK(toJson(statement) == expected); } @@ -507,7 +527,7 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstTypeError") AstStat* statement = parseResult.root->body.data[0]; std::string_view expected = - R"({"type":"AstStatTypeAlias","location":"0,0 - 0,9","name":"T","generics":[],"genericPacks":[],"type":{"type":"AstTypeError","location":"0,8 - 0,9","types":[],"messageIndex":0},"exported":false})"; + R"({"type":"AstStatTypeAlias","location":"0,0 - 0,9","name":"T","generics":[],"genericPacks":[],"value":{"type":"AstTypeError","location":"0,8 - 0,9","types":[],"messageIndex":0},"exported":false})"; CHECK(toJson(statement) == expected); } diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 6822ce6d..702be46b 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -163,13 +163,49 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "table_overloaded_function_prop") CHECK_EQ(symbol, "@test/global/Foo.new/overload/(string) -> number"); } +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "string_metatable_method") +{ + std::optional symbol = getDocSymbol( + R"( + local x: string = "Foo" + x:rep(2) + )", + Position(2, 12) + ); + + CHECK_EQ(symbol, "@luau/global/string.rep"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "parent_class_method") +{ + loadDefinition(R"( + declare class Foo + function bar(self, x: string): number + end + + declare class Bar extends Foo + function notbar(self, x: string): number + end + )"); + + std::optional symbol = getDocSymbol( + R"( + local x: Bar = Bar.new() + x:bar("asdf") + )", + Position(2, 11) + ); + + CHECK_EQ(symbol, "@test/globaltype/Foo.bar"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("AstQuery"); TEST_CASE_FIXTURE(Fixture, "last_argument_function_call_type") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); check(R"( local function foo() return 2 end diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 7f020b18..6a8bca05 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) +LUAU_FASTINT(LuauTypeInferRecursionLimit) using namespace Luau; @@ -27,7 +28,7 @@ template struct ACFixtureImpl : BaseType { ACFixtureImpl() - : BaseType(true, true) + : BaseType(true) { } @@ -150,40 +151,6 @@ struct ACBuiltinsFixture : ACFixtureImpl { }; -#define LUAU_CHECK_HAS_KEY(map, key) \ - do \ - { \ - auto&& _m = (map); \ - auto&& _k = (key); \ - const size_t count = _m.count(_k); \ - CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \ - if (!count) \ - { \ - MESSAGE("Keys: (count " << _m.size() << ")"); \ - for (const auto& [k, v] : _m) \ - { \ - MESSAGE("\tkey: " << k); \ - } \ - } \ - } while (false) - -#define LUAU_CHECK_HAS_NO_KEY(map, key) \ - do \ - { \ - auto&& _m = (map); \ - auto&& _k = (key); \ - const size_t count = _m.count(_k); \ - CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \ - if (count) \ - { \ - MESSAGE("Keys: (count " << _m.size() << ")"); \ - for (const auto& [k, v] : _m) \ - { \ - MESSAGE("\tkey: " << k); \ - } \ - } \ - } while (false) - TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") @@ -2265,7 +2232,7 @@ local ec = e(f@5) TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_for_overloads") { if (FFlag::LuauSolverV2) // CLI-116814 Autocomplete needs to populate expected types for function arguments correctly - // (overloads and singletons) + // (overloads and singletons) return; check(R"( local target: ((number) -> string) & ((string) -> number)) @@ -2615,7 +2582,7 @@ end TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") { if (FFlag::LuauSolverV2) // CLI-116812 AutocompleteTest.suggest_table_keys needs to populate expected types for nested - // tables without an annotation + // tables without an annotation return; check(R"( @@ -3102,7 +3069,7 @@ TEST_CASE_FIXTURE(ACBuiltinsFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singletons") { if (FFlag::LuauSolverV2) // CLI-116814 Autocomplete needs to populate expected types for function arguments correctly - // (overloads and singletons) + // (overloads and singletons) return; check(R"( @@ -3815,6 +3782,39 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_response_perf1" * doctest::timeout(0. CHECK(ac.entryMap.count("Instance")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_subtyping_recursion_limit") +{ + // TODO: in old solver, type resolve can't handle the type in this test without a stack overflow + if (!FFlag::LuauSolverV2) + return; + + ScopedFastInt luauTypeInferRecursionLimit{FInt::LuauTypeInferRecursionLimit, 10}; + + const int parts = 100; + std::string source; + + source += "function f()\n"; + + std::string prefix; + for (int i = 0; i < parts; i++) + formatAppend(prefix, "(nil|({a%d:number}&", i); + formatAppend(prefix, "(nil|{a%d:number})", parts); + for (int i = 0; i < parts; i++) + formatAppend(prefix, "))"); + + source += "local x1 : " + prefix + "\n"; + source += "local y : {a1:number} = x@1\n"; + + source += "end\n"; + + check(source); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("x1")); +} + TEST_CASE_FIXTURE(ACFixture, "strict_mode_force") { check(R"( @@ -4293,8 +4293,7 @@ end foo(@1) )"); - const std::optional EXPECTED_INSERT = - FFlag::LuauSolverV2 ? "function(...: number): number end" : "function(...): number end"; + const std::optional EXPECTED_INSERT = FFlag::LuauSolverV2 ? "function(...: number): number end" : "function(...): number end"; auto ac = autocomplete('1'); @@ -4305,4 +4304,46 @@ foo(@1) CHECK_EQ(EXPECTED_INSERT, *ac.entryMap[kGeneratedAnonymousFunctionEntryName].insertText); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_at_end_of_stmt_should_continue_as_part_of_stmt") +{ + check(R"( +local data = { x = 1 } +local var = data.@1 + )"); + auto ac = autocomplete('1'); + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("x")); + CHECK_EQ(ac.context, AutocompleteContext::Property); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_after_semicolon_should_complete_a_new_statement") +{ + check(R"( +local data = { x = 1 } +local var = data;@1 + )"); + auto ac = autocomplete('1'); + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); +} + +TEST_CASE_FIXTURE(ACBuiltinsFixture, "require_tracing") +{ + fileResolver.source["Module/A"] = R"( +return { x = 0 } + )"; + + fileResolver.source["Module/B"] = R"( +local result = require(script.Parent.A) +local x = 1 + result. + )"; + + auto ac = autocomplete("Module/B", Position{2, 21}); + + CHECK(ac.entryMap.size() == 1); + CHECK(ac.entryMap.count("x")); +} + TEST_SUITE_END(); diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index a9bf9596..40d06c85 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -132,6 +132,15 @@ ClassFixture::ClassFixture() // IndexableNumericKeyClass has a table indexer with a key type of 'number' and a return type of 'number' addIndexableClass("IndexableNumericKeyClass", numberType, numberType); + // Add a confusing derived class which shares the same name internally, but has a unique alias + TypeId duplicateBaseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test", {}}); + + getMutable(duplicateBaseClassInstanceType)->props = { + {"Method", {makeFunction(arena, duplicateBaseClassInstanceType, {}, {stringType})}}, + }; + + addGlobalBinding(globals, "confusingBaseClassInstance", duplicateBaseClassInstanceType, "@test"); + for (const auto& [name, tf] : globals.globalScope->exportedTypeBindings) persist(tf.type); diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index 058a1100..c236b49c 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -3,6 +3,7 @@ #include "Luau/AssemblyBuilderA64.h" #include "Luau/CodeAllocator.h" #include "Luau/CodeBlockUnwind.h" +#include "Luau/CodeGen.h" #include "Luau/UnwindBuilder.h" #include "Luau/UnwindBuilderDwarf2.h" #include "Luau/UnwindBuilderWin.h" diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 48bd45d7..56ba1a4f 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -3,6 +3,8 @@ #include "Luau/BytecodeBuilder.h" #include "Luau/StringUtils.h" +#include "luacode.h" + #include "ScopedFlags.h" #include "doctest.h" @@ -24,17 +26,61 @@ LUAU_FASTINT(LuauRecursionLimit) using namespace Luau; -static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1, bool enableVectors = false) +static void luauLibraryConstantLookup(const char* library, const char* member, Luau::CompileConstant* constant) +{ + // While 'vector' is built-in, because of LUA_VECTOR_SIZE VM configuration, compiler cannot provide the right default by itself + if (strcmp(library, "vector") == 0) + { + if (strcmp(member, "zero") == 0) + return Luau::setCompileConstantVector(constant, 0.0f, 0.0f, 0.0f, 0.0f); + + if (strcmp(member, "one") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 1.0f, 1.0f, 0.0f); + } + + if (strcmp(library, "Vector3") == 0) + { + if (strcmp(member, "one") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 1.0f, 1.0f, 0.0f); + + if (strcmp(member, "xAxis") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 0.0f, 0.0f, 0.0f); + } + + if (strcmp(library, "test") == 0) + { + if (strcmp(member, "some_nil") == 0) + return Luau::setCompileConstantNil(constant); + + if (strcmp(member, "some_boolean") == 0) + return Luau::setCompileConstantBoolean(constant, true); + + if (strcmp(member, "some_number") == 0) + return Luau::setCompileConstantNumber(constant, 4.75); + + if (strcmp(member, "some_string") == 0) + return Luau::setCompileConstantString(constant, "test", 4); + } +} + +static std::string compileFunction(const char* source, uint32_t id, int optimizationLevel = 1, int typeInfoLevel = 0, bool enableVectors = false) { Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); Luau::CompileOptions options; options.optimizationLevel = optimizationLevel; + options.typeInfoLevel = typeInfoLevel; if (enableVectors) { options.vectorLib = "Vector3"; options.vectorCtor = "new"; } + + static const char* kLibrariesWithConstants[] = {"vector", "Vector3", "test", nullptr}; + options.librariesWithKnownMembers = kLibrariesWithConstants; + + options.libraryMemberConstantCb = luauLibraryConstantLookup; + Luau::compileOrThrow(bcb, source, options); return bcb.dumpFunction(id); @@ -93,6 +139,8 @@ TEST_CASE("BytecodeIsStable") // Note: these aren't strictly bound to specific bytecode versions, but must monotonically increase to keep backwards compat CHECK(LBF_VECTOR == 54); CHECK(LBF_TOSTRING == 63); + CHECK(LBF_BUFFER_WRITEF64 == 77); + CHECK(LBF_VECTOR_MAX == 88); // Bytecode capture type (serialized & in-memory) CHECK(LCT_UPVAL == 2); // bytecode v1 @@ -1436,6 +1484,125 @@ RETURN R0 1 )"); } +TEST_CASE("ConstantFoldVectorArith") +{ + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a + b", 0, 2), R"( +LOADK R0 K0 [3, 6, 11] +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a - b", 0, 2), R"( +LOADK R0 K0 [-1, -2, -5] +RETURN R0 1 +)"); + + // Multiplication by infinity cannot be folded as it creates a non-zero value in W + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a * n, a * b, n * b, a * math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [2, 4, 6] +LOADK R1 K1 [2, 8, 24] +LOADK R2 K2 [4, 8, 16] +LOADK R4 K4 [1, 2, 3] +MULK R3 R4 K3 [inf] +RETURN R0 4 +)" + ); + + // Divisions creating an infinity in W cannot be constant-folded + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a / n, a / b, n / b, a / math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [0.5, 1, 1.5] +LOADK R2 K1 [1, 2, 3] +LOADK R3 K2 [2, 4, 8] +DIV R1 R2 R3 +LOADK R3 K2 [2, 4, 8] +DIVRK R2 K3 [2] R3 +LOADK R3 K4 [0, 0, 0] +RETURN R0 4 +)" + ); + + // Divisions creating an infinity in W cannot be constant-folded + CHECK_EQ( + "\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3), vector.create(2, 4, 8); return a // n, a // b, n // b", 0, 2), + R"( +LOADK R0 K0 [0, 1, 1] +LOADK R2 K1 [1, 2, 3] +LOADK R3 K2 [2, 4, 8] +IDIV R1 R2 R3 +LOADN R3 2 +LOADK R4 K2 [2, 4, 8] +IDIV R2 R3 R4 +RETURN R0 3 +)" + ); + + CHECK_EQ("\n" + compileFunction("local a = vector.create(1, 2, 3); return -a", 0, 2), R"( +LOADK R0 K0 [-1, -2, -3] +RETURN R0 1 +)"); +} + +TEST_CASE("ConstantFoldVectorArith4Wide") +{ + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a + b", 0, 2), R"( +LOADK R0 K0 [3, 6, 11, 5] +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a - b", 0, 2), R"( +LOADK R0 K0 [-1, -2, -5, 3] +RETURN R0 1 +)"); + + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a * n, a * b, n * b, a * math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [2, 4, 6, 8] +LOADK R1 K1 [2, 8, 24, 4] +LOADK R2 K2 [4, 8, 16, 2] +LOADK R3 K3 [inf, inf, inf, inf] +RETURN R0 4 +)" + ); + + CHECK_EQ( + "\n" + compileFunction( + "local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a / n, a / b, n / b, a / math.huge", 0, 2 + ), + R"( +LOADK R0 K0 [0.5, 1, 1.5, 2] +LOADK R1 K1 [0.5, 0.5, 0.375, 4] +LOADK R2 K2 [1, 0.5, 0.25, 2] +LOADK R3 K3 [0, 0, 0] +RETURN R0 4 +)" + ); + + CHECK_EQ( + "\n" + compileFunction("local n = 2; local a, b = vector.create(1, 2, 3, 4), vector.create(2, 4, 8, 1); return a // n, a // b, n // b", 0, 2), + R"( +LOADK R0 K0 [0, 1, 1, 2] +LOADK R1 K1 [0, 0, 0, 4] +LOADK R2 K2 [1, 0, 0, 2] +RETURN R0 3 +)" + ); + + CHECK_EQ("\n" + compileFunction("local a = vector.create(1, 2, 3, 4); return -a", 0, 2), R"( +LOADK R0 K0 [-1, -2, -3, -4] +RETURN R0 1 +)"); +} + TEST_CASE("ConstantFoldStringLen") { CHECK_EQ("\n" + compileFunction0("return #'string', #'', #'a', #('b')"), R"( @@ -2796,6 +2963,14 @@ TEST_CASE("TypeAliasing") CHECK_NOTHROW(Luau::compileOrThrow(bcb, "type A = number local a: A = 1", options, parseOptions)); } +TEST_CASE("TypeFunction") +{ + Luau::BytecodeBuilder bcb; + Luau::CompileOptions options; + Luau::ParseOptions parseOptions; + CHECK_NOTHROW(Luau::compileOrThrow(bcb, "type function a() return types.any end", options, parseOptions)); +} + TEST_CASE("DebugLineInfo") { Luau::BytecodeBuilder bcb; @@ -4915,36 +5090,78 @@ L0: RETURN R3 -1 )"); } -TEST_CASE("VectorLiterals") +TEST_CASE("VectorConstants") { - CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("return vector.create(1, 2)", 0, 2, 0), R"( +LOADK R0 K0 [1, 2, 0] +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("return vector.create(1, 2, 3)", 0, 2, 0), R"( LOADK R0 K0 [1, 2, 3] RETURN R0 1 )"); - CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3))", 0, 2, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("print(vector.create(1, 2, 3))", 0, 2, 0), R"( GETIMPORT R0 1 [print] LOADK R1 K2 [1, 2, 3] CALL R0 1 0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction("print(Vector3.new(1, 2, 3, 4))", 0, 2, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("print(vector.create(1, 2, 3, 4))", 0, 2, 0), R"( GETIMPORT R0 1 [print] LOADK R1 K2 [1, 2, 3, 4] CALL R0 1 0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction("return Vector3.new(0, 0, 0), Vector3.new(-0, 0, 0)", 0, 2, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("return vector.create(0, 0, 0), vector.create(-0, 0, 0)", 0, 2, 0), R"( LOADK R0 K0 [0, 0, 0] LOADK R1 K1 [-0, 0, 0] RETURN R0 2 )"); - CHECK_EQ("\n" + compileFunction("return type(Vector3.new(0, 0, 0))", 0, 2, /*enableVectors*/ true), R"( + CHECK_EQ("\n" + compileFunction("return type(vector.create(0, 0, 0))", 0, 2, 0), R"( LOADK R0 K0 ['vector'] RETURN R0 1 +)"); + + // test legacy constructor + CHECK_EQ("\n" + compileFunction("return Vector3.new(1, 2, 3)", 0, 2, 0, /*enableVectors*/ true), R"( +LOADK R0 K0 [1, 2, 3] +RETURN R0 1 +)"); +} + +TEST_CASE("VectorConstantFields") +{ + CHECK_EQ("\n" + compileFunction("return vector.one, vector.zero", 0, 2), R"( +LOADK R0 K0 [1, 1, 1] +LOADK R1 K1 [0, 0, 0] +RETURN R0 2 +)"); + + CHECK_EQ("\n" + compileFunction("return Vector3.one, Vector3.xAxis", 0, 2, 0, /*enableVectors*/ true), R"( +LOADK R0 K0 [1, 1, 1] +LOADK R1 K1 [1, 0, 0] +RETURN R0 2 +)"); + + CHECK_EQ("\n" + compileFunction("return vector.one == vector.create(1, 1, 1)", 0, 2), R"( +LOADB R0 1 +RETURN R0 1 +)"); +} + +TEST_CASE("CustomConstantFields") +{ + CHECK_EQ("\n" + compileFunction("return test.some_nil, test.some_boolean, test.some_number, test.some_string", 0, 2), R"( +LOADNIL R0 +LOADB R1 1 +LOADK R2 K0 [4.75] +LOADK R3 K1 ['test'] +RETURN R0 4 )"); } @@ -7670,6 +7887,39 @@ RETURN R0 1 ); } +TEST_CASE("BuiltinFoldingProhibitedInOptions") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.optimizationLevel = 2; + + // math.floor from the test is excluded in this list on purpose + static const char* kDisabledBuiltins[] = {"tostring", "math.abs", "math.sqrt", nullptr}; + options.disabledBuiltins = kDisabledBuiltins; + + Luau::compileOrThrow(bcb, "return math.abs(-42), math.floor(-1.5), math.sqrt(9), (tostring(2))", options); + + std::string result = bcb.dumpFunction(0); + + CHECK_EQ( + "\n" + result, + R"( +GETIMPORT R0 2 [math.abs] +LOADN R1 -42 +CALL R0 1 1 +LOADN R1 -2 +GETIMPORT R2 4 [math.sqrt] +LOADN R3 9 +CALL R2 1 1 +GETIMPORT R3 6 [tostring] +LOADN R4 2 +CALL R3 1 1 +RETURN R0 4 +)" + ); +} + TEST_CASE("LocalReassign") { // locals can be re-assigned and the register gets reused @@ -8406,6 +8656,19 @@ end ); } +TEST_CASE("BuiltinTypeVector") +{ + CHECK_EQ( + "\n" + compileTypeTable(R"( +function myfunc(test: Instance, pos: vector) +end +)"), + R"( +0: function(userdata, vector) +)" + ); +} + TEST_CASE("TypeAliasScoping") { CHECK_EQ( @@ -8491,6 +8754,23 @@ end ); } +TEST_CASE("TypeGroup") +{ + CHECK_EQ( + "\n" + compileTypeTable(R"( +function myfunc(test: (string), foo: nil) +end + +function myfunc2(test: (string | nil), foo: nil) +end +)"), + R"( +0: function(string, nil) +1: function(string?, nil) +)" + ); +} + TEST_CASE("BuiltinFoldMathK") { // we can fold math.pi at optimization level 2 @@ -8816,8 +9096,7 @@ RETURN R0 1 TEST_CASE("ArithRevK") { - // - and / have special optimized form for reverse constants; in the future, + and * will likely get compiled to ADDK/MULK - // other operators are not important enough to optimize reverse constant forms for + // - and / have special optimized form for reverse constants; in absence of type information, we can't optimize other ops CHECK_EQ( "\n" + compileFunction0(R"( local x: number = unknown @@ -8838,6 +9117,34 @@ IDIV R6 R7 R0 LOADN R8 2 POW R7 R8 R0 RETURN R1 7 +)" + ); + + // the same code with type information can optimize commutative operators (+ and *) as well + // other operators are not important enough to optimize reverse constant forms for + CHECK_EQ( + "\n" + compileFunction( + R"( +local x: number = unknown +return 2 + x, 2 - x, 2 * x, 2 / x, 2 % x, 2 // x, 2 ^ x +)", + 0, + 2, + 1 + ), + R"( +GETIMPORT R0 1 [unknown] +ADDK R1 R0 K2 [2] +SUBRK R2 K2 [2] R0 +MULK R3 R0 K2 [2] +DIVRK R4 K2 [2] R0 +LOADN R6 2 +MOD R5 R6 R0 +LOADN R7 2 +IDIV R6 R7 R0 +LOADN R8 2 +POW R7 R8 R0 +RETURN R1 7 )" ); } diff --git a/tests/Config.test.cpp b/tests/Config.test.cpp index 70d6d6d7..690c4c37 100644 --- a/tests/Config.test.cpp +++ b/tests/Config.test.cpp @@ -58,7 +58,11 @@ TEST_CASE("report_a_syntax_error") TEST_CASE("noinfer_is_still_allowed") { Config config; - auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, true); + + ConfigOptions opts; + opts.compat = true; + + auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, opts); REQUIRE(!err); CHECK_EQ(int(Luau::Mode::NoCheck), int(config.mode)); @@ -147,6 +151,10 @@ TEST_CASE("extra_globals") TEST_CASE("lint_rules_compat") { Config config; + + ConfigOptions opts; + opts.compat = true; + auto err = parseConfig( R"( {"lint": { @@ -156,7 +164,7 @@ TEST_CASE("lint_rules_compat") }} )", config, - true + opts ); REQUIRE(!err); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 376caa44..073fcf35 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -31,11 +31,12 @@ extern int optimizationLevel; void luaC_fullgc(lua_State* L); void luaC_validate(lua_State* L); +LUAU_FASTFLAG(LuauLibWhereErrorAutoreserve) LUAU_FASTFLAG(DebugLuauAbortingChecks) LUAU_FASTINT(CodegenHeuristicsInstructionLimit) -LUAU_FASTFLAG(LuauNativeAttribute) -LUAU_FASTFLAG(LuauPreserveLudataRenaming) -LUAU_FASTFLAG(LuauCodegenArmNumToVecFix) +LUAU_DYNAMIC_FASTFLAG(LuauStackLimit) +LUAU_FASTFLAG(LuauVectorLibNativeDot) +LUAU_DYNAMIC_FASTFLAG(LuauStringFormatFixC) static lua_CompileOptions defaultOptions() { @@ -639,28 +640,28 @@ TEST_CASE("CodegenSupported") TEST_CASE("Assert") { - runConformance("assert.lua"); + runConformance("assert.luau"); } TEST_CASE("Basic") { - runConformance("basic.lua"); + runConformance("basic.luau"); } TEST_CASE("Buffers") { - runConformance("buffers.lua"); + runConformance("buffers.luau"); } TEST_CASE("Math") { - runConformance("math.lua"); + runConformance("math.luau"); } TEST_CASE("Tables") { runConformance( - "tables.lua", + "tables.luau", [](lua_State* L) { lua_pushcfunction( @@ -689,97 +690,101 @@ TEST_CASE("Tables") TEST_CASE("PatternMatch") { - runConformance("pm.lua"); + runConformance("pm.luau"); } TEST_CASE("Sort") { - runConformance("sort.lua"); + runConformance("sort.luau"); } TEST_CASE("Move") { - runConformance("move.lua"); + runConformance("move.luau"); } TEST_CASE("Clear") { - runConformance("clear.lua"); + runConformance("clear.luau"); } TEST_CASE("Strings") { - runConformance("strings.lua"); + ScopedFastFlag luauStringFormatFixC{DFFlag::LuauStringFormatFixC, true}; + + runConformance("strings.luau"); } TEST_CASE("StringInterp") { - runConformance("stringinterp.lua"); + runConformance("stringinterp.luau"); } TEST_CASE("VarArg") { - runConformance("vararg.lua"); + runConformance("vararg.luau"); } TEST_CASE("Locals") { - runConformance("locals.lua"); + runConformance("locals.luau"); } TEST_CASE("Literals") { - runConformance("literals.lua"); + runConformance("literals.luau"); } TEST_CASE("Errors") { - runConformance("errors.lua"); + runConformance("errors.luau"); } TEST_CASE("Events") { - runConformance("events.lua"); + runConformance("events.luau"); } TEST_CASE("Constructs") { - runConformance("constructs.lua"); + runConformance("constructs.luau"); } TEST_CASE("Closure") { - runConformance("closure.lua"); + runConformance("closure.luau"); } TEST_CASE("Calls") { - runConformance("calls.lua"); + ScopedFastFlag LuauStackLimit{DFFlag::LuauStackLimit, true}; + + runConformance("calls.luau"); } TEST_CASE("Attrib") { - runConformance("attrib.lua"); + runConformance("attrib.luau"); } TEST_CASE("GC") { - runConformance("gc.lua"); + runConformance("gc.luau"); } TEST_CASE("Bitwise") { - runConformance("bitwise.lua"); + runConformance("bitwise.luau"); } TEST_CASE("UTF8") { - runConformance("utf8.lua"); + runConformance("utf8.luau"); } TEST_CASE("Coroutine") { - runConformance("coroutine.lua"); + runConformance("coroutine.luau"); } static int cxxthrow(lua_State* L) @@ -793,8 +798,10 @@ static int cxxthrow(lua_State* L) TEST_CASE("PCall") { + ScopedFastFlag LuauStackLimit{DFFlag::LuauStackLimit, true}; + runConformance( - "pcall.lua", + "pcall.luau", [](lua_State* L) { lua_pushcfunction(L, cxxthrow, "cxxthrow"); @@ -820,13 +827,11 @@ TEST_CASE("PCall") TEST_CASE("Pack") { - runConformance("tpack.lua"); + runConformance("tpack.luau"); } TEST_CASE("Vector") { - ScopedFastFlag luauCodegenArmNumToVecFix{FFlag::LuauCodegenArmNumToVecFix, true}; - lua_CompileOptions copts = defaultOptions(); Luau::CodeGen::CompilationOptions nativeOpts = defaultCodegenOptions(); @@ -867,7 +872,7 @@ TEST_CASE("Vector") } runConformance( - "vector.lua", + "vector.luau", [](lua_State* L) { setupVectorHelpers(L); @@ -880,6 +885,28 @@ TEST_CASE("Vector") ); } +TEST_CASE("VectorLibrary") +{ + ScopedFastFlag luauVectorLibNativeDot{FFlag::LuauVectorLibNativeDot, true}; + + lua_CompileOptions copts = defaultOptions(); + + SUBCASE("O0") + { + copts.optimizationLevel = 0; + } + SUBCASE("O1") + { + copts.optimizationLevel = 1; + } + SUBCASE("O2") + { + copts.optimizationLevel = 2; + } + + runConformance("vector_library.luau", [](lua_State* L) {}, nullptr, nullptr, &copts); +} + static void populateRTTI(lua_State* L, Luau::TypeId type) { if (auto p = Luau::get(type)) @@ -939,6 +966,10 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) lua_pushstring(L, "function"); } + else if (auto c = Luau::get(type)) + { + lua_pushstring(L, c->name.c_str()); + } else { LUAU_ASSERT(!"Unknown type"); @@ -948,7 +979,7 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { runConformance( - "types.lua", + "types.luau", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; @@ -973,12 +1004,12 @@ TEST_CASE("Types") TEST_CASE("DateTime") { - runConformance("datetime.lua"); + runConformance("datetime.luau"); } TEST_CASE("Debug") { - runConformance("debug.lua"); + runConformance("debug.luau"); } TEST_CASE("Debugger") @@ -1005,7 +1036,7 @@ TEST_CASE("Debugger") copts.debugLevel = 2; runConformance( - "debugger.lua", + "debugger.luau", [](lua_State* L) { lua_Callbacks* cb = lua_callbacks(L); @@ -1180,7 +1211,7 @@ TEST_CASE("NDebugGetUpValue") copts.optimizationLevel = 0; runConformance( - "ndebug_upvalues.lua", + "ndebug_upvalues.luau", nullptr, [](lua_State* L) { @@ -1344,6 +1375,25 @@ TEST_CASE("ApiTables") CHECK(strcmp(lua_tostring(L, -1), "test") == 0); lua_pop(L, 1); + // lua_clonetable + lua_clonetable(L, -1); + + CHECK(lua_getfield(L, -1, "key") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // modify clone + lua_pushnumber(L, 456.0); + lua_rawsetfield(L, -2, "key"); + + // remove clone + lua_pop(L, 1); + + // check original + CHECK(lua_getfield(L, -1, "key") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + // lua_cleartable lua_cleartable(L, -1); lua_pushnil(L); @@ -1392,7 +1442,7 @@ TEST_CASE("ApiIter") TEST_CASE("ApiCalls") { - StateRef globalState = runConformance("apicalls.lua", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); + StateRef globalState = runConformance("apicalls.luau", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); lua_State* L = globalState.get(); // lua_call @@ -1657,7 +1707,31 @@ TEST_CASE("ApiBuffer") lua_pop(L, 1); } -TEST_CASE("AllocApi") +int slowlyOverflowStack(lua_State* L) +{ + for (int i = 0; i < LUAI_MAXCSTACK * 2; i++) + { + luaL_checkstack(L, 1, "test"); + lua_pushnumber(L, 1.0); + } + + return 0; +} + +TEST_CASE("ApiStack") +{ + ScopedFastFlag luauLibWhereErrorAutoreserve{FFlag::LuauLibWhereErrorAutoreserve, true}; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_pushcfunction(L, slowlyOverflowStack, "foo"); + int result = lua_pcall(L, 0, 0, 0); + REQUIRE(result == LUA_ERRRUN); + CHECK(strcmp(luaL_checkstring(L, -1), "stack overflow (test)") == 0); +} + +TEST_CASE("ApiAlloc") { int ud = 0; StateRef globalState(lua_newstate(limitedRealloc, &ud), lua_close); @@ -1695,7 +1769,7 @@ TEST_CASE("ExceptionObject") return ExceptionResult{false, ""}; }; - StateRef globalState = runConformance("exceptions.lua", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); + StateRef globalState = runConformance("exceptions.luau", nullptr, nullptr, lua_newstate(limitedRealloc, nullptr)); lua_State* L = globalState.get(); { @@ -1734,7 +1808,7 @@ TEST_CASE("ExceptionObject") TEST_CASE("IfElseExpression") { - runConformance("ifelseexpr.lua"); + runConformance("ifelseexpr.luau"); } // Optionally returns debug info for the first Luau stack frame that is encountered on the callstack. @@ -1772,7 +1846,7 @@ TEST_CASE("TagMethodError") auto yieldCallback = [](lua_State* L) {}; runConformance( - "tmerror.lua", + "tmerror.luau", [](lua_State* L) { auto* cb = lua_callbacks(L); @@ -1810,7 +1884,7 @@ TEST_CASE("Coverage") copts.coverageLevel = 2; runConformance( - "coverage.lua", + "coverage.luau", [](lua_State* L) { lua_pushcfunction( @@ -1864,7 +1938,7 @@ TEST_CASE("Coverage") TEST_CASE("StringConversion") { - runConformance("strconv.lua"); + runConformance("strconv.luau"); } TEST_CASE("GCDump") @@ -1977,7 +2051,7 @@ TEST_CASE("Interrupt") static int index; - StateRef globalState = runConformance("interrupt.lua", nullptr, nullptr, nullptr, &copts); + StateRef globalState = runConformance("interrupt.luau", nullptr, nullptr, nullptr, &copts); lua_State* L = globalState.get(); @@ -2192,9 +2266,7 @@ TEST_CASE("UserdataApi") lua_getuserdatametatable(L, 50); lua_setmetatable(L, -2); - void* ud8 = lua_newuserdatatagged(L, 16, 51); - lua_getuserdatametatable(L, 51); - lua_setmetatable(L, -2); + void* ud8 = lua_newuserdatataggedwithmetatable(L, 16, 51); CHECK(luaL_checkudata(L, -2, "udata3") == ud7); CHECK(luaL_checkudata(L, -1, "udata4") == ud8); @@ -2251,20 +2323,17 @@ TEST_CASE("LightuserdataApi") lua_pop(L, 1); - if (FFlag::LuauPreserveLudataRenaming) - { - // Still possible to rename the global lightuserdata name using a metatable - lua_pushlightuserdata(L, value); - CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0); + // Still possible to rename the global lightuserdata name using a metatable + lua_pushlightuserdata(L, value); + CHECK(strcmp(luaL_typename(L, -1), "userdata") == 0); - lua_createtable(L, 0, 1); - lua_pushstring(L, "luserdata"); - lua_setfield(L, -2, "__type"); - lua_setmetatable(L, -2); + lua_createtable(L, 0, 1); + lua_pushstring(L, "luserdata"); + lua_setfield(L, -2, "__type"); + lua_setmetatable(L, -2); - CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0); - lua_pop(L, 1); - } + CHECK(strcmp(luaL_typename(L, -1), "luserdata") == 0); + lua_pop(L, 1); globalState.reset(); } @@ -2283,7 +2352,7 @@ TEST_CASE("DebugApi") TEST_CASE("Iter") { - runConformance("iter.lua"); + runConformance("iter.luau"); } const int kInt64Tag = 1; @@ -2312,7 +2381,7 @@ static void pushInt64(lua_State* L, int64_t value) TEST_CASE("Userdata") { runConformance( - "userdata.lua", + "userdata.luau", [](lua_State* L) { // create metatable with all the metamethods @@ -2534,7 +2603,7 @@ TEST_CASE("Userdata") TEST_CASE("SafeEnv") { - runConformance("safeenv.lua"); + runConformance("safeenv.luau"); } TEST_CASE("Native") @@ -2554,7 +2623,7 @@ TEST_CASE("Native") } runConformance( - "native.lua", + "native.luau", [](lua_State* L) { setupNativeHelpers(L); @@ -2569,7 +2638,7 @@ TEST_CASE("NativeTypeAnnotations") return; runConformance( - "native_types.lua", + "native_types.luau", [](lua_State* L) { setupNativeHelpers(L); @@ -2632,7 +2701,7 @@ TEST_CASE("NativeUserdata") } runConformance( - "native_userdata.lua", + "native_userdata.luau", [](lua_State* L) { Luau::CodeGen::setUserdataRemapper( @@ -2917,8 +2986,6 @@ TEST_CASE("NativeAttribute") if (!codegen || !luau_codegen_supported()) return; - ScopedFastFlag sffs[] = {{FFlag::LuauNativeAttribute, true}}; - std::string source = R"R( @native local function sum(x, y) diff --git a/tests/ConstraintGeneratorFixture.cpp b/tests/ConstraintGeneratorFixture.cpp index 7f168465..e10e60d4 100644 --- a/tests/ConstraintGeneratorFixture.cpp +++ b/tests/ConstraintGeneratorFixture.cpp @@ -10,6 +10,7 @@ namespace Luau ConstraintGeneratorFixture::ConstraintGeneratorFixture() : Fixture() , mainModule(new Module) + , simplifier(newSimplifier(NotNull{&arena}, builtinTypes)) , forceTheFlag{FFlag::LuauSolverV2, true} { mainModule->name = "MainModule"; @@ -21,10 +22,14 @@ ConstraintGeneratorFixture::ConstraintGeneratorFixture() void ConstraintGeneratorFixture::generateConstraints(const std::string& code) { AstStatBlock* root = parse(code); - dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); + dfg = std::make_unique( + DataFlowGraphBuilder::build(root, NotNull{&mainModule->defArena}, NotNull{&mainModule->keyArena}, NotNull{&ice}) + ); cg = std::make_unique( mainModule, NotNull{&normalizer}, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), @@ -42,7 +47,21 @@ void ConstraintGeneratorFixture::generateConstraints(const std::string& code) void ConstraintGeneratorFixture::solve(const std::string& code) { generateConstraints(code); - ConstraintSolver cs{NotNull{&normalizer}, NotNull{rootScope}, constraints, "MainModule", NotNull(&moduleResolver), {}, &logger, {}}; + ConstraintSolver cs{ + NotNull{&normalizer}, + NotNull{simplifier.get()}, + NotNull{&typeFunctionRuntime}, + NotNull{rootScope}, + constraints, + NotNull{&cg->scopeToFunction}, + "MainModule", + NotNull(&moduleResolver), + {}, + &logger, + NotNull{dfg.get()}, + {} + }; + cs.run(); } diff --git a/tests/ConstraintGeneratorFixture.h b/tests/ConstraintGeneratorFixture.h index ff362be1..800bf873 100644 --- a/tests/ConstraintGeneratorFixture.h +++ b/tests/ConstraintGeneratorFixture.h @@ -4,8 +4,9 @@ #include "Luau/ConstraintGenerator.h" #include "Luau/ConstraintSolver.h" #include "Luau/DcrLogger.h" -#include "Luau/TypeArena.h" +#include "Luau/EqSatSimplification.h" #include "Luau/Module.h" +#include "Luau/TypeArena.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -20,6 +21,9 @@ struct ConstraintGeneratorFixture : Fixture DcrLogger logger; UnifierSharedState sharedState{&ice}; Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + SimplifierPtr simplifier; + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{NotNull{&ice}, NotNull{&limits}}; std::unique_ptr dfg; std::unique_ptr cg; diff --git a/tests/DataFlowGraph.test.cpp b/tests/DataFlowGraph.test.cpp index 4ea656ee..1b7e243c 100644 --- a/tests/DataFlowGraph.test.cpp +++ b/tests/DataFlowGraph.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/DataFlowGraph.h" #include "Fixture.h" +#include "Luau/Def.h" #include "Luau/Error.h" #include "Luau/Parser.h" @@ -18,6 +19,8 @@ struct DataFlowGraphFixture // Only needed to fix the operator== reflexivity of an empty Symbol. ScopedFastFlag dcr{FFlag::LuauSolverV2, true}; + DefArena defArena; + RefinementKeyArena keyArena; InternalErrorReporter handle; Allocator allocator; @@ -32,7 +35,7 @@ struct DataFlowGraphFixture if (!parseResult.errors.empty()) throw ParseErrors(std::move(parseResult.errors)); module = parseResult.root; - graph = DataFlowGraphBuilder::build(module, NotNull{&handle}); + graph = DataFlowGraphBuilder::build(module, NotNull{&defArena}, NotNull{&keyArena}, NotNull{&handle}); } template diff --git a/tests/Differ.test.cpp b/tests/Differ.test.cpp index a2b2280b..8050974e 100644 --- a/tests/Differ.test.cpp +++ b/tests/Differ.test.cpp @@ -234,7 +234,7 @@ TEST_CASE_FIXTURE(DifferFixture, "right_cyclic_table_left_table_property_wrong") TEST_CASE_FIXTURE(DifferFixture, "equal_table_two_cyclic_tables_are_not_different") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function id(x: a): a @@ -1473,7 +1473,7 @@ TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "equal_metatable") TEST_CASE_FIXTURE(DifferFixtureWithBuiltins, "metatable_normal") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local metaFoo = { diff --git a/tests/EqSat.language.test.cpp b/tests/EqSat.language.test.cpp index 282d4ad2..fd1bde57 100644 --- a/tests/EqSat.language.test.cpp +++ b/tests/EqSat.language.test.cpp @@ -11,9 +11,7 @@ LUAU_EQSAT_ATOM(I32, int); LUAU_EQSAT_ATOM(Bool, bool); LUAU_EQSAT_ATOM(Str, std::string); -LUAU_EQSAT_FIELD(Left); -LUAU_EQSAT_FIELD(Right); -LUAU_EQSAT_NODE_FIELDS(Add, Left, Right); +LUAU_EQSAT_NODE_ARRAY(Add, 2); using namespace Luau; @@ -117,8 +115,8 @@ TEST_CASE("node_field") Add add{left, right}; - EqSat::Id left2 = add.field(); - EqSat::Id right2 = add.field(); + EqSat::Id left2 = add.operands()[0]; + EqSat::Id right2 = add.operands()[1]; CHECK(left == left2); CHECK(left != right2); @@ -135,10 +133,10 @@ TEST_CASE("language_operands") const Add* add = v2.get(); REQUIRE(add); - EqSat::Slice actual = v2.operands(); + EqSat::Slice actual = v2.operands(); CHECK(actual.size() == 2); - CHECK(actual[0] == add->field()); - CHECK(actual[1] == add->field()); + CHECK(actual[0] == add->operands()[0]); + CHECK(actual[1] == add->operands()[1]); } TEST_SUITE_END(); diff --git a/tests/EqSatSimplification.test.cpp b/tests/EqSatSimplification.test.cpp new file mode 100644 index 00000000..6fe2660f --- /dev/null +++ b/tests/EqSatSimplification.test.cpp @@ -0,0 +1,714 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "Luau/EqSatSimplification.h" +#include "Luau/Type.h" + +using namespace Luau; + +struct ESFixture : Fixture +{ + ScopedFastFlag newSolverOnly{FFlag::LuauSolverV2, true}; + + TypeArena arena_; + const NotNull arena{&arena_}; + + SimplifierPtr simplifier; + + TypeId parentClass; + TypeId childClass; + TypeId anotherChild; + TypeId unrelatedClass; + + TypeId genericT = arena_.addType(GenericType{"T"}); + TypeId genericU = arena_.addType(GenericType{"U"}); + + TypeId numberToString = + arena_.addType(FunctionType{arena_.addTypePack({builtinTypes->numberType}), arena_.addTypePack({builtinTypes->stringType})}); + + TypeId stringToNumber = + arena_.addType(FunctionType{arena_.addTypePack({builtinTypes->stringType}), arena_.addTypePack({builtinTypes->numberType})}); + + ESFixture() + : simplifier(newSimplifier(arena, builtinTypes)) + { + createSomeClasses(&frontend); + + ScopePtr moduleScope = frontend.globals.globalScope; + + parentClass = moduleScope->linearSearchForBinding("Parent")->typeId; + childClass = moduleScope->linearSearchForBinding("Child")->typeId; + anotherChild = moduleScope->linearSearchForBinding("AnotherChild")->typeId; + unrelatedClass = moduleScope->linearSearchForBinding("Unrelated")->typeId; + } + + std::optional simplifyStr(TypeId ty) + { + auto res = eqSatSimplify(NotNull{simplifier.get()}, ty); + LUAU_ASSERT(res); + return toString(res->result); + } + + TypeId tbl(TableType::Props props) + { + return arena->addType(TableType{std::move(props), std::nullopt, TypeLevel{}, TableState::Sealed}); + } +}; + +TEST_SUITE_BEGIN("EqSatSimplification"); + +TEST_CASE_FIXTURE(ESFixture, "primitive") +{ + CHECK("number" == simplifyStr(builtinTypes->numberType)); +} + +TEST_CASE_FIXTURE(ESFixture, "number | number") +{ + TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->numberType}}); + + CHECK("number" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "number | string") +{ + CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = number | t1") +{ + TypeId ty = arena->freshType(builtinTypes, nullptr); + asMutable(ty)->ty.emplace(std::vector{builtinTypes->numberType, ty}); + + CHECK("number" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "number | string | number") +{ + TypeId ty = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType, builtinTypes->numberType}}); + + CHECK("number | string" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "string | (number | string) | number") +{ + TypeId u1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}); + TypeId u2 = arena->addType(UnionType{{builtinTypes->stringType, u1, builtinTypes->numberType}}); + + CHECK("number | string" == simplifyStr(u2)); +} + +TEST_CASE_FIXTURE(ESFixture, "string | any") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->anyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "any | string") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "any | never") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->anyType, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string | unknown") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->unknownType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "unknown | string") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "unknown | never") +{ + CHECK("unknown" == simplifyStr(arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string | never") +{ + CHECK("string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string | never | number") +{ + CHECK("number | string" == simplifyStr(arena->addType(UnionType{{builtinTypes->stringType, builtinTypes->neverType, builtinTypes->numberType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & string") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & number") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->numberType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & unknown") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, builtinTypes->unknownType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "never & string") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->neverType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & (unknown | never)") +{ + CHECK( + "string" == simplifyStr(arena->addType( + IntersectionType{{builtinTypes->stringType, arena->addType(UnionType{{builtinTypes->unknownType, builtinTypes->neverType}})}} + )) + ); +} + +TEST_CASE_FIXTURE(ESFixture, "true | false") +{ + CHECK("boolean" == simplifyStr(arena->addType(UnionType{{builtinTypes->trueType, builtinTypes->falseType}}))); +} + +/* + * Intuitively, if we have a type like + * + * x where x = A & B & (C | D | x) + * + * We know that x is certainly not larger than A & B. + * We also know that the union (C | D | x) can be rewritten `(C | D | (A & B & (C | D | x))) + * This tells us that the union part is not smaller than A & B. + * We can therefore discard the union entirely and simplify this type to A & B + */ +TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (number | t1)") +{ + TypeId intersectionTy = arena->addType(BlockedType{}); + TypeId unionTy = arena->addType(UnionType{{builtinTypes->numberType, intersectionTy}}); + + asMutable(intersectionTy)->ty.emplace(std::vector{builtinTypes->stringType, unionTy}); + + CHECK("string" == simplifyStr(intersectionTy)); +} + +TEST_CASE_FIXTURE(ESFixture, "t1 where t1 = string & (unknown | t1)") +{ + TypeId intersectionTy = arena->addType(BlockedType{}); + TypeId unionTy = arena->addType(UnionType{{builtinTypes->unknownType, intersectionTy}}); + + asMutable(intersectionTy)->ty.emplace(std::vector{builtinTypes->stringType, unionTy}); + + CHECK("string" == simplifyStr(intersectionTy)); +} + +TEST_CASE_FIXTURE(ESFixture, "error | unknown") +{ + CHECK("any" == simplifyStr(arena->addType(UnionType{{builtinTypes->errorType, builtinTypes->unknownType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "\"hello\" | string") +{ + CHECK("string" == simplifyStr(arena->addType(UnionType{{arena->addType(SingletonType{StringSingleton{"hello"}}), builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "\"hello\" | \"world\" | \"hello\"") +{ + CHECK( + "\"hello\" | \"world\"" == simplifyStr(arena->addType(UnionType{{ + arena->addType(SingletonType{StringSingleton{"hello"}}), + arena->addType(SingletonType{StringSingleton{"world"}}), + arena->addType(SingletonType{StringSingleton{"hello"}}), + }})) + ); +} + +TEST_CASE_FIXTURE(ESFixture, "nil | boolean | number | string | thread | function | table | class | buffer") +{ + CHECK( + "unknown" == simplifyStr(arena->addType(UnionType{{ + builtinTypes->nilType, + builtinTypes->booleanType, + builtinTypes->numberType, + builtinTypes->stringType, + builtinTypes->threadType, + builtinTypes->functionType, + builtinTypes->tableType, + builtinTypes->classType, + builtinTypes->bufferType, + }})) + ); +} + +TEST_CASE_FIXTURE(ESFixture, "Parent & number") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{parentClass, builtinTypes->numberType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & Parent") +{ + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{childClass, parentClass}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & Unrelated") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{childClass, unrelatedClass}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child | Parent") +{ + CHECK("Parent" == simplifyStr(arena->addType(UnionType{{childClass, parentClass}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "class | Child") +{ + CHECK("class" == simplifyStr(arena->addType(UnionType{{builtinTypes->classType, childClass}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Parent | class | Child") +{ + CHECK("class" == simplifyStr(arena->addType(UnionType{{parentClass, builtinTypes->classType, childClass}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Parent | Unrelated") +{ + CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{parentClass, unrelatedClass}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "never | Parent | Unrelated") +{ + CHECK("Parent | Unrelated" == simplifyStr(arena->addType(UnionType{{builtinTypes->neverType, parentClass, unrelatedClass}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "never | Parent | (number & string) | Unrelated") +{ + CHECK( + "Parent | Unrelated" == simplifyStr(arena->addType(UnionType{ + {builtinTypes->neverType, + parentClass, + arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}), + unrelatedClass} + })) + ); +} + +TEST_CASE_FIXTURE(ESFixture, "T & U") +{ + CHECK("T & U" == simplifyStr(arena->addType(IntersectionType{{genericT, genericU}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "boolean & true") +{ + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->booleanType, builtinTypes->trueType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "boolean & (true | number | string | thread | function | table | class | buffer)") +{ + TypeId truthy = arena->addType(UnionType{{ + builtinTypes->trueType, + builtinTypes->numberType, + builtinTypes->stringType, + builtinTypes->threadType, + builtinTypes->functionType, + builtinTypes->tableType, + builtinTypes->classType, + builtinTypes->bufferType, + }}); + + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->booleanType, truthy}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "boolean & ~(false?)") +{ + CHECK("true" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->booleanType, builtinTypes->truthyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "false & ~(false?)") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->falseType, builtinTypes->truthyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (number) -> string") +{ + CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, numberToString}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (number) -> string") +{ + CHECK("(number) -> string" == simplifyStr(arena->addType(UnionType{{numberToString, numberToString}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & function") +{ + CHECK("(number) -> string" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->functionType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & boolean") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->booleanType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & string") +{ + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & ~function") +{ + TypeId notFunction = arena->addType(NegationType{builtinTypes->functionType}); + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{numberToString, notFunction}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string | function") +{ + CHECK("function" == simplifyStr(arena->addType(UnionType{{numberToString, builtinTypes->functionType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string & (string) -> number") +{ + CHECK("((number) -> string) & ((string) -> number)" == simplifyStr(arena->addType(IntersectionType{{numberToString, stringToNumber}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number) -> string | (string) -> number") +{ + CHECK("((number) -> string) | ((string) -> number)" == simplifyStr(arena->addType(UnionType{{numberToString, stringToNumber}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "add") +{ + CHECK( + "number" == + simplifyStr(arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().addFunc, {builtinTypes->numberType, builtinTypes->numberType}})) + ); +} + +TEST_CASE_FIXTURE(ESFixture, "union") +{ + CHECK( + "number" == + simplifyStr(arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().unionFunc, {builtinTypes->numberType, builtinTypes->numberType}})) + ); +} + +TEST_CASE_FIXTURE(ESFixture, "never & ~string") +{ + CHECK( + "never" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->neverType, arena->addType(NegationType{builtinTypes->stringType})}})) + ); +} + +TEST_CASE_FIXTURE(ESFixture, "blocked & never") +{ + const TypeId blocked = arena->addType(BlockedType{}); + + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{blocked, builtinTypes->neverType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "blocked & ~number & function") +{ + const TypeId blocked = arena->addType(BlockedType{}); + const TypeId notNumber = arena->addType(NegationType{builtinTypes->numberType}); + + const TypeId ty = arena->addType(IntersectionType{{blocked, notNumber, builtinTypes->functionType}}); + + std::string expected = toString(blocked) + " & function"; + + CHECK(expected == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "(number | boolean | string | nil | table) & (false | nil)") +{ + const TypeId t1 = arena->addType( + UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->stringType, builtinTypes->nilType, builtinTypes->tableType}} + ); + + CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(number | boolean | nil) & (false | nil)") +{ + const TypeId t1 = arena->addType(UnionType{{builtinTypes->numberType, builtinTypes->booleanType, builtinTypes->nilType}}); + + CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(boolean | nil) & (false | nil)") +{ + const TypeId t1 = arena->addType(UnionType{{builtinTypes->booleanType, builtinTypes->nilType}}); + + CHECK("false?" == simplifyStr(arena->addType(IntersectionType{{t1, builtinTypes->falsyType}}))); +} + +// (('a & false) | ('a & nil)) | number + +// Child & ~Parent +// ~Parent & Child +// ~Child & Parent +// Parent & ~Child +// ~Child & ~Parent +// ~Parent & ~Child + +TEST_CASE_FIXTURE(ESFixture, "free & string & number") +{ + Scope scope{builtinTypes->anyTypePack}; + const TypeId freeTy = arena->freshType(builtinTypes, &scope); + + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{freeTy, builtinTypes->numberType, builtinTypes->stringType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(blocked & number) | (blocked & number)") +{ + const TypeId blocked = arena->addType(BlockedType{}); + const TypeId u = arena->addType(IntersectionType{{blocked, builtinTypes->numberType}}); + const TypeId ty = arena->addType(UnionType{{u, u}}); + + const std::string blockedStr = toString(blocked); + + CHECK(blockedStr + " & number" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "{} & unknown") +{ + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{tbl({}), builtinTypes->unknownType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{} & table") +{ + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{tbl({}), builtinTypes->tableType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{} & ~(false?)") +{ + CHECK("{ }" == simplifyStr(arena->addType(IntersectionType{{tbl({}), builtinTypes->truthyType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: number}") +{ + const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}}); + const TypeId hasX = tbl({{"x", builtinTypes->numberType}}); + + const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}}); + auto res = eqSatSimplify(NotNull{simplifier.get()}, ty); + + CHECK("{ x: number }" == toString(res->result)); + + // Also assert that we don't allocate a fresh TableType in this case. + CHECK(follow(res->result) == hasX); +} + +TEST_CASE_FIXTURE(ESFixture, "{x: number?} & {x: ~(false?)}") +{ + const TypeId hasOptionalX = tbl({{"x", builtinTypes->optionalNumberType}}); + const TypeId hasX = tbl({{"x", builtinTypes->truthyType}}); + + const TypeId ty = arena->addType(IntersectionType{{hasOptionalX, hasX}}); + auto res = eqSatSimplify(NotNull{simplifier.get()}, ty); + + CHECK("{ x: number }" == toString(res->result)); +} + +TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) }") +{ + // {x: number?}? + const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}}); + + // {x: ~(false?)} + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + + const TypeId ty = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}}); + + CHECK("{ x: number }" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "never | (({ x: number? }?) & { x: ~(false?) })") +{ + // {x: number?}? + const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}}); + + // {x: ~(false?)} + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + + // ({x: number?}?) & {x: ~(false?)} + const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy}}); + + const TypeId ty = arena->addType(UnionType{{builtinTypes->neverType, intersectionTy}}); + + CHECK("{ x: number }" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "({ x: number? }?) & { x: ~(false?) } & ~(false?)") +{ + // {x: number?}? + const TypeId xWithOptionalNumber = arena->addType(UnionType{{tbl({{"x", builtinTypes->optionalNumberType}}), builtinTypes->nilType}}); + + // {x: ~(false?)} + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + + // ({x: number?}?) & {x: ~(false?)} & ~(false?) + const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}}); + + CHECK("{ x: number }" == simplifyStr(intersectionTy)); +} + +#if 0 +// TODO +TEST_CASE_FIXTURE(ESFixture, "(({ x: number? }?) & { x: ~(false?) } & ~(false?)) | number") +{ + // ({ x: number? }?) & { x: ~(false?) } & ~(false?) + const TypeId xWithOptionalNumber = tbl({{"x", builtinTypes->optionalNumberType}}); + const TypeId xWithTruthy = tbl({{"x", builtinTypes->truthyType}}); + const TypeId intersectionTy = arena->addType(IntersectionType{{xWithOptionalNumber, xWithTruthy, builtinTypes->truthyType}}); + const TypeId ty = arena->addType(UnionType{{intersectionTy, builtinTypes->numberType}}); + + CHECK("{ x: number } | number" == simplifyStr(ty)); +} +#endif + +TEST_CASE_FIXTURE(ESFixture, "number & no-refine") +{ + CHECK("number" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->numberType, builtinTypes->noRefineType}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{ x: number } & ~boolean") +{ + const TypeId tblTy = tbl(TableType::Props{{"x", builtinTypes->numberType}}); + + const TypeId ty = arena->addType(IntersectionType{{tblTy, arena->addType(NegationType{builtinTypes->booleanType})}}); + + CHECK("{ x: number }" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "(nil & string)?") +{ + const TypeId nilAndString = arena->addType(IntersectionType{{builtinTypes->nilType, builtinTypes->stringType}}); + const TypeId ty = arena->addType(UnionType{{nilAndString, builtinTypes->nilType}}); + + CHECK("nil" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "string & \"hi\"") +{ + const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}}); + + CHECK("\"hi\"" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, hi}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "string & (\"hi\" | \"bye\")") +{ + const TypeId hi = arena->addType(SingletonType{StringSingleton{"hi"}}); + const TypeId bye = arena->addType(SingletonType{StringSingleton{"bye"}}); + + CHECK("\"bye\" | \"hi\"" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, arena->addType(UnionType{{hi, bye}})}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(\"err\" | \"ok\") & ~\"ok\"") +{ + TypeId err = arena->addType(SingletonType{StringSingleton{"err"}}); + TypeId ok1 = arena->addType(SingletonType{StringSingleton{"ok"}}); + TypeId ok2 = arena->addType(SingletonType{StringSingleton{"ok"}}); + + TypeId ty = arena->addType(IntersectionType{{arena->addType(UnionType{{err, ok1}}), arena->addType(NegationType{ok2})}}); + + CHECK("\"err\"" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & ~Child") +{ + const TypeId ty = + arena->addType(IntersectionType{{arena->addType(UnionType{{childClass, unrelatedClass}}), arena->addType(NegationType{childClass})}}); + + CHECK("Unrelated" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "string & ~Child") +{ + CHECK("string" == simplifyStr(arena->addType(IntersectionType{{builtinTypes->stringType, arena->addType(NegationType{childClass})}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(Child | Unrelated) & Child") +{ + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{arena->addType(UnionType{{childClass, unrelatedClass}}), childClass}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "(Child | AnotherChild) & ~Child") +{ + CHECK("Child" == simplifyStr(arena->addType(IntersectionType{{arena->addType(UnionType{{childClass, anotherChild}}), childClass}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: never }") +{ + const TypeId ty = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->neverType}}); + + CHECK("never" == simplifyStr(ty)); +} + +TEST_CASE_FIXTURE(ESFixture, "{ tag: \"Part\", x: number? } & { x: string }") +{ + const TypeId leftTable = tbl({{"tag", arena->addType(SingletonType{StringSingleton{"Part"}})}, {"x", builtinTypes->optionalNumberType}}); + const TypeId rightTable = tbl({{"x", builtinTypes->stringType}}); + + CHECK("never" == simplifyStr(arena->addType(IntersectionType{{leftTable, rightTable}}))); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & add") +{ + const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}}); + const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().addFunc, {u, parentClass}, {}}); + + const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}}); + + CHECK("Child & add" == simplifyStr(intersection)); +} + +TEST_CASE_FIXTURE(ESFixture, "Child & intersect") +{ + const TypeId u = arena->addType(UnionType{{childClass, anotherChild, builtinTypes->stringType}}); + const TypeId intersectTf = arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().intersectFunc, {u, parentClass}, {}}); + + const TypeId intersection = arena->addType(IntersectionType{{childClass, intersectTf}}); + + CHECK("Child" == simplifyStr(intersection)); +} + +TEST_CASE_FIXTURE(ESFixture, "lt == boolean") +{ + std::vector> cases{ + {builtinTypes->numberType, arena->addType(BlockedType{})}, + {builtinTypes->stringType, arena->addType(BlockedType{})}, + {arena->addType(BlockedType{}), builtinTypes->numberType}, + {arena->addType(BlockedType{}), builtinTypes->stringType}, + }; + + for (const auto& [lhs, rhs] : cases) + { + const TypeId tfun = arena->addType(TypeFunctionInstanceType{builtinTypeFunctions().ltFunc, {lhs, rhs}}); + CHECK("boolean" == simplifyStr(tfun)); + } +} + +TEST_CASE_FIXTURE(ESFixture, "unknown & ~string") +{ + CHECK_EQ( + "~string", simplifyStr(arena->addType(IntersectionType{{builtinTypes->unknownType, arena->addType(NegationType{builtinTypes->stringType})}})) + ); +} + +TEST_CASE_FIXTURE(ESFixture, "string & ~\"foo\"") +{ + CHECK_EQ( + "string & ~\"foo\"", + simplifyStr(arena->addType( + IntersectionType{{builtinTypes->stringType, arena->addType(NegationType{arena->addType(SingletonType{StringSingleton{"foo"}})})}} + )) + ); +} + +// {someKey: ~any} +// +// Maybe something we could do here is to try to reduce the key, get the +// class->node mapping, and skip the extraction process if the class corresponds +// to TNever. + +// t1 where t1 = add, number> + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 0b0e1b7c..dcb228a3 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -16,6 +16,7 @@ #include "doctest.h" #include +#include #include #include #include @@ -24,9 +25,9 @@ static const char* mainModuleName = "MainModule"; LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAG(DebugLuauFreezeArena); LUAU_FASTFLAG(DebugLuauLogSolverToJsonFile) -LUAU_FASTFLAG(LuauDCRMagicFunctionTypeChecker); + +LUAU_FASTFLAGVARIABLE(DebugLuauForceAllNewSolverTests); extern std::optional randomSeed; // tests/main.cpp @@ -150,11 +151,8 @@ const Config& TestConfigResolver::getConfig(const ModuleName& name) const return defaultConfig; } -Fixture::Fixture(bool freeze, bool prepareAutocomplete) - : sff_DebugLuauFreezeArena(FFlag::DebugLuauFreezeArena, freeze) - // In tests, we *always* want to register the extra magic functions for typechecking `string.format`. - , sff_LuauDCRMagicFunctionTypeChecker(FFlag::LuauDCRMagicFunctionTypeChecker, true) - , frontend( +Fixture::Fixture(bool prepareAutocomplete) + : frontend( &fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* runLintChecks */ false, /* randomConstraintResolutionSeed */ randomSeed} @@ -244,21 +242,21 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars return result.root; } -CheckResult Fixture::check(Mode mode, const std::string& source) +CheckResult Fixture::check(Mode mode, const std::string& source, std::optional options) { ModuleName mm = fromString(mainModuleName); configResolver.defaultConfig.mode = mode; fileResolver.source[mm] = std::move(source); frontend.markDirty(mm); - CheckResult result = frontend.check(mm); + CheckResult result = frontend.check(mm, options); return result; } -CheckResult Fixture::check(const std::string& source) +CheckResult Fixture::check(const std::string& source, std::optional options) { - return check(Mode::Strict, source); + return check(Mode::Strict, source, options); } LintResult Fixture::lint(const std::string& source, const std::optional& lintOptions) @@ -343,8 +341,11 @@ ParseResult Fixture::matchParseErrorPrefix(const std::string& source, const std: return result; } -ModulePtr Fixture::getMainModule() +ModulePtr Fixture::getMainModule(bool forAutocomplete) { + if (forAutocomplete && !FFlag::LuauSolverV2) + return frontend.moduleResolverForAutocomplete.getModule(fromString(mainModuleName)); + return frontend.moduleResolver.getModule(fromString(mainModuleName)); } @@ -367,9 +368,9 @@ std::optional Fixture::getPrimitiveType(TypeId ty) return std::nullopt; } -std::optional Fixture::getType(const std::string& name) +std::optional Fixture::getType(const std::string& name, bool forAutocomplete) { - ModulePtr module = getMainModule(); + ModulePtr module = getMainModule(forAutocomplete); REQUIRE(module); if (!module->hasModuleScope()) @@ -521,6 +522,9 @@ void Fixture::registerTestTypes() void Fixture::dumpErrors(const CheckResult& cr) { + if (hasDumpedErrors) + return; + hasDumpedErrors = true; std::string error = getErrors(cr); if (!error.empty()) MESSAGE(error); @@ -528,6 +532,9 @@ void Fixture::dumpErrors(const CheckResult& cr) void Fixture::dumpErrors(const ModulePtr& module) { + if (hasDumpedErrors) + return; + hasDumpedErrors = true; std::stringstream ss; dumpErrors(ss, module->errors); if (!ss.str().empty()) @@ -536,6 +543,9 @@ void Fixture::dumpErrors(const ModulePtr& module) void Fixture::dumpErrors(const Module& module) { + if (hasDumpedErrors) + return; + hasDumpedErrors = true; std::stringstream ss; dumpErrors(ss, module.errors); if (!ss.str().empty()) @@ -564,12 +574,14 @@ void Fixture::validateErrors(const std::vector& errors) } } -LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) +LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source, bool forAutocomplete) { - unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = - frontend.loadDefinitionFile(frontend.globals, frontend.globals.globalScope, source, "@test", /* captureComments */ false); - freeze(frontend.globals.globalTypes); + GlobalTypes& globals = forAutocomplete ? frontend.globalsForAutocomplete : frontend.globals; + unfreeze(globals.globalTypes); + LoadDefinitionFileResult result = frontend.loadDefinitionFile( + globals, globals.globalScope, source, "@test", /* captureComments */ false, /* typecheckForAutocomplete */ forAutocomplete + ); + freeze(globals.globalTypes); if (result.module) dumpErrors(result.module); @@ -577,8 +589,8 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) return result; } -BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) - : Fixture(freeze, prepareAutocomplete) +BuiltinsFixture::BuiltinsFixture(bool prepareAutocomplete) + : Fixture(prepareAutocomplete) { Luau::unfreeze(frontend.globals.globalTypes); Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); @@ -592,6 +604,72 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) Luau::freeze(frontend.globalsForAutocomplete.globalTypes); } +static std::vector parsePathExpr(const AstExpr& pathExpr) +{ + const AstExprIndexName* indexName = pathExpr.as(); + if (!indexName) + return {}; + + std::vector segments{indexName->index.value}; + + while (true) + { + if (AstExprIndexName* in = indexName->expr->as()) + { + segments.push_back(in->index.value); + indexName = in; + continue; + } + else if (AstExprGlobal* indexNameAsGlobal = indexName->expr->as()) + { + segments.push_back(indexNameAsGlobal->name.value); + break; + } + else if (AstExprLocal* indexNameAsLocal = indexName->expr->as()) + { + segments.push_back(indexNameAsLocal->local->name.value); + break; + } + else + return {}; + } + + std::reverse(segments.begin(), segments.end()); + return segments; +} + +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& segments) +{ + if (segments.empty()) + return std::nullopt; + + std::vector result; + + auto it = segments.begin(); + + if (*it == "script" && !currentModuleName.empty()) + { + result = split(currentModuleName, '/'); + ++it; + } + + for (; it != segments.end(); ++it) + { + if (result.size() > 1 && *it == "Parent") + result.pop_back(); + else + result.push_back(*it); + } + + return join(result, "/"); +} + +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& pathExpr) +{ + std::vector segments = parsePathExpr(pathExpr); + return pathExprToModuleName(currentModuleName, segments); +} + ModuleName fromString(std::string_view name) { return ModuleName(name); diff --git a/tests/Fixture.h b/tests/Fixture.h index f50431f3..60643839 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -20,8 +20,17 @@ #include "doctest.h" #include +#include #include #include +#include + +LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAG(DebugLuauForceAllNewSolverTests) + +#define DOES_NOT_PASS_NEW_SOLVER_GUARD_IMPL(line) ScopedFastFlag sff_##line{FFlag::LuauSolverV2, FFlag::DebugLuauForceAllNewSolverTests}; + +#define DOES_NOT_PASS_NEW_SOLVER_GUARD() DOES_NOT_PASS_NEW_SOLVER_GUARD_IMPL(__LINE__) namespace Luau { @@ -61,13 +70,13 @@ struct TestConfigResolver : ConfigResolver struct Fixture { - explicit Fixture(bool freeze = true, bool prepareAutocomplete = false); + explicit Fixture(bool prepareAutocomplete = false); ~Fixture(); // Throws Luau::ParseErrors if the parse fails. AstStatBlock* parse(const std::string& source, const ParseOptions& parseOptions = {}); - CheckResult check(Mode mode, const std::string& source); - CheckResult check(const std::string& source); + CheckResult check(Mode mode, const std::string& source, std::optional = std::nullopt); + CheckResult check(const std::string& source, std::optional = std::nullopt); LintResult lint(const std::string& source, const std::optional& lintOptions = {}); LintResult lintModule(const ModuleName& moduleName, const std::optional& lintOptions = {}); @@ -79,11 +88,11 @@ struct Fixture // Verify a parse error occurs and the parse error message has the specified prefix ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); - ModulePtr getMainModule(); + ModulePtr getMainModule(bool forAutocomplete = false); SourceModule* getMainSourceModule(); std::optional getPrimitiveType(TypeId ty); - std::optional getType(const std::string& name); + std::optional getType(const std::string& name, bool forAutocomplete = false); TypeId requireType(const std::string& name); TypeId requireType(const ModuleName& moduleName, const std::string& name); TypeId requireType(const ModulePtr& module, const std::string& name); @@ -98,8 +107,14 @@ struct Fixture TypeId requireTypeAlias(const std::string& name); TypeId requireExportedType(const ModuleName& moduleName, const std::string& name); - ScopedFastFlag sff_DebugLuauFreezeArena; - ScopedFastFlag sff_LuauDCRMagicFunctionTypeChecker; + // While most flags can be flipped inside the unit test, some code changes affect the state that is part of Fixture initialization + // Most often those are changes related to builtin type definitions. + // In that case, flag can be forced to 'true' using the example below: + // ScopedFastFlag sff_LuauExampleFlagDefinition{FFlag::LuauExampleFlagDefinition, true}; + + // Arena freezing marks the `TypeArena`'s underlying memory as read-only, raising an access violation whenever you mutate it. + // This is useful for tracking down violations of Luau's memory model. + ScopedFastFlag sff_DebugLuauFreezeArena{FFlag::DebugLuauFreezeArena, true}; TestFileResolver fileResolver; TestConfigResolver configResolver; @@ -123,14 +138,20 @@ struct Fixture void registerTestTypes(); - LoadDefinitionFileResult loadDefinition(const std::string& source); + LoadDefinitionFileResult loadDefinition(const std::string& source, bool forAutocomplete = false); + +private: + bool hasDumpedErrors = false; }; struct BuiltinsFixture : Fixture { - BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false); + explicit BuiltinsFixture(bool prepareAutocomplete = false); }; +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& segments); +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& pathExpr); + ModuleName fromString(std::string_view name); template @@ -159,6 +180,18 @@ std::optional linearSearchForBinding(Scope* scope, const char* name); void registerHiddenTypes(Frontend* frontend); void createSomeClasses(Frontend* frontend); +template +const E* findError(const CheckResult& result) +{ + for (const auto& e : result.errors) + { + if (auto p = get(e)) + return p; + } + + return nullptr; +} + template struct DifferFixtureGeneric : BaseFixture { @@ -274,3 +307,85 @@ using DifferFixtureWithBuiltins = DifferFixtureGeneric; } while (false) #define LUAU_CHECK_NO_ERRORS(result) LUAU_CHECK_ERROR_COUNT(0, result) + +#define LUAU_CHECK_HAS_KEY(map, key) \ + do \ + { \ + auto&& _m = (map); \ + auto&& _k = (key); \ + const size_t count = _m.count(_k); \ + CHECK_MESSAGE(count, "Map should have key \"" << _k << "\""); \ + if (!count) \ + { \ + MESSAGE("Keys: (count " << _m.size() << ")"); \ + for (const auto& [k, v] : _m) \ + { \ + MESSAGE("\tkey: " << k); \ + } \ + } \ + } while (false) + +#define LUAU_CHECK_HAS_NO_KEY(map, key) \ + do \ + { \ + auto&& _m = (map); \ + auto&& _k = (key); \ + const size_t count = _m.count(_k); \ + CHECK_MESSAGE(!count, "Map should not have key \"" << _k << "\""); \ + if (count) \ + { \ + MESSAGE("Keys: (count " << _m.size() << ")"); \ + for (const auto& [k, v] : _m) \ + { \ + MESSAGE("\tkey: " << k); \ + } \ + } \ + } while (false) + +#define LUAU_REQUIRE_ERROR(result, Type) \ + do \ + { \ + using T = Type; \ + const auto& res = (result); \ + if (!findError(res)) \ + { \ + dumpErrors(res); \ + REQUIRE_MESSAGE(false, "Expected to find " #Type " error"); \ + } \ + } while (false) + +#define LUAU_CHECK_ERROR(result, Type) \ + do \ + { \ + using T = Type; \ + const auto& res = (result); \ + if (!findError(res)) \ + { \ + dumpErrors(res); \ + CHECK_MESSAGE(false, "Expected to find " #Type " error"); \ + } \ + } while (false) + +#define LUAU_REQUIRE_NO_ERROR(result, Type) \ + do \ + { \ + using T = Type; \ + const auto& res = (result); \ + if (findError(res)) \ + { \ + dumpErrors(res); \ + REQUIRE_MESSAGE(false, "Expected to find no " #Type " error"); \ + } \ + } while (false) + +#define LUAU_CHECK_NO_ERROR(result, Type) \ + do \ + { \ + using T = Type; \ + const auto& res = (result); \ + if (findError(res)) \ + { \ + dumpErrors(res); \ + CHECK_MESSAGE(false, "Expected to find no " #Type " error"); \ + } \ + } while (false) diff --git a/tests/FragmentAutocomplete.test.cpp b/tests/FragmentAutocomplete.test.cpp new file mode 100644 index 00000000..9f7a5261 --- /dev/null +++ b/tests/FragmentAutocomplete.test.cpp @@ -0,0 +1,2105 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/FragmentAutocomplete.h" +#include "Fixture.h" +#include "Luau/Ast.h" +#include "Luau/AstQuery.h" +#include "Luau/Autocomplete.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/Frontend.h" +#include "Luau/AutocompleteTypes.h" +#include "Luau/Type.h" +#include "ScopedFlags.h" + +#include +#include +#include +#include +#include +#include + + +using namespace Luau; + +LUAU_FASTFLAG(LuauAutocompleteRefactorsForIncrementalAutocomplete) +LUAU_FASTFLAG(LuauSymbolEquality); +LUAU_FASTFLAG(LuauStoreSolverTypeOnModule); +LUAU_FASTFLAG(LexerResumesFromPosition2) +LUAU_FASTFLAG(LuauIncrementalAutocompleteCommentDetection) +LUAU_FASTINT(LuauParseErrorLimit) +LUAU_FASTFLAG(LuauCloneIncrementalModule) + +LUAU_FASTFLAG(LuauIncrementalAutocompleteBugfixes) +LUAU_FASTFLAG(LuauMixedModeDefFinderTraversesTypeOf) +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds) + +LUAU_FASTFLAG(LuauBetterReverseDependencyTracking) +LUAU_FASTFLAG(LuauAutocompleteUsesModuleForTypeCompatibility) + +static std::optional nullCallback(std::string tag, std::optional ptr, std::optional contents) +{ + return std::nullopt; +} + +static FrontendOptions getOptions() +{ + FrontendOptions options; + options.retainFullTypeGraphs = true; + + if (!FFlag::LuauSolverV2) + options.forAutocomplete = true; + + options.runLintChecks = false; + + return options; +} + +static ModuleResolver& getModuleResolver(Luau::Frontend& frontend) +{ + return FFlag::LuauSolverV2 ? frontend.moduleResolver : frontend.moduleResolverForAutocomplete; +} + +template +struct FragmentAutocompleteFixtureImpl : BaseType +{ + static_assert(std::is_base_of_v, "BaseType must be a descendant of Fixture"); + + ScopedFastFlag sffs[6] = { + {FFlag::LuauAutocompleteRefactorsForIncrementalAutocomplete, true}, + {FFlag::LuauStoreSolverTypeOnModule, true}, + {FFlag::LuauSymbolEquality, true}, + {FFlag::LexerResumesFromPosition2, true}, + {FFlag::LuauIncrementalAutocompleteBugfixes, true}, + {FFlag::LuauBetterReverseDependencyTracking, true}, + }; + + FragmentAutocompleteFixtureImpl() + : BaseType(true) + { + } + + FragmentAutocompleteAncestryResult runAutocompleteVisitor(const std::string& source, const Position& cursorPos) + { + ParseResult p = this->tryParse(source); // We don't care about parsing incomplete asts + REQUIRE(p.root); + return findAncestryForFragmentParse(p.root, cursorPos); + } + + + std::optional parseFragment( + const std::string& document, + const Position& cursorPos, + std::optional fragmentEndPosition = std::nullopt + ) + { + SourceModule* srcModule = this->getMainSourceModule(); + std::string_view srcString = document; + return Luau::parseFragment(*srcModule, srcString, cursorPos, fragmentEndPosition); + } + + CheckResult checkOldSolver(const std::string& source) + { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + return this->check(Mode::Strict, source, getOptions()); + } + + FragmentTypeCheckResult checkFragment( + const std::string& document, + const Position& cursorPos, + std::optional fragmentEndPosition = std::nullopt + ) + { + auto [_, result] = Luau::typecheckFragment(this->frontend, "MainModule", cursorPos, getOptions(), document, fragmentEndPosition); + return result; + } + + FragmentAutocompleteResult autocompleteFragment( + const std::string& document, + Position cursorPos, + std::optional fragmentEndPosition = std::nullopt + ) + { + FrontendOptions options; + return Luau::fragmentAutocomplete(this->frontend, document, "MainModule", cursorPos, getOptions(), nullCallback, fragmentEndPosition); + } + + + void autocompleteFragmentInBothSolvers( + const std::string& document, + const std::string& updated, + Position cursorPos, + std::function assertions, + std::optional fragmentEndPosition = std::nullopt + ) + { + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + this->check(document); + + FragmentAutocompleteResult result = autocompleteFragment(updated, cursorPos, fragmentEndPosition); + assertions(result); + + ScopedFastFlag _{FFlag::LuauSolverV2, false}; + this->check(document, getOptions()); + + result = autocompleteFragment(updated, cursorPos, fragmentEndPosition); + assertions(result); + } + + std::pair typecheckFragmentForModule( + const ModuleName& module, + const std::string& document, + Position cursorPos, + std::optional fragmentEndPosition = std::nullopt + ) + { + return Luau::typecheckFragment(this->frontend, module, cursorPos, getOptions(), document, fragmentEndPosition); + } + + FragmentAutocompleteResult autocompleteFragmentForModule( + const ModuleName& module, + const std::string& document, + Position cursorPos, + std::optional fragmentEndPosition = std::nullopt + ) + { + return Luau::fragmentAutocomplete(this->frontend, document, module, cursorPos, getOptions(), nullCallback, fragmentEndPosition); + } +}; + +struct FragmentAutocompleteFixture : FragmentAutocompleteFixtureImpl +{ + FragmentAutocompleteFixture() + : FragmentAutocompleteFixtureImpl() + { + addGlobalBinding(frontend.globals, "table", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globals, "math", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "table", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "math", Binding{builtinTypes->anyType}); + } +}; + +struct FragmentAutocompleteBuiltinsFixture : FragmentAutocompleteFixtureImpl +{ + FragmentAutocompleteBuiltinsFixture() + : FragmentAutocompleteFixtureImpl() + { + const std::string fakeVecDecl = R"( +declare class FakeVec + function dot(self, x: FakeVec) : FakeVec + zero : FakeVec +end +)"; + // The old solver always performs a strict mode check and populates the module resolver and globals + // for autocomplete. + // The new solver just populates the globals and the moduleResolver. + // Because these tests run in both the old solver and the new solver, and the test suite + // now picks the module resolver as appropriate in order to better mimic the studio code path, + // we have to load the definition file into both the 'globals'/'resolver' and the equivalent + // 'for autocomplete'. + loadDefinition(fakeVecDecl); + loadDefinition(fakeVecDecl, /* For Autocomplete Module */ true); + + addGlobalBinding(frontend.globals, "game", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "game", Binding{builtinTypes->anyType}); + } +}; + +// NOLINTBEGIN(bugprone-unchecked-optional-access) +TEST_SUITE_BEGIN("FragmentAutocompleteTraversalTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "just_two_locals") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +)", + {2, 11} + ); + + CHECK_EQ(3, result.ancestry.size()); + CHECK_EQ(1, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("y", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_within_scope_tracks_locals_from_previous_scope") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then + local e = y +end +)", + {4, 15} + ); + + CHECK_EQ(5, result.ancestry.size()); + CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("y", std::string(result.localStack.back()->name.value)); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("e", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cursor_that_comes_later_shouldnt_capture_locals_in_unavailable_scope") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then + local e = y +end +local z = x + x +if y == 5 then + local q = x + y + z +end +)", + {8, 23} + ); + + CHECK_EQ(6, result.ancestry.size()); + CHECK_EQ(3, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("z", std::string(result.localStack.back()->name.value)); + + AstStatLocal* local = result.nearestStatement->as(); + REQUIRE(local); + CHECK(1 == local->vars.size); + CHECK_EQ("q", std::string(local->vars.data[0]->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nearest_enclosing_statement_can_be_non_local") +{ + auto result = runAutocompleteVisitor( + R"( +local x = 4 +local y = 5 +if x == 4 then +)", + {3, 4} + ); + + CHECK_EQ(4, result.ancestry.size()); + CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + REQUIRE(result.nearestStatement); + CHECK_EQ("y", std::string(result.localStack.back()->name.value)); + + AstStatIf* ifS = result.nearestStatement->as(); + CHECK(ifS != nullptr); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_funcs_show_up_in_local_stack") +{ + auto result = runAutocompleteVisitor( + R"( +local function foo() return 4 end +local x = foo() +local function bar() return x + foo() end +)", + {3, 32} + ); + + CHECK_EQ(8, result.ancestry.size()); + CHECK_EQ(2, result.localStack.size()); + CHECK_EQ(result.localMap.size(), result.localStack.size()); + CHECK_EQ("x", std::string(result.localStack.back()->name.value)); + auto returnSt = result.nearestStatement->as(); + CHECK(returnSt != nullptr); +} + +TEST_SUITE_END(); + + +TEST_SUITE_BEGIN("FragmentAutocompleteParserTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "thrown_parse_error_leads_to_null_root") +{ + check("type A = "); + ScopedFastInt sfi{FInt::LuauParseErrorLimit, 1}; + auto fragment = parseFragment("type A = <>function<> more garbage here", Position(0, 39)); + CHECK(fragment == std::nullopt); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + check("local a ="); + auto fragment = parseFragment("local a =", Position(0, 10)); + + REQUIRE(fragment.has_value()); + CHECK_EQ("local a =", fragment->fragmentToParse); + CHECK_EQ(Location{Position{0, 0}, 9}, fragment->root->location); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "statement_in_empty_fragment_is_non_null") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + auto res = check(R"( + +)"); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = parseFragment( + R"( + +)", + Position(1, 0) + ); + REQUIRE(fragment.has_value()); + CHECK_EQ("\n", fragment->fragmentToParse); + CHECK_EQ(2, fragment->ancestry.size()); + REQUIRE(fragment->root); + CHECK_EQ(0, fragment->root->body.size); + auto statBody = fragment->root->as(); + CHECK(statBody != nullptr); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_complete_fragments") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = parseFragment( + R"( +local x = 4 +local y = 5 +local z = x + y +)", + Position{3, 15} + ); + + REQUIRE(fragment.has_value()); + + CHECK_EQ(Location{Position{2, 0}, Position{3, 15}}, fragment->root->location); + + CHECK_EQ("local y = 5\nlocal z = x + y", fragment->fragmentToParse); + CHECK_EQ(5, fragment->ancestry.size()); + REQUIRE(fragment->root); + CHECK_EQ(2, fragment->root->body.size); + auto stat = fragment->root->body.data[1]->as(); + REQUIRE(stat); + CHECK_EQ(1, stat->vars.size); + CHECK_EQ(1, stat->values.size); + CHECK_EQ("z", std::string(stat->vars.data[0]->name.value)); + + auto bin = stat->values.data[0]->as(); + REQUIRE(bin); + CHECK_EQ(AstExprBinary::Op::Add, bin->op); + + auto lhs = bin->left->as(); + auto rhs = bin->right->as(); + REQUIRE(lhs); + REQUIRE(rhs); + CHECK_EQ("x", std::string(lhs->local->name.value)); + CHECK_EQ("y", std::string(rhs->local->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_fragments_in_line") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = parseFragment( + R"( +local x = 4 +local z = x + y +local y = 5 +)", + Position{2, 15} + ); + + REQUIRE(fragment.has_value()); + + CHECK_EQ("local z = x + y", fragment->fragmentToParse); + CHECK_EQ(5, fragment->ancestry.size()); + REQUIRE(fragment->root); + CHECK_EQ(Location{Position{2, 0}, Position{2, 15}}, fragment->root->location); + CHECK_EQ(1, fragment->root->body.size); + auto stat = fragment->root->body.data[0]->as(); + REQUIRE(stat); + CHECK_EQ(1, stat->vars.size); + CHECK_EQ(1, stat->values.size); + CHECK_EQ("z", std::string(stat->vars.data[0]->name.value)); + + auto bin = stat->values.data[0]->as(); + REQUIRE(bin); + CHECK_EQ(AstExprBinary::Op::Add, bin->op); + + auto lhs = bin->left->as(); + auto rhs = bin->right->as(); + REQUIRE(lhs); + REQUIRE(rhs); + CHECK_EQ("x", std::string(lhs->local->name.value)); + CHECK_EQ("y", std::string(rhs->name.value)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_in_correct_scope") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + check(R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )"); + + auto fragment = parseFragment( + R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )", + Position{6, 0} + ); + + REQUIRE(fragment.has_value()); + + CHECK_EQ("\n ", fragment->fragmentToParse); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_single_line_fragment_override") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + auto res = check("function abc(foo: string) end"); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto callFragment = parseFragment( + R"(function abc(foo: string) end +abc("foo") +abc("bar") +)", + Position{1, 6}, + Position{1, 10} + ); + + REQUIRE(callFragment.has_value()); + + CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", callFragment->fragmentToParse); + CHECK(callFragment->nearestStatement->is()); + + CHECK_GE(callFragment->ancestry.size(), 2); + + AstNode* back = callFragment->ancestry.back(); + CHECK(back->is()); + CHECK_EQ(Position{1, 4}, back->location.begin); + CHECK_EQ(Position{1, 9}, back->location.end); + + AstNode* parent = callFragment->ancestry.rbegin()[1]; + CHECK(parent->is()); + CHECK_EQ(Position{1, 0}, parent->location.begin); + CHECK_EQ(Position{1, 10}, parent->location.end); + + + auto stringFragment = parseFragment( + R"(function abc(foo: string) end +abc("foo") +abc("bar") +)", + Position{1, 6}, + Position{1, 9} + ); + + REQUIRE(stringFragment.has_value()); + + CHECK_EQ("function abc(foo: string) end\nabc(\"foo\")", stringFragment->fragmentToParse); + CHECK(stringFragment->nearestStatement->is()); + + CHECK_GE(stringFragment->ancestry.size(), 1); + + back = stringFragment->ancestry.back(); + + auto asString = back->as(); + CHECK(asString); + + CHECK_EQ(Position{1, 4}, asString->location.begin); + CHECK_EQ(Position{1, 9}, asString->location.end); + CHECK_EQ("foo", std::string{asString->value.data}); + CHECK_EQ(AstExprConstantString::QuotedSimple, asString->quoteStyle); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_parse_multi_line_fragment_override") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + auto res = check("function abc(foo: string) end"); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = parseFragment( + R"(function abc(foo: string) end +abc( +"foo" +) +abc("bar") +)", + Position{2, 5}, + Position{3, 1} + ); + + REQUIRE(fragment.has_value()); + + CHECK_EQ("function abc(foo: string) end\nabc(\n\"foo\"\n)", fragment->fragmentToParse); + CHECK(fragment->nearestStatement->is()); + + CHECK_GE(fragment->ancestry.size(), 2); + + AstNode* back = fragment->ancestry.back(); + CHECK(back->is()); + CHECK_EQ(Position{2, 0}, back->location.begin); + CHECK_EQ(Position{2, 5}, back->location.end); + + AstNode* parent = fragment->ancestry.rbegin()[1]; + CHECK(parent->is()); + CHECK_EQ(Position{1, 0}, parent->location.begin); + CHECK_EQ(Position{3, 1}, parent->location.end); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "respects_frontend_options") +{ + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + + std::string source = R"( +local tbl = { abc = 1234} +t +)"; + fileResolver.source["game/A"] = source; + + FrontendOptions opts; + opts.forAutocomplete = true; + + frontend.check("game/A", opts); + CHECK_NE(frontend.moduleResolverForAutocomplete.getModule("game/A"), nullptr); + CHECK_EQ(frontend.moduleResolver.getModule("game/A"), nullptr); + + + FragmentAutocompleteResult result = Luau::fragmentAutocomplete(frontend, source, "game/A", Position{2, 1}, opts, nullCallback); + CHECK_EQ("game/A", result.incrementalModule->name); + CHECK_NE(frontend.moduleResolverForAutocomplete.getModule("game/A"), nullptr); + CHECK_EQ(frontend.moduleResolver.getModule("game/A"), nullptr); +} + +TEST_SUITE_END(); +// NOLINTEND(bugprone-unchecked-optional-access) + +TEST_SUITE_BEGIN("FragmentAutocompleteTypeCheckerTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_typecheck_simple_fragment") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = checkFragment( + R"( +local x = 4 +local y = 5 +local z = x + y +)", + Position{3, 15} + ); + + auto opt = linearSearchForBinding(fragment.freshScope.get(), "z"); + REQUIRE(opt); + CHECK_EQ("number", toString(*opt)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_typecheck_fragment_inserted_inline") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + auto res = check( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + auto fragment = checkFragment( + R"( +local x = 4 +local z = x +local y = 5 +)", + Position{2, 11} + ); + + auto correct = linearSearchForBinding(fragment.freshScope.get(), "z"); + REQUIRE(correct); + CHECK_EQ("number", toString(*correct)); +} + +TEST_SUITE_END(); + + +TEST_SUITE_BEGIN("MixedModeTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "mixed_mode_basic_example_append") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + auto res = checkOldSolver( + R"( +local x = 4 +local y = 5 +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = checkFragment( + R"( +local x = 4 +local y = 5 +local z = x + y +)", + Position{3, 15} + ); + + auto opt = linearSearchForBinding(fragment.freshScope.get(), "z"); + REQUIRE(opt); + CHECK_EQ("number", toString(*opt)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "mixed_mode_basic_example_inlined") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + auto res = checkOldSolver( + R"( +local x = 4 +local y = 5 +)" + ); + + auto fragment = checkFragment( + R"( +local x = 4 +local z = x +local y = 5 +)", + Position{2, 11} + ); + + auto correct = linearSearchForBinding(fragment.freshScope.get(), "z"); + REQUIRE(correct); + CHECK_EQ("number", toString(*correct)); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "mixed_mode_can_autocomplete_simple_property_access") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + auto res = checkOldSolver( + R"( +local tbl = { abc = 1234} +)" + ); + + LUAU_REQUIRE_NO_ERRORS(res); + + auto fragment = autocompleteFragment( + R"( +local tbl = { abc = 1234} +tbl. +)", + Position{2, 5} + ); + + LUAU_ASSERT(fragment.freshScope); + + CHECK_EQ(1, fragment.acResults.entryMap.size()); + CHECK(fragment.acResults.entryMap.count("abc")); + CHECK_EQ(AutocompleteContext::Property, fragment.acResults.context); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "typecheck_fragment_handles_stale_module") +{ + const std::string sourceName = "MainModule"; + fileResolver.source[sourceName] = "local x = 5"; + + CheckResult checkResult = frontend.check(sourceName, getOptions()); + LUAU_REQUIRE_NO_ERRORS(checkResult); + + auto [result, _] = typecheckFragmentForModule(sourceName, fileResolver.source[sourceName], Luau::Position(0, 0)); + CHECK_EQ(result, FragmentTypeCheckStatus::Success); + + frontend.markDirty(sourceName); + frontend.parse(sourceName); + + CHECK_NE(frontend.getSourceModule(sourceName), nullptr); + + auto [result2, __] = typecheckFragmentForModule(sourceName, fileResolver.source[sourceName], Luau::Position(0, 0)); + CHECK_EQ(result2, FragmentTypeCheckStatus::SkipAutocomplete); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "typecheck_fragment_handles_unusable_module") +{ + const std::string sourceA = "MainModule"; + fileResolver.source[sourceA] = R"( +local Modules = game:GetService('Gui').Modules +local B = require(Modules.B) +return { hello = B } +)"; + + const std::string sourceB = "game/Gui/Modules/B"; + fileResolver.source[sourceB] = R"(return {hello = "hello"})"; + + CheckResult result = frontend.check(sourceA, getOptions()); + CHECK(!frontend.isDirty(sourceA, getOptions().forAutocomplete)); + + std::weak_ptr weakModule = getModuleResolver(frontend).getModule(sourceB); + REQUIRE(!weakModule.expired()); + + frontend.markDirty(sourceB); + CHECK(frontend.isDirty(sourceA, getOptions().forAutocomplete)); + + frontend.check(sourceB, getOptions()); + CHECK(weakModule.expired()); + + auto [status, _] = typecheckFragmentForModule(sourceA, fileResolver.source[sourceA], Luau::Position(0, 0)); + CHECK_EQ(status, FragmentTypeCheckStatus::SkipAutocomplete); + + auto [status2, _2] = typecheckFragmentForModule(sourceB, fileResolver.source[sourceB], Luau::Position(3, 20)); + CHECK_EQ(status2, FragmentTypeCheckStatus::Success); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("FragmentAutocompleteTests"); + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "multiple_fragment_autocomplete") +{ + ToStringOptions opt; + opt.exhaustive = true; + opt.exhaustive = true; + opt.functionTypeArguments = true; + opt.maxTableLength = 0; + opt.maxTypeLength = 0; + + auto checkAndExamine = [&](const std::string& src, const std::string& idName, const std::string& idString) + { + check(src, getOptions()); + auto id = getType(idName, true); + LUAU_ASSERT(id); + CHECK_EQ(Luau::toString(*id, opt), idString); + }; + + auto getTypeFromModule = [](ModulePtr module, const std::string& name) -> std::optional + { + if (!module->hasModuleScope()) + return std::nullopt; + return lookupName(module->getModuleScope(), name); + }; + + auto fragmentACAndCheck = [&](const std::string& updated, + const Position& pos, + const std::string& idName, + const std::string& srcIdString, + const std::string& fragIdString) + { + FragmentAutocompleteResult result = autocompleteFragment(updated, pos, std::nullopt); + auto fragId = getTypeFromModule(result.incrementalModule, idName); + LUAU_ASSERT(fragId); + CHECK_EQ(Luau::toString(*fragId, opt), fragIdString); + + auto srcId = getType(idName, true); + LUAU_ASSERT(srcId); + CHECK_EQ(Luau::toString(*srcId, opt), srcIdString); + }; + + const std::string source = R"(local module = {} +f +return module)"; + + const std::string updated1 = R"(local module = {} +function module.a +return module)"; + + const std::string updated2 = R"(local module = {} +function module.ab +return module)"; + + { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + ScopedFastFlag sff2{FFlag::LuauCloneIncrementalModule, true}; + ScopedFastFlag sff3{FFlag::LuauFreeTypesMustHaveBounds, true}; + checkAndExamine(source, "module", "{ }"); + fragmentACAndCheck(updated1, Position{1, 17}, "module", "{ }", "{ a: (%error-id%: unknown) -> () }"); + fragmentACAndCheck(updated2, Position{1, 18}, "module", "{ }", "{ ab: (%error-id%: unknown) -> () }"); + } + { + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + checkAndExamine(source, "module", "{ }"); + // [TODO] CLI-140762 Fragment autocomplete still doesn't return correct result when LuauSolverV2 is on + return; + fragmentACAndCheck(updated1, Position{1, 17}, "module", "{ }", "{ a: (%error-id%: unknown) -> () }"); + fragmentACAndCheck(updated2, Position{1, 18}, "module", "{ }", "{ ab: (%error-id%: unknown) -> () }"); + } +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_simple_property_access") +{ + + const std::string source = R"( +local tbl = { abc = 1234} +)"; + const std::string updated = R"( +local tbl = { abc = 1234} +tbl. +)"; + + autocompleteFragmentInBothSolvers( + source, + updated, + Position{2, 5}, + [](FragmentAutocompleteResult& fragment) + { + LUAU_ASSERT(fragment.freshScope); + + CHECK_EQ(1, fragment.acResults.entryMap.size()); + CHECK(fragment.acResults.entryMap.count("abc")); + CHECK_EQ(AutocompleteContext::Property, fragment.acResults.context); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "can_autocomplete_nested_property_access") +{ + const std::string source = R"( +local tbl = { abc = { def = 1234, egh = false } } +)"; + const std::string updated = R"( +local tbl = { abc = { def = 1234, egh = false } } +tbl.abc. +)"; + autocompleteFragmentInBothSolvers( + source, + updated, + Position{2, 8}, + [](FragmentAutocompleteResult& fragment) + { + LUAU_ASSERT(fragment.freshScope); + + CHECK_EQ(2, fragment.acResults.entryMap.size()); + CHECK(fragment.acResults.entryMap.count("def")); + CHECK(fragment.acResults.entryMap.count("egh")); + CHECK_EQ(fragment.acResults.context, AutocompleteContext::Property); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "multiple_functions_complex") +{ + const std::string text = R"( local function f1(a1) + local l1 = 1;" + g1 = 1;" +end + +local function f2(a2) + local l2 = 1; + g2 = 1; +end +)"; + + autocompleteFragmentInBothSolvers( + text, + text, + Position{0, 0}, + [](FragmentAutocompleteResult& fragment) + { + auto strings = fragment.acResults.entryMap; + CHECK(strings.count("f1") == 0); + CHECK(strings.count("a1") == 0); + CHECK(strings.count("l1") == 0); + CHECK(strings.count("g1") != 0); + CHECK(strings.count("f2") == 0); + CHECK(strings.count("a2") == 0); + CHECK(strings.count("l2") == 0); + CHECK(strings.count("g2") != 0); + } + ); + + autocompleteFragmentInBothSolvers( + text, + text, + Position{0, 22}, + [](FragmentAutocompleteResult& fragment) + { + auto strings = fragment.acResults.entryMap; + CHECK(strings.count("f1") != 0); + CHECK(strings.count("a1") != 0); + CHECK(strings.count("l1") == 0); + CHECK(strings.count("g1") != 0); + CHECK(strings.count("f2") == 0); + CHECK(strings.count("a2") == 0); + CHECK(strings.count("l2") == 0); + CHECK(strings.count("g2") != 0); + } + ); + + autocompleteFragmentInBothSolvers( + text, + text, + Position{1, 17}, + [](FragmentAutocompleteResult& fragment) + { + auto strings = fragment.acResults.entryMap; + CHECK(strings.count("f1") != 0); + CHECK(strings.count("a1") != 0); + CHECK(strings.count("l1") != 0); + CHECK(strings.count("g1") != 0); + CHECK(strings.count("f2") == 0); + CHECK(strings.count("a2") == 0); + CHECK(strings.count("l2") == 0); + CHECK(strings.count("g2") != 0); + } + ); + + autocompleteFragmentInBothSolvers( + text, + text, + Position{2, 11}, + [](FragmentAutocompleteResult& fragment) + { + auto strings = fragment.acResults.entryMap; + CHECK(strings.count("f1") != 0); + CHECK(strings.count("a1") != 0); + CHECK(strings.count("l1") != 0); + CHECK(strings.count("g1") != 0); + CHECK(strings.count("f2") == 0); + CHECK(strings.count("a2") == 0); + CHECK(strings.count("l2") == 0); + CHECK(strings.count("g2") != 0); + } + ); + + autocompleteFragmentInBothSolvers( + text, + text, + Position{4, 0}, + [](FragmentAutocompleteResult& fragment) + { + auto strings = fragment.acResults.entryMap; + CHECK(strings.count("f1") != 0); + // FIXME: RIDE-11123: This should be zero counts of `a1`. + CHECK(strings.count("a1") != 0); + CHECK(strings.count("l1") == 0); + CHECK(strings.count("g1") != 0); + CHECK(strings.count("f2") == 0); + CHECK(strings.count("a2") == 0); + CHECK(strings.count("l2") == 0); + CHECK(strings.count("g2") != 0); + } + ); + + autocompleteFragmentInBothSolvers( + text, + text, + Position{6, 17}, + [](FragmentAutocompleteResult& fragment) + { + auto strings = fragment.acResults.entryMap; + CHECK(strings.count("f1") != 0); + CHECK(strings.count("a1") == 0); + CHECK(strings.count("l1") == 0); + CHECK(strings.count("g1") != 0); + CHECK(strings.count("f2") != 0); + CHECK(strings.count("a2") != 0); + CHECK(strings.count("l2") != 0); + CHECK(strings.count("g2") != 0); + } + ); + + autocompleteFragmentInBothSolvers( + text, + text, + Position{8, 4}, + [](FragmentAutocompleteResult& fragment) + { + auto strings = fragment.acResults.entryMap; + CHECK(strings.count("f1") != 0); + CHECK(strings.count("a1") == 0); + CHECK(strings.count("l1") == 0); + CHECK(strings.count("g1") != 0); + CHECK(strings.count("f2") != 0); + // FIXME: RIDE-11123: This should be zero counts of `a2`. + CHECK(strings.count("a2") != 0); + CHECK(strings.count("l2") == 0); + CHECK(strings.count("g2") != 0); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "inline_autocomplete_picks_the_right_scope") +{ + const std::string source = R"( +type Table = { a: number, b: number } +do + type Table = { x: string, y: string } +end +)"; + + const std::string updated = R"( +type Table = { a: number, b: number } +do + type Table = { x: string, y: string } + local a : T +end +)"; + + autocompleteFragmentInBothSolvers( + source, + updated, + Position{4, 15}, + [](FragmentAutocompleteResult& fragment) + { + LUAU_ASSERT(fragment.freshScope); + + REQUIRE(fragment.acResults.entryMap.count("Table")); + REQUIRE(fragment.acResults.entryMap["Table"].type); + const TableType* tv = get(follow(*fragment.acResults.entryMap["Table"].type)); + REQUIRE(tv); + CHECK(tv->props.count("x")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nested_recursive_function") +{ + const std::string source = R"( +function foo() +end +)"; + autocompleteFragmentInBothSolvers( + source, + source, + Position{2, 0}, + [](FragmentAutocompleteResult& fragment) + { + CHECK(fragment.acResults.entryMap.count("foo")); + CHECK_EQ(AutocompleteContext::Statement, fragment.acResults.context); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "string_literal_with_override") +{ + const std::string source = R"( +function foo(bar: string) end +foo("abc") +)"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{2, 6}, + [](FragmentAutocompleteResult& fragment) + { + CHECK(fragment.acResults.entryMap.empty()); + CHECK_EQ(AutocompleteContext::String, fragment.acResults.context); + }, + Position{2, 9} + ); +} + +// Start compatibility tests! + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "empty_program") +{ + autocompleteFragmentInBothSolvers( + "", + "", + Position{0, 1}, + [](FragmentAutocompleteResult& frag) + { + auto ac = frag.acResults; + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_initializer") +{ + const std::string source = "local a ="; + autocompleteFragmentInBothSolvers( + source, + source, + Position{0, 9}, + [](FragmentAutocompleteResult& frag) + { + auto ac = frag.acResults; + + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "leave_numbers_alone") +{ + const std::string source = "local a = 3."; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{0, 12}, + [](FragmentAutocompleteResult& frag) + { + auto ac = frag.acResults; + CHECK(ac.entryMap.empty()); + CHECK_EQ(ac.context, AutocompleteContext::Unknown); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "user_defined_globals") +{ + const std::string source = "local myLocal = 4; "; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{0, 18}, + [](FragmentAutocompleteResult& frag) + { + auto ac = frag.acResults; + + CHECK(ac.entryMap.count("myLocal")); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "dont_suggest_local_before_its_definition") +{ + const std::string source = R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )"; + + // autocomplete after abc but before myInnerLocal + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 0}, + [](FragmentAutocompleteResult& fragment) + { + auto ac = fragment.acResults; + CHECK(ac.entryMap.count("myLocal")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal"); + } + ); + // autocomplete after my inner local + autocompleteFragmentInBothSolvers( + source, + source, + Position{4, 0}, + [](FragmentAutocompleteResult& fragment) + { + auto ac = fragment.acResults; + CHECK(ac.entryMap.count("myLocal")); + CHECK(ac.entryMap.count("myInnerLocal")); + } + ); + + // autocomplete after abc, but don't include myInnerLocal(in the hidden scope) + autocompleteFragmentInBothSolvers( + source, + source, + Position{6, 0}, + [](FragmentAutocompleteResult& fragment) + { + auto ac = fragment.acResults; + CHECK(ac.entryMap.count("myLocal")); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "myInnerLocal"); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "nested_recursive_function") +{ + const std::string source = R"( + local function outer() + local function inner() + end + end + )"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 0}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK(ac.entryMap.count("inner")); + CHECK(ac.entryMap.count("outer")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "user_defined_local_functions_in_own_definition") +{ + const std::string source = R"( + local function abc() + + end + )"; + // Autocomplete inside of abc + autocompleteFragmentInBothSolvers( + source, + source, + Position{2, 0}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK(ac.entryMap.count("abc")); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "global_functions_are_not_scoped_lexically") +{ + const std::string source = R"( + if true then + function abc() + + end + end + )"; + autocompleteFragmentInBothSolvers( + source, + source, + Position{6, 0}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("abc")); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "local_functions_fall_out_of_scope") +{ + const std::string source = R"( + if true then + local function abc() + + end + end + )"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{6, 0}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK_NE(0, ac.entryMap.size()); + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "abc"); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "function_parameters") +{ + const std::string source = R"( + function abc(test) + + end + )"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 0}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK(ac.entryMap.count("test")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "unsealed_table") +{ + const std::string source = R"( + local tbl = {} + tbl.prop = 5 + tbl. + )"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 12}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("prop")); + CHECK_EQ(ac.context, AutocompleteContext::Property); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "unsealed_table_2") +{ + const std::string source = R"( + local tbl = {} + local inner = { prop = 5 } + tbl.inner = inner + tbl.inner. + )"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{4, 18}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("prop")); + CHECK_EQ(ac.context, AutocompleteContext::Property); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "cyclic_table") +{ + const std::string source = R"( + local abc = {} + local def = { abc = abc } + abc.def = def + abc.def. + )"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{4, 16}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK(ac.entryMap.count("abc")); + CHECK_EQ(ac.context, AutocompleteContext::Property); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "table_union") +{ + const std::string source = R"( + type t1 = { a1 : string, b2 : number } + type t2 = { b2 : string, c3 : string } + function func(abc : t1 | t2) + + end + )"; + const std::string updated = R"( + type t1 = { a1 : string, b2 : number } + type t2 = { b2 : string, c3 : string } + function func(abc : t1 | t2) + abc. + end + )"; + + autocompleteFragmentInBothSolvers( + source, + updated, + Position{4, 16}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("b2")); + CHECK_EQ(ac.context, AutocompleteContext::Property); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "table_intersection") +{ + const std::string source = R"( + type t1 = { a1 : string, b2 : number } + type t2 = { b2 : number, c3 : string } + function func(abc : t1 & t2) + + end + )"; + const std::string updated = R"( + type t1 = { a1 : string, b2 : number } + type t2 = { b2 : number, c3 : string } + function func(abc : t1 & t2) + abc. + end + )"; + + autocompleteFragmentInBothSolvers( + source, + updated, + Position{4, 16}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK_EQ(3, ac.entryMap.size()); + CHECK(ac.entryMap.count("a1")); + CHECK(ac.entryMap.count("b2")); + CHECK(ac.entryMap.count("c3")); + CHECK_EQ(ac.context, AutocompleteContext::Property); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "get_suggestions_for_the_very_start_of_the_script") +{ + const std::string source = R"( + + function aaa() end + )"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{0, 0}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Statement); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "studio_ice_1") +{ + const std::string source = R"( +--Woop +@native +local function test() + +end +)"; + + const std::string updated = R"( +--Woop +@native +local function test() + +end +function a +)"; + autocompleteFragmentInBothSolvers(source, updated, Position{6, 10}, [](FragmentAutocompleteResult& result) {}); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "method_call_inside_function_body") +{ + const std::string source = R"( + local game = { GetService=function(s) return 'hello' end } + + function a() + + end + )"; + + const std::string updated = R"( + local game = { GetService=function(s) return 'hello' end } + + function a() + game: + end + )"; + + autocompleteFragmentInBothSolvers( + source, + updated, + Position{4, 17}, + [](FragmentAutocompleteResult& result) + { + auto ac = result.acResults; + CHECK_NE(0, ac.entryMap.size()); + + LUAU_CHECK_HAS_NO_KEY(ac.entryMap, "math"); + CHECK_EQ(ac.context, AutocompleteContext::Property); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "tbl_function_parameter") +{ + const std::string source = R"( +--!strict +type Foo = {x : number, y : number} +local function func(abc : Foo) + abc. +end +)"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{4, 7}, + [](FragmentAutocompleteResult& result) + { + CHECK_EQ(2, result.acResults.entryMap.size()); + CHECK(result.acResults.entryMap.count("x")); + CHECK(result.acResults.entryMap.count("y")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "tbl_local_function_parameter") +{ + const std::string source = R"( +--!strict +type Foo = {x : number, y : number} +local function func(abc : Foo) + abc. +end +)"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{4, 7}, + [](FragmentAutocompleteResult& result) + { + CHECK_EQ(2, result.acResults.entryMap.size()); + CHECK(result.acResults.entryMap.count("x")); + CHECK(result.acResults.entryMap.count("y")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "vec3_function_parameter") +{ + const std::string source = R"( +--!strict +local function func(abc : FakeVec) + abc. +end +)"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 7}, + [](FragmentAutocompleteResult& result) + { + CHECK_EQ(2, result.acResults.entryMap.size()); + CHECK(result.acResults.entryMap.count("zero")); + CHECK(result.acResults.entryMap.count("dot")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "vec3_local_function_parameter") +{ + const std::string source = R"( +--!strict +local function func(abc : FakeVec) + abc. +end +)"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 7}, + [](FragmentAutocompleteResult& result) + { + CHECK_EQ(2, result.acResults.entryMap.size()); + CHECK(result.acResults.entryMap.count("zero")); + CHECK(result.acResults.entryMap.count("dot")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "function_parameter_not_recommending_out_of_scope_argument") +{ + const std::string source = R"( +--!strict +local function foo(abd: FakeVec) +end +local function bar(abc : FakeVec) + a +end +)"; + + autocompleteFragmentInBothSolvers( + source, + source, + Position{5, 5}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.count("abc")); + CHECK(!result.acResults.entryMap.count("abd")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "bad_range") +{ + const std::string source = R"( +l +)"; + const std::string updated = R"( +local t = 1 +t +)"; + + autocompleteFragmentInBothSolvers( + source, + updated, + Position{2, 1}, + [](FragmentAutocompleteResult& result) + { + auto opt = linearSearchForBinding(result.freshScope, "t"); + REQUIRE(opt); + CHECK_EQ("number", toString(*opt)); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "no_recs_for_comments_simple") +{ + const std::string source = R"( +-- sel +-- retur +-- fo +-- if +-- end +-- the +)"; + ScopedFastFlag sff{FFlag::LuauIncrementalAutocompleteCommentDetection, true}; + autocompleteFragmentInBothSolvers( + source, + source, + Position{4, 6}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "no_recs_for_comments_blocks") +{ + const std::string source = R"( +--[[ +comment 1 +]] local +-- [[ comment 2]] +-- +-- sdfsdfsdf +--[[comment 3]] +--[[ +foo +bar +baz +]] +)"; + ScopedFastFlag sff{FFlag::LuauIncrementalAutocompleteCommentDetection, true}; + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 0}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 2}, + [](FragmentAutocompleteResult& result) + { + CHECK(!result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{8, 6}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{10, 0}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "no_recs_for_comments") +{ + const std::string source = R"( +-- sel +-- retur +-- fo +--[[ sel ]] +local -- hello +)"; + ScopedFastFlag sff{FFlag::LuauIncrementalAutocompleteCommentDetection, true}; + autocompleteFragmentInBothSolvers( + source, + source, + Position{1, 7}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{2, 9}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{3, 6}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{4, 9}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{5, 6}, + [](FragmentAutocompleteResult& result) + { + CHECK(!result.acResults.entryMap.empty()); + } + ); + + autocompleteFragmentInBothSolvers( + source, + source, + Position{5, 14}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "no_recs_for_comments_in_incremental_fragment") +{ + const std::string source = R"( +local x = 5 +if x == 5 +)"; + const std::string updated = R"( +local x = 5 +if x == 5 then -- a comment +)"; + ScopedFastFlag sff{FFlag::LuauIncrementalAutocompleteCommentDetection, true}; + autocompleteFragmentInBothSolvers( + source, + updated, + Position{2, 28}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_handles_parse_errors") +{ + + ScopedFastInt sfi{FInt::LuauParseErrorLimit, 1}; + const std::string source = R"( + +)"; + const std::string updated = R"( +type A = <>random non code text here +)"; + + autocompleteFragmentInBothSolvers( + source, + updated, + Position{1, 38}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.empty()); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_handles_stale_module") +{ + const std::string sourceName = "MainModule"; + fileResolver.source[sourceName] = "local x = 5"; + + frontend.check(sourceName, getOptions()); + frontend.markDirty(sourceName); + frontend.parse(sourceName); + + FragmentAutocompleteResult result = autocompleteFragmentForModule(sourceName, fileResolver.source[sourceName], Luau::Position(0, 0)); + CHECK(result.acResults.entryMap.empty()); + CHECK_EQ(result.incrementalModule, nullptr); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "require_tracing") +{ + fileResolver.source["MainModule/A"] = R"( +return { x = 0 } + )"; + + fileResolver.source["MainModule"] = R"( +local result = require(script.A) +local x = 1 + result. + )"; + + autocompleteFragmentInBothSolvers( + fileResolver.source["MainModule"], + fileResolver.source["MainModule"], + Position{2, 21}, + [](FragmentAutocompleteResult& result) + { + CHECK(result.acResults.entryMap.size() == 1); + CHECK(result.acResults.entryMap.count("x")); + } + ); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "fragment_ac_must_traverse_typeof_and_not_ice") +{ + // This test ensures that we traverse typeof expressions for defs that are being referred to in the fragment + // In this case, we want to ensure we populate the incremental environment with the reference to `m` + // Without this, we would ice as we will refer to the local `m` before it's declaration + ScopedFastFlag sff{FFlag::LuauMixedModeDefFinderTraversesTypeOf, true}; + const std::string source = R"( +--!strict +local m = {} +-- and here +function m:m1() end +type nt = typeof(m) + +return m +)"; + const std::string updated = R"( +--!strict +local m = {} +-- and here +function m:m1() end +type nt = typeof(m) +l +return m +)"; + + autocompleteFragmentInBothSolvers(source, updated, Position{6, 2}, [](auto& _) {}); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "generalization_crash_when_old_solver_freetypes_have_no_bounds_set") +{ + ScopedFastFlag sff{FFlag::LuauFreeTypesMustHaveBounds, true}; + const std::string source = R"( +local UserInputService = game:GetService("UserInputService"); + +local Camera = workspace.CurrentCamera; + +UserInputService.InputBegan:Connect(function(Input) + if (Input.KeyCode == Enum.KeyCode.One) then + local Up = Input.Foo + local Vector = -(Up:Unit) + end +end) +)"; + + const std::string dest = R"( +local UserInputService = game:GetService("UserInputService"); + +local Camera = workspace.CurrentCamera; + +UserInputService.InputBegan:Connect(function(Input) + if (Input.KeyCode == Enum.KeyCode.One) then + local Up = Input.Foo + local Vector = -(Up:Unit()) + end +end) +)"; + + autocompleteFragmentInBothSolvers(source, dest, Position{8, 36}, [](auto& _) {}); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_ensures_memory_isolation") +{ + ScopedFastFlag sff{FFlag::LuauCloneIncrementalModule, true}; + ToStringOptions opt; + opt.exhaustive = true; + opt.exhaustive = true; + opt.functionTypeArguments = true; + opt.maxTableLength = 0; + opt.maxTypeLength = 0; + + auto checkAndExamine = [&](const std::string& src, const std::string& idName, const std::string& idString) + { + check(src, getOptions()); + auto id = getType(idName, true); + LUAU_ASSERT(id); + CHECK_EQ(Luau::toString(*id, opt), idString); + }; + + auto getTypeFromModule = [](ModulePtr module, const std::string& name) -> std::optional + { + if (!module->hasModuleScope()) + return std::nullopt; + return lookupName(module->getModuleScope(), name); + }; + + auto fragmentACAndCheck = [&](const std::string& updated, const Position& pos, const std::string& idName) + { + FragmentAutocompleteResult result = autocompleteFragment(updated, pos, std::nullopt); + auto fragId = getTypeFromModule(result.incrementalModule, idName); + LUAU_ASSERT(fragId); + + auto srcId = getType(idName, true); + LUAU_ASSERT(srcId); + + CHECK((*fragId)->owningArena != (*srcId)->owningArena); + CHECK(&(result.incrementalModule->internalTypes) == (*fragId)->owningArena); + }; + + const std::string source = R"(local module = {} +f +return module)"; + + const std::string updated1 = R"(local module = {} +function module.a +return module)"; + + const std::string updated2 = R"(local module = {} +function module.ab +return module)"; + + { + ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + checkAndExamine(source, "module", "{ }"); + // [TODO] CLI-140762 we shouldn't mutate stale module in autocompleteFragment + // early return since the following checking will fail, which it shouldn't! + fragmentACAndCheck(updated1, Position{1, 17}, "module"); + fragmentACAndCheck(updated2, Position{1, 18}, "module"); + } + + { + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + checkAndExamine(source, "module", "{ }"); + // [TODO] CLI-140762 we shouldn't mutate stale module in autocompleteFragment + // early return since the following checking will fail, which it shouldn't! + fragmentACAndCheck(updated1, Position{1, 17}, "module"); + fragmentACAndCheck(updated2, Position{1, 18}, "module"); + } +} + +TEST_CASE_FIXTURE(FragmentAutocompleteFixture, "fragment_autocomplete_shouldnt_crash_on_cross_module_mutation") +{ + ScopedFastFlag sff{FFlag::LuauCloneIncrementalModule, true}; + const std::string source = R"(local module = {} +function module. +return module +)"; + + const std::string updated = R"(local module = {} +function module.f +return module +)"; + + autocompleteFragmentInBothSolvers(source, updated, Position{1, 18}, [](FragmentAutocompleteResult& result) {}); +} + +TEST_CASE_FIXTURE(FragmentAutocompleteBuiltinsFixture, "ice_caused_by_mixed_mode_use") +{ + ScopedFastFlag sff{FFlag::LuauAutocompleteUsesModuleForTypeCompatibility, true}; + const std::string source = "--[[\n\tPackage link auto-generated by Rotriever\n]]\nlocal PackageIndex = script.Parent._Index\n\nlocal Package = " + "require(PackageIndex[\"ReactOtter\"][\"ReactOtter\"])\n\nexport type Goal = Package.Goal\nexport type SpringOptions " + "= Package.SpringOptions\n\n\nreturn Pa"; + autocompleteFragmentInBothSolvers(source, source, Position{11,9}, [](auto& _){ + + }); +} + +TEST_SUITE_END(); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 88f91708..9d6cfa74 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/DenseHash.h" #include "Luau/Frontend.h" #include "Luau/RequireTracer.h" @@ -12,9 +13,11 @@ using namespace Luau; -LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(DebugLuauFreezeArena); LUAU_FASTFLAG(DebugLuauMagicTypes); +LUAU_FASTFLAG(LuauSelectivelyRetainDFGArena) +LUAU_FASTFLAG(LuauBetterReverseDependencyTracking); namespace { @@ -312,9 +315,9 @@ TEST_CASE_FIXTURE(FrontendFixture, "nocheck_cycle_used_by_checked") REQUIRE(bool(cExports)); if (FFlag::LuauSolverV2) - CHECK_EQ("{ a: { hello: any }, b: { hello: any } }", toString(*cExports)); + CHECK("{ a: { hello: any }, b: { hello: any } }" == toString(*cExports)); else - CHECK_EQ("{| a: any, b: any |}", toString(*cExports)); + CHECK("{| a: {| hello: any |}, b: {| hello: any |} |}" == toString(*cExports)); } TEST_CASE_FIXTURE(FrontendFixture, "cycle_detection_disabled_in_nocheck") @@ -1375,7 +1378,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "checked_modules_have_the_correct_mode") TEST_CASE_FIXTURE(FrontendFixture, "separate_caches_for_autocomplete") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); fileResolver.source["game/A"] = R"( --!nonstrict @@ -1454,4 +1457,319 @@ TEST_CASE_FIXTURE(Fixture, "exported_tables_have_position_metadata") CHECK(Location{Position{1, 17}, Position{1, 20}} == prop.location); } +TEST_CASE_FIXTURE(FrontendFixture, "get_required_scripts") +{ + fileResolver.source["game/workspace/MyScript"] = R"( + local MyModuleScript = require(game.workspace.MyModuleScript) + local MyModuleScript2 = require(game.workspace.MyModuleScript2) + MyModuleScript.myPrint() + )"; + + fileResolver.source["game/workspace/MyModuleScript"] = R"( + local module = {} + function module.myPrint() + print("Hello World") + end + return module + )"; + + fileResolver.source["game/workspace/MyModuleScript2"] = R"( + local module = {} + return module + )"; + + // isDirty(name) is true, getRequiredScripts should not hit the cache. + frontend.markDirty("game/workspace/MyScript"); + std::vector requiredScripts = frontend.getRequiredScripts("game/workspace/MyScript"); + REQUIRE(requiredScripts.size() == 2); + CHECK(requiredScripts[0] == "game/workspace/MyModuleScript"); + CHECK(requiredScripts[1] == "game/workspace/MyModuleScript2"); + + // Call frontend.check first, then getRequiredScripts should hit the cache because isDirty(name) is false. + frontend.check("game/workspace/MyScript"); + requiredScripts = frontend.getRequiredScripts("game/workspace/MyScript"); + REQUIRE(requiredScripts.size() == 2); + CHECK(requiredScripts[0] == "game/workspace/MyModuleScript"); + CHECK(requiredScripts[1] == "game/workspace/MyModuleScript2"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "get_required_scripts_dirty") +{ + fileResolver.source["game/workspace/MyScript"] = R"( + print("Hello World") + )"; + + fileResolver.source["game/workspace/MyModuleScript"] = R"( + local module = {} + function module.myPrint() + print("Hello World") + end + return module + )"; + + frontend.check("game/workspace/MyScript"); + std::vector requiredScripts = frontend.getRequiredScripts("game/workspace/MyScript"); + REQUIRE(requiredScripts.size() == 0); + + fileResolver.source["game/workspace/MyScript"] = R"( + local MyModuleScript = require(game.workspace.MyModuleScript) + MyModuleScript.myPrint() + )"; + + requiredScripts = frontend.getRequiredScripts("game/workspace/MyScript"); + REQUIRE(requiredScripts.size() == 0); + + frontend.markDirty("game/workspace/MyScript"); + requiredScripts = frontend.getRequiredScripts("game/workspace/MyScript"); + REQUIRE(requiredScripts.size() == 1); + CHECK(requiredScripts[0] == "game/workspace/MyModuleScript"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "check_module_references_allocator") +{ + fileResolver.source["game/workspace/MyScript"] = R"( + print("Hello World") + )"; + + frontend.check("game/workspace/MyScript"); + + ModulePtr module = frontend.moduleResolver.getModule("game/workspace/MyScript"); + SourceModule* source = frontend.getSourceModule("game/workspace/MyScript"); + CHECK(module); + CHECK(source); + + CHECK_EQ(module->allocator.get(), source->allocator.get()); + CHECK_EQ(module->names.get(), source->names.get()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "dfg_data_cleared_on_retain_type_graphs_unset") +{ + ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauSelectivelyRetainDFGArena, true}}; + fileResolver.source["game/A"] = R"( +local a = 1 +local b = 2 +local c = 3 +return {x = a, y = b, z = c} +)"; + + frontend.options.retainFullTypeGraphs = true; + frontend.check("game/A"); + + auto mod = frontend.moduleResolver.getModule("game/A"); + CHECK(!mod->defArena.allocator.empty()); + CHECK(!mod->keyArena.allocator.empty()); + + // We should check that the dfg arena is empty once retainFullTypeGraphs is unset + frontend.options.retainFullTypeGraphs = false; + frontend.markDirty("game/A"); + frontend.check("game/A"); + + mod = frontend.moduleResolver.getModule("game/A"); + CHECK(mod->defArena.allocator.empty()); + CHECK(mod->keyArena.allocator.empty()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_traverse_dependents") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + return require(game:GetService('Gui').Modules.A) + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B.hello} + )"; + fileResolver.source["game/Gui/Modules/D"] = R"( + local Modules = game:GetService('Gui').Modules + local C = require(Modules.C) + return {d_value = C.c_value} + )"; + + frontend.check("game/Gui/Modules/D"); + + std::vector visited; + frontend.traverseDependents( + "game/Gui/Modules/B", + [&visited](SourceNode& node) + { + visited.push_back(node.name); + return true; + } + ); + + CHECK_EQ(std::vector{"game/Gui/Modules/B", "game/Gui/Modules/C", "game/Gui/Modules/D"}, visited); +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_traverse_dependents_early_exit") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + return require(game:GetService('Gui').Modules.A) + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B.hello} + )"; + + frontend.check("game/Gui/Modules/C"); + + std::vector visited; + frontend.traverseDependents( + "game/Gui/Modules/A", + [&visited](SourceNode& node) + { + visited.push_back(node.name); + return node.name != "game/Gui/Modules/B"; + } + ); + + CHECK_EQ(std::vector{"game/Gui/Modules/A", "game/Gui/Modules/B"}, visited); +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_dependents_stored_on_node_as_graph_updates") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + + auto updateSource = [&](const std::string& name, const std::string& source) + { + fileResolver.source[name] = source; + frontend.markDirty(name); + }; + + auto validateMatchesRequireLists = [&](const std::string& message) + { + DenseHashMap> dependents{{}}; + for (const auto& module : frontend.sourceNodes) + { + for (const auto& dep : module.second->requireSet) + dependents[dep].push_back(module.first); + } + + for (const auto& module : frontend.sourceNodes) + { + Set& dependentsForModule = module.second->dependents; + for (const auto& dep : dependents[module.first]) + CHECK_MESSAGE(1 == dependentsForModule.count(dep), "Mismatch in dependents for " << module.first << ": " << message); + } + }; + + auto validateSecondDependsOnFirst = [&](const std::string& from, const std::string& to, bool expected) + { + SourceNode& fromNode = *frontend.sourceNodes[from]; + CHECK_MESSAGE( + fromNode.dependents.count(to) == int(expected), + "Expected " << from << " to " << (expected ? std::string() : std::string("not ")) << "have a reverse dependency on " << to + ); + }; + + // C -> B -> A + { + updateSource("game/Gui/Modules/A", "return {hello=5, world=true}"); + updateSource("game/Gui/Modules/B", R"( + return require(game:GetService('Gui').Modules.A) + )"); + updateSource("game/Gui/Modules/C", R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {c_value = B} + )"); + frontend.check("game/Gui/Modules/C"); + + validateMatchesRequireLists("Initial check"); + + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/B", true); + validateSecondDependsOnFirst("game/Gui/Modules/B", "game/Gui/Modules/C", true); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/A", false); + } + + // C -> B, A + { + updateSource("game/Gui/Modules/B", R"( + return 1 + )"); + frontend.check("game/Gui/Modules/C"); + + validateMatchesRequireLists("Removing dependency B->A"); + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/B", false); + } + + // C -> B -> A + { + updateSource("game/Gui/Modules/B", R"( + return require(game:GetService('Gui').Modules.A) + )"); + frontend.check("game/Gui/Modules/C"); + + validateMatchesRequireLists("Adding back B->A"); + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/B", true); + } + + // C -> B -> A, D -> (C,B,A) + { + updateSource("game/Gui/Modules/D", R"( + local C = require(game:GetService('Gui').Modules.C) + local B = require(game:GetService('Gui').Modules.B) + local A = require(game:GetService('Gui').Modules.A) + return {d_value = C.c_value} + )"); + frontend.check("game/Gui/Modules/D"); + + validateMatchesRequireLists("Adding D->C, D->B, D->A"); + validateSecondDependsOnFirst("game/Gui/Modules/A", "game/Gui/Modules/D", true); + validateSecondDependsOnFirst("game/Gui/Modules/B", "game/Gui/Modules/D", true); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/D", true); + } + + // B -> A, C <-> D + { + updateSource("game/Gui/Modules/D", "return require(game:GetService('Gui').Modules.C)"); + updateSource("game/Gui/Modules/C", "return require(game:GetService('Gui').Modules.D)"); + frontend.check("game/Gui/Modules/D"); + + validateMatchesRequireLists("Adding cycle D->C, C->D"); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/D", true); + validateSecondDependsOnFirst("game/Gui/Modules/D", "game/Gui/Modules/C", true); + } + + // B -> A, C -> D, D -> error + { + updateSource("game/Gui/Modules/D", "return require(game:GetService('Gui').Modules.C.)"); + frontend.check("game/Gui/Modules/D"); + + validateMatchesRequireLists("Adding error dependency D->C."); + validateSecondDependsOnFirst("game/Gui/Modules/D", "game/Gui/Modules/C", true); + validateSecondDependsOnFirst("game/Gui/Modules/C", "game/Gui/Modules/D", false); + } +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_invalid_dependency_tracking_per_module_resolver") +{ + ScopedFastFlag dependencyTracking{FFlag::LuauBetterReverseDependencyTracking, true}; + ScopedFastFlag newSolver{FFlag::LuauSolverV2, false}; + + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = "return require(game:GetService('Gui').Modules.A)"; + + FrontendOptions opts; + opts.forAutocomplete = false; + + frontend.check("game/Gui/Modules/B", opts); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/B", opts.forAutocomplete)); + CHECK(!frontend.allModuleDependenciesValid("game/Gui/Modules/B", !opts.forAutocomplete)); + + opts.forAutocomplete = true; + frontend.check("game/Gui/Modules/A", opts); + + CHECK(!frontend.allModuleDependenciesValid("game/Gui/Modules/B", opts.forAutocomplete)); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/B", !opts.forAutocomplete)); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/A", !opts.forAutocomplete)); + CHECK(frontend.allModuleDependenciesValid("game/Gui/Modules/A", opts.forAutocomplete)); +} + TEST_SUITE_END(); diff --git a/tests/Generalization.test.cpp b/tests/Generalization.test.cpp index 1388b900..b9e4eaf1 100644 --- a/tests/Generalization.test.cpp +++ b/tests/Generalization.test.cpp @@ -179,9 +179,9 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "functions_containing_cyclic_tables_can TEST_CASE_FIXTURE(GeneralizationFixture, "union_type_traversal_doesnt_crash") { // t1 where t1 = ('h <: (t1 <: 'i)) | ('j <: (t1 <: 'i)) - TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId i = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId h = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId j = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); TypeId unionType = arena.addType(UnionType{{h, j}}); getMutable(h)->upperBound = i; getMutable(h)->lowerBound = builtinTypes.neverType; @@ -196,9 +196,9 @@ TEST_CASE_FIXTURE(GeneralizationFixture, "union_type_traversal_doesnt_crash") TEST_CASE_FIXTURE(GeneralizationFixture, "intersection_type_traversal_doesnt_crash") { // t1 where t1 = ('h <: (t1 <: 'i)) & ('j <: (t1 <: 'i)) - TypeId i = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId h = arena.addType(FreeType{NotNull{globalScope.get()}}); - TypeId j = arena.addType(FreeType{NotNull{globalScope.get()}}); + TypeId i = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId h = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); + TypeId j = arena.freshType(NotNull{&builtinTypes}, globalScope.get()); TypeId intersectionType = arena.addType(IntersectionType{{h, j}}); getMutable(h)->upperBound = i; diff --git a/tests/Instantiation2.test.cpp b/tests/Instantiation2.test.cpp index fff98e60..fcd136fb 100644 --- a/tests/Instantiation2.test.cpp +++ b/tests/Instantiation2.test.cpp @@ -4,6 +4,7 @@ #include "Fixture.h" #include "ClassFixture.h" +#include "Luau/Type.h" #include "ScopedFlags.h" #include "doctest.h" @@ -29,7 +30,7 @@ TEST_CASE_FIXTURE(Fixture, "weird_cyclic_instantiation") DenseHashMap genericSubstitutions{nullptr}; DenseHashMap genericPackSubstitutions{nullptr}; - TypeId freeTy = arena.freshType(&scope); + TypeId freeTy = arena.freshType(builtinTypes, &scope); FreeType* ft = getMutable(freeTy); REQUIRE(ft); ft->lowerBound = idTy; diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index d02fd9f1..629c3696 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -119,6 +119,7 @@ public: static const int tnil = 0; static const int tboolean = 1; static const int tnumber = 3; + static const int tvector = 4; static const int tstring = 5; static const int ttable = 6; static const int tfunction = 7; @@ -1720,6 +1721,55 @@ bb_fallback_1: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "NumericSimplifications") +{ + IrOp block = build.block(IrBlockKind::Internal); + + build.beginBlock(block); + IrOp value = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), build.inst(IrCmd::SUB_NUM, value, build.constDouble(0.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), build.inst(IrCmd::ADD_NUM, value, build.constDouble(-0.0))); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), build.inst(IrCmd::MUL_NUM, value, build.constDouble(1.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(4), build.inst(IrCmd::MUL_NUM, value, build.constDouble(2.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(5), build.inst(IrCmd::MUL_NUM, value, build.constDouble(-1.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(6), build.inst(IrCmd::MUL_NUM, value, build.constDouble(3.0))); + + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(7), build.inst(IrCmd::DIV_NUM, value, build.constDouble(1.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(8), build.inst(IrCmd::DIV_NUM, value, build.constDouble(-1.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(9), build.inst(IrCmd::DIV_NUM, value, build.constDouble(32.0))); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(10), build.inst(IrCmd::DIV_NUM, value, build.constDouble(6.0))); + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(9)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + STORE_DOUBLE R1, %0 + STORE_DOUBLE R2, %0 + STORE_DOUBLE R3, %0 + %7 = ADD_NUM %0, %0 + STORE_DOUBLE R4, %7 + %9 = UNM_NUM %0 + STORE_DOUBLE R5, %9 + %11 = MUL_NUM %0, 3 + STORE_DOUBLE R6, %11 + STORE_DOUBLE R7, %0 + %15 = UNM_NUM %0 + STORE_DOUBLE R8, %15 + %17 = MUL_NUM %0, 0.03125 + STORE_DOUBLE R9, %17 + %19 = DIV_NUM %0, 6 + STORE_DOUBLE R10, %19 + RETURN R1, 9i + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); @@ -4416,6 +4466,194 @@ bb_0: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverNumber") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_VECTOR R0, 1, 2, 4, tvector + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverVector") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(4.0), build.constDouble(2.0), build.constDouble(1.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_VECTOR R0, 1, 2, 4, tvector + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumberOverVector") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_SPLIT_TVALUE R0, tnumber, 2 + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumberOverNil") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnil)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_SPLIT_TVALUE R0, tnumber, 2 + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverNil") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnil)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_VECTOR R0, 1, 2, 4, tvector + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "NumberOverCombinedVector") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(3.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_SPLIT_TVALUE R0, tnumber, 3 + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverCombinedVector") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(1.0), build.constDouble(2.0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(8.0), build.constDouble(16.0), build.constDouble(32.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_VECTOR R0, 8, 16, 32, tvector + RETURN R0, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VectorOverCombinedNumber") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(2.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(4.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_VECTOR, build.vmReg(0), build.constDouble(8.0), build.constDouble(16.0), build.constDouble(32.0)); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tvector)); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + markDeadStoresInBlockChains(build); + + CHECK("\n" + toString(build.function, IncludeUseInfo::No) == R"( +bb_0: + STORE_VECTOR R0, 8, 16, 32, tvector + RETURN R0, 1i + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Dump"); diff --git a/tests/IrLowering.test.cpp b/tests/IrLowering.test.cpp index 396678468..2a5c23fd 100644 --- a/tests/IrLowering.test.cpp +++ b/tests/IrLowering.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "lua.h" #include "lualib.h" +#include "luacode.h" #include "Luau/BytecodeBuilder.h" #include "Luau/CodeGen.h" @@ -15,10 +16,65 @@ #include #include -static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) +static void luauLibraryConstantLookup(const char* library, const char* member, Luau::CompileConstant* constant) { - Luau::CodeGen::AssemblyOptions options; + // While 'vector' library constants are a Luau built-in, their constant value depends on the embedder LUA_VECTOR_SIZE value + if (strcmp(library, "vector") == 0) + { + if (strcmp(member, "zero") == 0) + return Luau::setCompileConstantVector(constant, 0.0f, 0.0f, 0.0f, 0.0f); + if (strcmp(member, "one") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 1.0f, 1.0f, 0.0f); + } + + if (strcmp(library, "Vector3") == 0) + { + if (strcmp(member, "xAxis") == 0) + return Luau::setCompileConstantVector(constant, 1.0f, 0.0f, 0.0f, 0.0f); + + if (strcmp(member, "yAxis") == 0) + return Luau::setCompileConstantVector(constant, 0.0f, 1.0f, 0.0f, 0.0f); + } +} + +static void luauLibraryConstantLookupC(const char* library, const char* member, lua_CompileConstant* constant) +{ + if (strcmp(library, "test") == 0) + { + if (strcmp(member, "some_nil") == 0) + return luau_set_compile_constant_nil(constant); + + if (strcmp(member, "some_boolean") == 0) + return luau_set_compile_constant_boolean(constant, 1); + + if (strcmp(member, "some_number") == 0) + return luau_set_compile_constant_number(constant, 4.75); + + if (strcmp(member, "some_vector") == 0) + return luau_set_compile_constant_vector(constant, 1.0f, 2.0f, 4.0f, 8.0f); + + if (strcmp(member, "some_string") == 0) + return luau_set_compile_constant_string(constant, "test", 4); + } +} + +static int luauLibraryTypeLookup(const char* library, const char* member) +{ + if (strcmp(library, "Vector3") == 0) + { + if (strcmp(member, "xAxis") == 0) + return LuauBytecodeType::LBC_TYPE_VECTOR; + + if (strcmp(member, "yAxis") == 0) + return LuauBytecodeType::LBC_TYPE_VECTOR; + } + + return LuauBytecodeType::LBC_TYPE_ANY; +} + +static void setupAssemblyOptions(Luau::CodeGen::AssemblyOptions& options, bool includeIrTypes) +{ options.compilationOptions.hooks.vectorAccessBytecodeType = vectorAccessBytecodeType; options.compilationOptions.hooks.vectorNamecallBytecodeType = vectorNamecallBytecodeType; options.compilationOptions.hooks.vectorAccess = vectorAccess; @@ -44,35 +100,10 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = options.includeUseInfo = Luau::CodeGen::IncludeUseInfo::No; options.includeCfgInfo = Luau::CodeGen::IncludeCfgInfo::No; options.includeRegFlowInfo = Luau::CodeGen::IncludeRegFlowInfo::No; +} - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); - Luau::ParseResult result = Luau::Parser::parse(source, strlen(source), names, allocator); - - if (!result.errors.empty()) - throw Luau::ParseErrors(result.errors); - - Luau::CompileOptions copts = {}; - - copts.optimizationLevel = 2; - copts.debugLevel = debugLevel; - copts.typeInfoLevel = 1; - copts.vectorCtor = "vector"; - copts.vectorType = "vector"; - - static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; - copts.userdataTypes = kUserdataCompileTypes; - - Luau::BytecodeBuilder bcb; - Luau::compileOrThrow(bcb, result, names, copts); - - std::string bytecode = bcb.getBytecode(); - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - // Runtime mapping is specifically created to NOT match the compilation mapping - options.compilationOptions.userdataTypes = kUserdataRunTypes; - +static void initializeCodegen(lua_State* L) +{ if (Luau::CodeGen::isSupported()) { // Type remapper requires the codegen runtime @@ -101,9 +132,95 @@ static std::string getCodegenAssembly(const char* source, bool includeIrTypes = } ); } +} + +static std::string getCodegenAssembly(const char* source, bool includeIrTypes = false, int debugLevel = 1) +{ + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + Luau::ParseResult result = Luau::Parser::parse(source, strlen(source), names, allocator); + + if (!result.errors.empty()) + throw Luau::ParseErrors(result.errors); + + Luau::CompileOptions copts = {}; + + copts.optimizationLevel = 2; + copts.debugLevel = debugLevel; + copts.typeInfoLevel = 1; + copts.vectorCtor = "vector"; + copts.vectorType = "vector"; + + static const char* kUserdataCompileTypes[] = {"vec2", "color", "mat3", nullptr}; + copts.userdataTypes = kUserdataCompileTypes; + + static const char* kLibrariesWithConstants[] = {"vector", "Vector3", nullptr}; + copts.librariesWithKnownMembers = kLibrariesWithConstants; + + copts.libraryMemberTypeCb = luauLibraryTypeLookup; + copts.libraryMemberConstantCb = luauLibraryConstantLookup; + + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, result, names, copts); + + std::string bytecode = bcb.getBytecode(); + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + initializeCodegen(L); if (luau_load(L, "name", bytecode.data(), bytecode.size(), 0) == 0) + { + Luau::CodeGen::AssemblyOptions options; + setupAssemblyOptions(options, includeIrTypes); + + // Runtime mapping is specifically created to NOT match the compilation mapping + options.compilationOptions.userdataTypes = kUserdataRunTypes; + return Luau::CodeGen::getAssembly(L, -1, options, nullptr); + } + + FAIL("Failed to load bytecode"); + return ""; +} + +static std::string getCodegenAssemblyUsingCApi(const char* source, bool includeIrTypes = false, int debugLevel = 1) +{ + lua_CompileOptions copts = {}; + + copts.optimizationLevel = 2; + copts.debugLevel = debugLevel; + copts.typeInfoLevel = 1; + + static const char* kLibrariesWithConstants[] = {"test", nullptr}; + copts.librariesWithKnownMembers = kLibrariesWithConstants; + + copts.libraryMemberTypeCb = luauLibraryTypeLookup; + copts.libraryMemberConstantCb = luauLibraryConstantLookupC; + + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source, strlen(source), &copts, &bytecodeSize); + REQUIRE(bytecode); + + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + initializeCodegen(L); + + if (luau_load(L, "name", bytecode, bytecodeSize, 0) == 0) + { + free(bytecode); + + Luau::CodeGen::AssemblyOptions options; + setupAssemblyOptions(options, includeIrTypes); + + // Runtime mapping is specifically created to NOT match the compilation mapping + options.compilationOptions.userdataTypes = kUserdataRunTypes; + + return Luau::CodeGen::getAssembly(L, -1, options, nullptr); + } + + free(bytecode); FAIL("Failed to load bytecode"); return ""; @@ -401,9 +518,9 @@ bb_bytecode_0: JUMP bb_2 bb_2: CHECK_SAFE_ENV exit(3) - JUMP_EQ_TAG K1, tnil, bb_fallback_4, bb_3 + JUMP_EQ_TAG K1 (nil), tnil, bb_fallback_4, bb_3 bb_3: - %9 = LOAD_TVALUE K1 + %9 = LOAD_TVALUE K1 (nil) STORE_TVALUE R1, %9 JUMP bb_5 bb_5: @@ -456,7 +573,7 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - %4 = LOAD_TVALUE K0, 0i, tvector + %4 = LOAD_TVALUE K0 (1, 2, 3), 0i, tvector %11 = LOAD_TVALUE R0 %12 = ADD_VEC %4, %11 %13 = TAG_VECTOR %12 @@ -483,7 +600,7 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - FALLBACK_NAMECALL 0u, R1, R0, K0 + FALLBACK_NAMECALL 0u, R1, R0, K0 ('Abs') INTERRUPT 2u SET_SAVEDPC 3u CALL R1, 1i, -1i @@ -509,8 +626,8 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - FALLBACK_GETTABLEKS 0u, R3, R0, K0 - FALLBACK_GETTABLEKS 2u, R4, R0, K1 + FALLBACK_GETTABLEKS 0u, R3, R0, K0 ('XX') + FALLBACK_GETTABLEKS 2u, R4, R0, K1 ('YY') CHECK_TAG R3, tnumber, bb_fallback_3 CHECK_TAG R4, tnumber, bb_fallback_3 %14 = LOAD_DOUBLE R3 @@ -520,7 +637,7 @@ bb_bytecode_1: JUMP bb_4 bb_4: CHECK_TAG R0, tvector, exit(5) - FALLBACK_GETTABLEKS 5u, R3, R0, K2 + FALLBACK_GETTABLEKS 5u, R3, R0, K2 ('ZZ') CHECK_TAG R2, tnumber, bb_fallback_5 CHECK_TAG R3, tnumber, bb_fallback_5 %30 = LOAD_DOUBLE R2 @@ -540,7 +657,7 @@ TEST_CASE("VectorCustomAccess") CHECK_EQ( "\n" + getCodegenAssembly(R"( local function vec3magn(a: vector) - return a.Magnitude * 2 + return a.Magnitude * 3 end )"), R"( @@ -560,7 +677,7 @@ bb_bytecode_1: %12 = ADD_NUM %9, %10 %13 = ADD_NUM %12, %11 %14 = SQRT_NUM %13 - %20 = MUL_NUM %14, 2 + %20 = MUL_NUM %14, 3 STORE_DOUBLE R1, %20 STORE_TAG R1, tnumber INTERRUPT 3u @@ -738,8 +855,8 @@ bb_2: JUMP bb_bytecode_1 bb_bytecode_1: %8 = LOAD_POINTER R0 - %9 = GET_SLOT_NODE_ADDR %8, 0u, K1 - CHECK_SLOT_MATCH %9, K1, bb_fallback_3 + %9 = GET_SLOT_NODE_ADDR %8, 0u, K1 ('n') + CHECK_SLOT_MATCH %9, K1 ('n'), bb_fallback_3 %11 = LOAD_TVALUE %9, 0i STORE_TVALUE R3, %11 JUMP bb_4 @@ -766,8 +883,8 @@ bb_4: STORE_VECTOR R3, %30, %33, %36 CHECK_TAG R0, ttable, exit(6) %41 = LOAD_POINTER R0 - %42 = GET_SLOT_NODE_ADDR %41, 6u, K3 - CHECK_SLOT_MATCH %42, K3, bb_fallback_5 + %42 = GET_SLOT_NODE_ADDR %41, 6u, K3 ('b') + CHECK_SLOT_MATCH %42, K3 ('b'), bb_fallback_5 %44 = LOAD_TVALUE %42, 0i STORE_TVALUE R5, %44 JUMP bb_6 @@ -810,8 +927,8 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - FALLBACK_GETTABLEKS 0u, R2, R0, K0 - FALLBACK_GETTABLEKS 2u, R3, R0, K1 + FALLBACK_GETTABLEKS 0u, R2, R0, K0 ('x') + FALLBACK_GETTABLEKS 2u, R3, R0, K1 ('y') CHECK_TAG R2, tnumber, bb_fallback_3 CHECK_TAG R3, tnumber, bb_fallback_3 %14 = LOAD_DOUBLE R2 @@ -845,9 +962,9 @@ bb_2: bb_bytecode_1: STORE_DOUBLE R1, 3 STORE_TAG R1, tnumber - FALLBACK_SETTABLEKS 1u, R1, R0, K0 + FALLBACK_SETTABLEKS 1u, R1, R0, K0 ('x') STORE_DOUBLE R1, 4 - FALLBACK_SETTABLEKS 4u, R1, R0, K1 + FALLBACK_SETTABLEKS 4u, R1, R0, K1 ('y') INTERRUPT 6u RETURN R0, 0i )" @@ -870,11 +987,11 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - FALLBACK_NAMECALL 0u, R2, R0, K0 + FALLBACK_NAMECALL 0u, R2, R0, K0 ('GetX') INTERRUPT 2u SET_SAVEDPC 3u CALL R2, 1i, 1i - FALLBACK_NAMECALL 3u, R3, R0, K1 + FALLBACK_NAMECALL 3u, R3, R0, K1 ('GetY') INTERRUPT 5u SET_SAVEDPC 6u CALL R3, 1i, 1i @@ -1167,7 +1284,7 @@ local function inl(v: vector, s: number) end local function getsum(x) - return inl(x, 2) + inl(x, 5) + return inl(x, 3) + inl(x, 5) end )", /* includeIrTypes */ true @@ -1195,7 +1312,7 @@ bb_bytecode_1: bb_bytecode_0: CHECK_TAG R0, tvector, exit(0) %2 = LOAD_FLOAT R0, 4i - %8 = MUL_NUM %2, 2 + %8 = MUL_NUM %2, 3 %13 = LOAD_FLOAT R0, 4i %19 = MUL_NUM %13, 5 %28 = ADD_NUM %8, %19 @@ -1248,8 +1365,8 @@ bb_bytecode_1: bb_4: CHECK_TAG R2, ttable, exit(1) %23 = LOAD_POINTER R2 - %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 - CHECK_SLOT_MATCH %24, K0, bb_fallback_5 + %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 ('pos') + CHECK_SLOT_MATCH %24, K0 ('pos'), bb_fallback_5 %26 = LOAD_TVALUE %24, 0i STORE_TVALUE R4, %26 JUMP bb_6 @@ -1357,13 +1474,13 @@ bb_bytecode_1: bb_4: CHECK_TAG R3, ttable, bb_fallback_5 %23 = LOAD_POINTER R3 - %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 - CHECK_SLOT_MATCH %24, K0, bb_fallback_5 + %24 = GET_SLOT_NODE_ADDR %23, 1u, K0 ('normal') + CHECK_SLOT_MATCH %24, K0 ('normal'), bb_fallback_5 %26 = LOAD_TVALUE %24, 0i STORE_TVALUE R2, %26 JUMP bb_6 bb_6: - %31 = LOAD_TVALUE K1, 0i, tvector + %31 = LOAD_TVALUE K1 (0.707000017, 0, 0.707000017), 0i, tvector STORE_TVALUE R4, %31 CHECK_TAG R2, tvector, exit(4) %37 = LOAD_FLOAT R2, 0i @@ -1484,9 +1601,9 @@ bb_bytecode_1: STORE_DOUBLE R1, 0 STORE_TAG R1, tnumber CHECK_SAFE_ENV exit(1) - JUMP_EQ_TAG K1, tnil, bb_fallback_6, bb_5 + JUMP_EQ_TAG K1 (nil), tnil, bb_fallback_6, bb_5 bb_5: - %9 = LOAD_TVALUE K1 + %9 = LOAD_TVALUE K1 (nil) STORE_TVALUE R2, %9 JUMP bb_7 bb_7: @@ -1508,8 +1625,8 @@ bb_9: bb_bytecode_2: CHECK_TAG R6, ttable, exit(6) %35 = LOAD_POINTER R6 - %36 = GET_SLOT_NODE_ADDR %35, 6u, K2 - CHECK_SLOT_MATCH %36, K2, bb_fallback_10 + %36 = GET_SLOT_NODE_ADDR %35, 6u, K2 ('pos') + CHECK_SLOT_MATCH %36, K2 ('pos'), bb_fallback_10 %38 = LOAD_TVALUE %36, 0i STORE_TVALUE R8, %38 JUMP bb_11 @@ -1710,8 +1827,8 @@ bb_0: bb_2: JUMP bb_bytecode_1 bb_bytecode_1: - FALLBACK_GETTABLEKS 0u, R2, R0, K0 - FALLBACK_GETTABLEKS 2u, R3, R0, K1 + FALLBACK_GETTABLEKS 0u, R2, R0, K0 ('Row1') + FALLBACK_GETTABLEKS 2u, R3, R0, K1 ('Row2') CHECK_TAG R2, tvector, exit(4) CHECK_TAG R3, tvector, exit(4) %14 = LOAD_TVALUE R2 @@ -1994,4 +2111,103 @@ bb_bytecode_1: ); } +TEST_CASE("LibraryFieldTypesAndConstants") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: vector) + return Vector3.xAxis * a + Vector3.yAxis +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: vector [argument] +; R2: vector from 3 to 4 +; R3: vector from 1 to 2 +; R3: vector from 3 to 4 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %4 = LOAD_TVALUE K0 (1, 0, 0), 0i, tvector + %11 = LOAD_TVALUE R0 + %12 = MUL_VEC %4, %11 + %15 = LOAD_TVALUE K1 (0, 1, 0), 0i, tvector + %23 = ADD_VEC %12, %15 + %24 = TAG_VECTOR %23 + STORE_TVALUE R1, %24 + INTERRUPT 4u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("LibraryFieldTypesAndConstants") +{ + CHECK_EQ( + "\n" + getCodegenAssembly( + R"( +local function foo(a: vector) + local x = vector.zero + x += a + return x +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo($arg0) line 2 +; R0: vector [argument] +; R1: vector from 0 to 3 +bb_0: + CHECK_TAG R0, tvector, exit(entry) + JUMP bb_2 +bb_2: + JUMP bb_bytecode_1 +bb_bytecode_1: + %4 = LOAD_TVALUE K0 (0, 0, 0), 0i, tvector + %11 = LOAD_TVALUE R0 + %12 = ADD_VEC %4, %11 + %13 = TAG_VECTOR %12 + STORE_TVALUE R1, %13 + INTERRUPT 2u + RETURN R1, 1i +)" + ); +} + +TEST_CASE("LibraryFieldTypesAndConstantsCApi") +{ + CHECK_EQ( + "\n" + getCodegenAssemblyUsingCApi( + R"( +local function foo() + return test.some_nil, test.some_boolean, test.some_number, test.some_vector, test.some_string +end +)", + /* includeIrTypes */ true + ), + R"( +; function foo() line 2 +bb_bytecode_0: + STORE_TAG R0, tnil + STORE_INT R1, 1i + STORE_TAG R1, tboolean + STORE_DOUBLE R2, 4.75 + STORE_TAG R2, tnumber + %5 = LOAD_TVALUE K1 (1, 2, 4), 0i, tvector + STORE_TVALUE R3, %5 + %7 = LOAD_TVALUE K2 ('test'), 0i, tstring + STORE_TVALUE R4, %7 + INTERRUPT 5u + RETURN R0, 5i +)" + ); +} + TEST_SUITE_END(); diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index e0716e4c..6133305d 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LexerFixInterpStringStart) + TEST_SUITE_BEGIN("LexerTests"); TEST_CASE("broken_string_works") @@ -153,6 +155,8 @@ TEST_CASE("string_interpolation_basic") Lexeme interpEnd = lexer.next(); CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); + // The InterpStringEnd should start with }, not `. + CHECK_EQ(interpEnd.location.begin.column, FFlag::LexerFixInterpStringStart ? 11 : 12); } TEST_CASE("string_interpolation_full") @@ -173,6 +177,7 @@ TEST_CASE("string_interpolation_full") Lexeme interpMid = lexer.next(); CHECK_EQ(interpMid.type, Lexeme::InterpStringMid); CHECK_EQ(interpMid.toString(), "} {"); + CHECK_EQ(interpMid.location.begin.column, FFlag::LexerFixInterpStringStart ? 11 : 12); Lexeme quote2 = lexer.next(); CHECK_EQ(quote2.type, Lexeme::QuotedString); @@ -181,6 +186,7 @@ TEST_CASE("string_interpolation_full") Lexeme interpEnd = lexer.next(); CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); CHECK_EQ(interpEnd.toString(), "} end`"); + CHECK_EQ(interpEnd.location.begin.column, FFlag::LexerFixInterpStringStart ? 19 : 20); } TEST_CASE("string_interpolation_double_brace") @@ -242,4 +248,185 @@ TEST_CASE("string_interpolation_with_unicode_escape") CHECK_EQ(lexer.next().type, Lexeme::Eof); } +TEST_CASE("single_quoted_string") +{ + const std::string testInput = "'test'"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::QuotedString); + CHECK_EQ(lexeme.getQuoteStyle(), Lexeme::QuoteStyle::Single); +} + +TEST_CASE("double_quoted_string") +{ + const std::string testInput = R"("test")"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::QuotedString); + CHECK_EQ(lexeme.getQuoteStyle(), Lexeme::QuoteStyle::Double); +} + +TEST_CASE("lexer_determines_string_block_depth_0") +{ + const std::string testInput = "[[ test ]]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_0_multiline_1") +{ + const std::string testInput = R"([[ test + ]])"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_0_multiline_2") +{ + const std::string testInput = R"([[ + test + ]])"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_0_multiline_3") +{ + const std::string testInput = R"([[ + test ]])"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_1") +{ + const std::string testInput = "[=[[%s]]=]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 1); +} + +TEST_CASE("lexer_determines_string_block_depth_2") +{ + const std::string testInput = "[==[ test ]==]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + +TEST_CASE("lexer_determines_string_block_depth_2_multiline_1") +{ + const std::string testInput = R"([==[ test + ]==])"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + +TEST_CASE("lexer_determines_string_block_depth_2_multiline_2") +{ + const std::string testInput = R"([==[ + test + ]==])"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + +TEST_CASE("lexer_determines_string_block_depth_2_multiline_3") +{ + const std::string testInput = R"([==[ + + test ]==])"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::RawString); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + + +TEST_CASE("lexer_determines_comment_block_depth_0") +{ + const std::string testInput = "--[[ test ]]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::BlockComment); + CHECK_EQ(lexeme.getBlockDepth(), 0); +} + +TEST_CASE("lexer_determines_string_block_depth_1") +{ + const std::string testInput = "--[=[ μέλλον ]=]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::BlockComment); + CHECK_EQ(lexeme.getBlockDepth(), 1); +} + +TEST_CASE("lexer_determines_string_block_depth_2") +{ + const std::string testInput = "--[==[ test ]==]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme lexeme = lexer.next(); + REQUIRE_EQ(lexeme.type, Lexeme::BlockComment); + CHECK_EQ(lexeme.getBlockDepth(), 2); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 8647777a..9162ccf3 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -8,7 +8,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAG(LuauNativeAttribute); LUAU_FASTFLAG(LintRedundantNativeAttribute); using namespace Luau; @@ -1999,7 +1998,7 @@ local _ = a <= (b == 0) TEST_CASE_FIXTURE(Fixture, "RedundantNativeAttribute") { - ScopedFastFlag sff[] = {{FFlag::LuauNativeAttribute, true}, {FFlag::LintRedundantNativeAttribute, true}}; + ScopedFastFlag sff[] = {{FFlag::LintRedundantNativeAttribute, true}}; LintResult result = lint(R"( --!native diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 4519ba82..08b0bb0d 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -14,7 +14,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(DebugLuauFreezeArena); LUAU_FASTINT(LuauTypeCloneIterationLimit); - +LUAU_FASTFLAG(LuauOldSolverCreatesChildScopePointers) TEST_SUITE_BEGIN("ModuleTests"); TEST_CASE_FIXTURE(Fixture, "is_within_comment") @@ -110,9 +110,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") // breaks this test. I'm not sure if that behaviour change is important or // not, but it's tangental to the core purpose of this test. - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local Cyclic = {} @@ -283,7 +281,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_free_types") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); TypeArena arena; TypeId freeTy = freshType(NotNull{&arena}, builtinTypes, nullptr); @@ -542,4 +540,28 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "clone_a_bound_typepack_to_a_persistent_typep REQUIRE(res == follow(boundTo)); } +TEST_CASE_FIXTURE(Fixture, "old_solver_correctly_populates_child_scopes") +{ + ScopedFastFlag sff{FFlag::LuauOldSolverCreatesChildScopePointers, true}; + check(R"( +--!strict +if true then +end + +if false then +end + +if true then +else +end + +local x = {} +for i,v in x do +end +)"); + + auto& module = frontend.moduleResolver.getModule("MainModule"); + CHECK(module->getModuleScope()->children.size() == 7); +} + TEST_SUITE_END(); diff --git a/tests/NonStrictTypeChecker.test.cpp b/tests/NonStrictTypeChecker.test.cpp index ffb44049..61ebecf3 100644 --- a/tests/NonStrictTypeChecker.test.cpp +++ b/tests/NonStrictTypeChecker.test.cpp @@ -4,7 +4,9 @@ #include "Fixture.h" #include "Luau/Ast.h" +#include "Luau/BuiltinDefinitions.h" #include "Luau/Common.h" +#include "Luau/Error.h" #include "Luau/IostreamHelpers.h" #include "Luau/ModuleResolver.h" #include "Luau/VisitType.h" @@ -13,6 +15,9 @@ #include "doctest.h" #include +LUAU_FASTFLAG(LuauNewNonStrictWarnOnUnknownGlobals) +LUAU_FASTFLAG(LuauNonStrictVisitorImprovements) + using namespace Luau; #define NONSTRICT_REQUIRE_ERR_AT_POS(pos, result, idx) \ @@ -486,6 +491,40 @@ foo.bar("hi") NONSTRICT_REQUIRE_CHECKED_ERR(Position(1, 8), "foo.bar", result); } +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "exprgroup_is_checked") +{ + ScopedFastFlag sff{FFlag::LuauNonStrictVisitorImprovements, true}; + + CheckResult result = checkNonStrict(R"( + local foo = (abs("foo")) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto r1 = get(result.errors[0]); + LUAU_ASSERT(r1); + CHECK_EQ("abs", r1->checkedFunctionName); + CHECK_EQ("number", toString(r1->expected)); + CHECK_EQ("string", toString(r1->passed)); +} + +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "binop_is_checked") +{ + ScopedFastFlag sff{FFlag::LuauNonStrictVisitorImprovements, true}; + + CheckResult result = checkNonStrict(R"( + local foo = 4 + abs("foo") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto r1 = get(result.errors[0]); + LUAU_ASSERT(r1); + CHECK_EQ("abs", r1->checkedFunctionName); + CHECK_EQ("number", toString(r1->expected)); + CHECK_EQ("string", toString(r1->passed)); +} + TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "incorrect_arg_count") { CheckResult result = checkNonStrict(R"( @@ -576,4 +615,38 @@ buffer.readi8(b, 0) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(NonStrictTypeCheckerFixture, "nonstrict_method_calls") +{ + Luau::unfreeze(frontend.globals.globalTypes); + Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); + + registerBuiltinGlobals(frontend, frontend.globals); + registerTestTypes(); + + Luau::freeze(frontend.globals.globalTypes); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); + + CheckResult result = checkNonStrict(R"( + local test = "test" + test:lower() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unknown_globals_in_non_strict") +{ + ScopedFastFlag flags[] = {{FFlag::LuauNonStrictVisitorImprovements, true}, {FFlag::LuauNewNonStrictWarnOnUnknownGlobals, true}}; + + CheckResult result = check(Mode::Nonstrict, R"( + foo = 5 + local wrong1 = foob + + local x = 12 + local wrong2 = x + foblm + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 3acd3909..7ff6ea37 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -16,7 +16,7 @@ TEST_SUITE_BEGIN("NonstrictModeTests"); TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict function foo(x, y) end @@ -39,7 +39,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") TEST_CASE_FIXTURE(Fixture, "infer_the_maximum_number_of_values_the_function_could_return") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict function getMinCardCountForWidth(width) @@ -103,7 +103,7 @@ TEST_CASE_FIXTURE(Fixture, "inconsistent_return_types_are_ok") TEST_CASE_FIXTURE(Fixture, "locals_are_any_by_default") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict local m = 55 @@ -130,7 +130,7 @@ TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict local T = {} @@ -148,7 +148,7 @@ TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict local T = {} @@ -163,7 +163,7 @@ TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") TEST_CASE_FIXTURE(Fixture, "table_props_are_any") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict local T = {} @@ -185,7 +185,7 @@ TEST_CASE_FIXTURE(Fixture, "table_props_are_any") TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict local T = { @@ -261,7 +261,7 @@ TEST_CASE_FIXTURE(Fixture, "delay_function_does_not_require_its_argument_to_retu TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index 23b4f133..0e026edf 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -11,8 +11,8 @@ #include "Luau/BuiltinDefinitions.h" LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauNormalizeNotUnknownIntersection) LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTFLAG(LuauFixNormalizedIntersectionOfNegatedClass) using namespace Luau; namespace @@ -27,7 +27,9 @@ struct IsSubtypeFixture : Fixture if (!module->hasModuleScope()) FAIL("isSubtype: module scope data is not available"); - return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); + SimplifierPtr simplifier = newSimplifier(NotNull{&module->internalTypes}, builtinTypes); + + return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, NotNull{simplifier.get()}, ice); } }; } // namespace @@ -849,17 +851,17 @@ TEST_CASE_FIXTURE(NormalizeFixture, "crazy_metatable") TEST_CASE_FIXTURE(NormalizeFixture, "negations_of_classes") { + ScopedFastFlag _{FFlag::LuauFixNormalizedIntersectionOfNegatedClass, true}; createSomeClasses(&frontend); CHECK("(Parent & ~Child) | Unrelated" == toString(normal("(Parent & Not) | Unrelated"))); CHECK("((class & ~Child) | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not"))); - CHECK("Child" == toString(normal("Not & Child"))); + CHECK("never" == toString(normal("Not & Child"))); CHECK("((class & ~Parent) | Child | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not | Child"))); CHECK("(boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not"))); CHECK( "(Parent | Unrelated | boolean | buffer | function | number | string | table | thread)?" == toString(normal("Not & Not & Not>")) ); - CHECK("Child" == toString(normal("(Child | Unrelated) & Not"))); } @@ -960,7 +962,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "final_types_are_cached") TEST_CASE_FIXTURE(NormalizeFixture, "non_final_types_can_be_normalized_but_are_not_cached") { - TypeId a = arena.freshType(&globalScope); + TypeId a = arena.freshType(builtinTypes, &globalScope); std::shared_ptr na1 = normalizer.normalize(a); std::shared_ptr na2 = normalizer.normalize(a); @@ -970,8 +972,6 @@ TEST_CASE_FIXTURE(NormalizeFixture, "non_final_types_can_be_normalized_but_are_n TEST_CASE_FIXTURE(NormalizeFixture, "intersect_with_not_unknown") { - ScopedFastFlag sff{FFlag::LuauNormalizeNotUnknownIntersection, true}; - TypeId notUnknown = arena.addType(NegationType{builtinTypes->unknownType}); TypeId type = arena.addType(IntersectionType{{builtinTypes->numberType, notUnknown}}); std::shared_ptr normalized = normalizer.normalize(type); @@ -1029,4 +1029,109 @@ TEST_CASE_FIXTURE(NormalizeFixture, "truthy_table_property_and_optional_table_wi CHECK("{ x: number }" == toString(ty)); } +TEST_CASE_FIXTURE(BuiltinsFixture, "normalizer_should_be_able_to_detect_cyclic_tables_and_not_stack_overflow") +{ + if (!FFlag::LuauSolverV2) + return; + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 0}; + + CheckResult result = check(R"( +--!strict + +type Array = { [number] : T} +type Object = { [number] : any} + +type Set = typeof(setmetatable( + {} :: { + size: number, + -- method definitions + add: (self: Set, T) -> Set, + clear: (self: Set) -> (), + delete: (self: Set, T) -> boolean, + has: (self: Set, T) -> boolean, + ipairs: (self: Set) -> any, + }, + {} :: { + __index: Set, + __iter: (self: Set) -> (({ [K]: V }, K?) -> (K, V), T), + } +)) + +type Map = typeof(setmetatable( + {} :: { + size: number, + -- method definitions + set: (self: Map, K, V) -> Map, + get: (self: Map, K) -> V | nil, + clear: (self: Map) -> (), + delete: (self: Map, K) -> boolean, + [K]: V, + has: (self: Map, K) -> boolean, + keys: (self: Map) -> Array, + values: (self: Map) -> Array, + entries: (self: Map) -> Array>, + ipairs: (self: Map) -> any, + _map: { [K]: V }, + _array: { [number]: K }, + __index: (self: Map, key: K) -> V, + __iter: (self: Map) -> (({ [K]: V }, K?) -> (K?, V), V), + __newindex: (self: Map, key: K, value: V) -> (), + }, + {} :: { + __index: Map, + __iter: (self: Map) -> (({ [K]: V }, K?) -> (K, V), V), + __newindex: (self: Map, key: K, value: V) -> (), + } +)) +type mapFn = (element: T, index: number) -> U +type mapFnWithThisArg = (thisArg: any, element: T, index: number) -> U + +function fromSet( + value: Set, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? + -- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts +): Array | Array | Array + + local array : { [number] : string} = {"foo"} + return array +end + +function instanceof(tbl: any, class: any): boolean + return true +end + +function fromArray( + value: Array, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? + -- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts +): Array | Array | Array + local array : {[number] : string} = {} + return array +end + +return function( + value: string | Array | Set | Map, + mapFn: (mapFn | mapFnWithThisArg)?, + thisArg: Object? + -- FIXME Luau: need overloading so the return type on this is more sane and doesn't require manual casts +): Array | Array | Array + if value == nil then + error("cannot create array from a nil value") + end + local array: Array | Array | Array + + if instanceof(value, Set) then + array = fromSet(value :: Set, mapFn, thisArg) + else + array = {} + end + + + return array +end +)"); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index dfcf0ded..de986cd4 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -3,6 +3,7 @@ #include "AstQueryDsl.h" #include "Fixture.h" +#include "Luau/Common.h" #include "ScopedFlags.h" #include "doctest.h" @@ -11,13 +12,18 @@ using namespace Luau; -LUAU_FASTFLAG(LuauLexerLookaheadRemembersBraceType); -LUAU_FASTINT(LuauRecursionLimit); -LUAU_FASTINT(LuauTypeLengthLimit); -LUAU_FASTINT(LuauParseErrorLimit); -LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAG(LuauAttributeSyntaxFunExpr); -LUAU_FASTFLAG(LuauUserDefinedTypeFunctions); +LUAU_FASTINT(LuauRecursionLimit) +LUAU_FASTINT(LuauTypeLengthLimit) +LUAU_FASTINT(LuauParseErrorLimit) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauAllowComplexTypesInGenericParams) +LUAU_FASTFLAG(LuauErrorRecoveryForTableTypes) +LUAU_FASTFLAG(LuauErrorRecoveryForClassNames) +LUAU_FASTFLAG(LuauFixFunctionNameStartPosition) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAG(LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType) +LUAU_FASTFLAG(LuauAstTypeGroup2) +LUAU_FASTFLAG(LuauFixDoBlockEndLocation) namespace { @@ -37,24 +43,6 @@ struct Counter int Counter::instanceCount = 0; -// TODO: delete this and replace all other use of this function with matchParseError -std::string getParseError(const std::string& code) -{ - Fixture f; - - try - { - f.parse(code); - } - catch (const Luau::ParseErrors& e) - { - // in general, tests check only the first error - return e.getErrors().front().getMessage(); - } - - throw std::runtime_error("Expected a parse error in '" + code + "'"); -} - } // namespace TEST_SUITE_BEGIN("AllocatorTests"); @@ -384,7 +372,10 @@ TEST_CASE_FIXTURE(Fixture, "return_type_is_an_intersection_type_if_led_with_one_ AstTypeIntersection* returnAnnotation = annotation->returnTypes.types.data[0]->as(); REQUIRE(returnAnnotation != nullptr); - CHECK(returnAnnotation->types.data[0]->as()); + if (FFlag::LuauAstTypeGroup2) + CHECK(returnAnnotation->types.data[0]->as()); + else + CHECK(returnAnnotation->types.data[0]->as()); CHECK(returnAnnotation->types.data[1]->as()); } @@ -463,60 +454,60 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_span_is_correct") TEST_CASE_FIXTURE(Fixture, "parse_error_messages") { - CHECK_EQ( - getParseError(R"( - local a: (number, number) -> (string - )"), + matchParseError( + R"( + local a: (number, number) -> (string + )", "Expected ')' (to close '(' at line 2), got " ); - CHECK_EQ( - getParseError(R"( - local a: (number, number) -> ( - string - )"), + matchParseError( + R"( + local a: (number, number) -> ( + string + )", "Expected ')' (to close '(' at line 2), got " ); - CHECK_EQ( - getParseError(R"( - local a: (number, number) - )"), + matchParseError( + R"( + local a: (number, number) + )", "Expected '->' when parsing function type, got " ); - CHECK_EQ( - getParseError(R"( - local a: (number, number - )"), + matchParseError( + R"( + local a: (number, number + )", "Expected ')' (to close '(' at line 2), got " ); - CHECK_EQ( - getParseError(R"( - local a: {foo: string, - )"), + matchParseError( + R"( + local a: {foo: string, + )", "Expected identifier when parsing table field, got " ); - CHECK_EQ( - getParseError(R"( - local a: {foo: string - )"), + matchParseError( + R"( + local a: {foo: string + )", "Expected '}' (to close '{' at line 2), got " ); - CHECK_EQ( - getParseError(R"( - local a: { [string]: number, [number]: string } - )"), + matchParseError( + R"( + local a: { [string]: number, [number]: string } + )", "Cannot have more than one table indexer" ); - CHECK_EQ( - getParseError(R"( - type T = foo - )"), + matchParseError( + R"( + type T = foo + )", "Expected '(' when parsing function parameters, got 'foo'" ); } @@ -546,10 +537,10 @@ TEST_CASE_FIXTURE(Fixture, "cannot_write_multiple_values_in_type_groups") TEST_CASE_FIXTURE(Fixture, "type_alias_error_messages") { - CHECK_EQ(getParseError("type 5 = number"), "Expected identifier when parsing type name, got '5'"); - CHECK_EQ(getParseError("type A"), "Expected '=' when parsing type alias, got "); - CHECK_EQ(getParseError("type A<"), "Expected identifier, got "); - CHECK_EQ(getParseError("type A' (to close '<' at column 7), got "); + matchParseError("type 5 = number", "Expected identifier when parsing type name, got '5'"); + matchParseError("type A", "Expected '=' when parsing type alias, got "); + matchParseError("type A<", "Expected identifier, got "); + matchParseError("type A' (to close '<' at column 7), got "); } TEST_CASE_FIXTURE(Fixture, "type_assertion_expression") @@ -653,10 +644,10 @@ TEST_CASE_FIXTURE(Fixture, "vertical_space") TEST_CASE_FIXTURE(Fixture, "parse_error_type_name") { - CHECK_EQ( - getParseError(R"( - local a: Foo.= - )"), + matchParseError( + R"( + local a: Foo.= + )", "Expected identifier when parsing field name, got '='" ); } @@ -704,26 +695,26 @@ TEST_CASE_FIXTURE(Fixture, "parse_numbers_binary") TEST_CASE_FIXTURE(Fixture, "parse_numbers_error") { - CHECK_EQ(getParseError("return 0b123"), "Malformed number"); - CHECK_EQ(getParseError("return 123x"), "Malformed number"); - CHECK_EQ(getParseError("return 0xg"), "Malformed number"); - CHECK_EQ(getParseError("return 0x0x123"), "Malformed number"); - CHECK_EQ(getParseError("return 0xffffffffffffffffffffllllllg"), "Malformed number"); - CHECK_EQ(getParseError("return 0x0xffffffffffffffffffffffffffff"), "Malformed number"); + matchParseError("return 0b123", "Malformed number"); + matchParseError("return 123x", "Malformed number"); + matchParseError("return 0xg", "Malformed number"); + matchParseError("return 0x0x123", "Malformed number"); + matchParseError("return 0xffffffffffffffffffffllllllg", "Malformed number"); + matchParseError("return 0x0xffffffffffffffffffffffffffff", "Malformed number"); } TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error") { - CHECK_EQ(getParseError("return 0 print(5)"), "Expected , got 'print'"); - CHECK_EQ(getParseError("while true do break print(5) end"), "Expected 'end' (to close 'do' at column 12), got 'print'"); + matchParseError("return 0 print(5)", "Expected , got 'print'"); + matchParseError("while true do break print(5) end", "Expected 'end' (to close 'do' at column 12), got 'print'"); } TEST_CASE_FIXTURE(Fixture, "error_on_unicode") { - CHECK_EQ( - getParseError(R"( + matchParseError( + R"( local ☃ = 10 - )"), + )", "Expected identifier when parsing variable name, got Unicode character U+2603" ); } @@ -736,10 +727,10 @@ TEST_CASE_FIXTURE(Fixture, "allow_unicode_in_string") TEST_CASE_FIXTURE(Fixture, "error_on_confusable") { - CHECK_EQ( - getParseError(R"( - local pi = 3․13 - )"), + matchParseError( + R"( + local pi = 3․13 + )", "Expected identifier when parsing expression, got Unicode character U+2024 (did you mean '.'?)" ); } @@ -748,8 +739,8 @@ TEST_CASE_FIXTURE(Fixture, "error_on_non_utf8_sequence") { const char* expected = "Expected identifier when parsing expression, got invalid UTF-8 sequence"; - CHECK_EQ(getParseError("local pi = \xFF!"), expected); - CHECK_EQ(getParseError("local pi = \xE2!"), expected); + matchParseError("local pi = \xFF!", expected); + matchParseError("local pi = \xE2!", expected); } TEST_CASE_FIXTURE(Fixture, "lex_broken_unicode") @@ -817,7 +808,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_continue") TEST_CASE_FIXTURE(Fixture, "continue_not_last_error") { - CHECK_EQ(getParseError("while true do continue print(5) end"), "Expected 'end' (to close 'do' at column 12), got 'print'"); + matchParseError("while true do continue print(5) end", "Expected 'end' (to close 'do' at column 12), got 'print'"); } TEST_CASE_FIXTURE(Fixture, "parse_export_type") @@ -860,7 +851,7 @@ TEST_CASE_FIXTURE(Fixture, "export_is_an_identifier_only_when_followed_by_type") TEST_CASE_FIXTURE(Fixture, "incomplete_statement_error") { - CHECK_EQ(getParseError("fiddlesticks"), "Incomplete statement: expected assignment or a function call"); + matchParseError("fiddlesticks", "Incomplete statement: expected assignment or a function call"); } TEST_CASE_FIXTURE(Fixture, "parse_compound_assignment") @@ -1191,9 +1182,7 @@ until false TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_local_function") { - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); try { @@ -1228,9 +1217,7 @@ end TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_failsafe_earlier") { - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); try { @@ -2138,6 +2125,20 @@ TEST_CASE_FIXTURE(Fixture, "variadic_definition_parsing") matchParseError("declare class Foo function a(self, ...) end", "All declaration parameters aside from 'self' must be annotated"); } +TEST_CASE_FIXTURE(Fixture, "missing_declaration_prop") +{ + ScopedFastFlag luauErrorRecoveryForClassNames{FFlag::LuauErrorRecoveryForClassNames, true}; + + matchParseError( + R"( + declare class Foo + a: number, + end + )", + "Expected identifier when parsing property name, got ','" + ); +} + TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") { ParseResult result = parseEx(R"( @@ -2380,11 +2381,13 @@ TEST_CASE_FIXTURE(Fixture, "invalid_type_forms") TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctions, true}; - AstStat* stat = parse(R"( type function foo() - return + return types.number + end + + export type function bar() + return types.string end )"); @@ -2394,6 +2397,152 @@ TEST_CASE_FIXTURE(Fixture, "parse_user_defined_type_functions") REQUIRE(f->name == "foo"); } +TEST_CASE_FIXTURE(Fixture, "parse_nested_type_function") +{ + AstStat* stat = parse(R"( + local v1 = 1 + type function foo() + local v2 = 2 + local function bar() + v2 += 1 + type function inner() end + v2 += 2 + end + local function bar2() + v2 += 3 + end + end + local function bar() v1 += 1 end + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "invalid_user_defined_type_functions") +{ + matchParseError("local foo = 1; type function bar() print(foo) end", "Type function cannot reference outer local 'foo'"); + matchParseError("type function foo() local v1 = 1; type function bar() print(v1) end end", "Type function cannot reference outer local 'v1'"); +} + +TEST_CASE_FIXTURE(Fixture, "leading_union_intersection_with_single_type_preserves_the_union_intersection_ast_node") +{ + ScopedFastFlag _{FFlag::LuauPreserveUnionIntersectionNodeForLeadingTokenSingleType, true}; + AstStatBlock* block = parse(R"( + type Foo = | string + type Bar = & number + )"); + + REQUIRE_EQ(2, block->body.size); + + const auto alias1 = block->body.data[0]->as(); + REQUIRE(alias1); + + const auto unionType = alias1->type->as(); + REQUIRE(unionType); + CHECK_EQ(1, unionType->types.size); + + const auto alias2 = block->body.data[1]->as(); + REQUIRE(alias2); + + const auto intersectionType = alias2->type->as(); + REQUIRE(intersectionType); + CHECK_EQ(1, intersectionType->types.size); +} + +TEST_CASE_FIXTURE(Fixture, "parse_simple_ast_type_group") +{ + ScopedFastFlag _{FFlag::LuauAstTypeGroup2, true}; + + AstStatBlock* stat = parse(R"( + type Foo = (string) + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto alias1 = stat->body.data[0]->as(); + REQUIRE(alias1); + + auto group1 = alias1->type->as(); + REQUIRE(group1); + CHECK(group1->type->is()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_nested_ast_type_group") +{ + ScopedFastFlag _{FFlag::LuauAstTypeGroup2, true}; + + AstStatBlock* stat = parse(R"( + type Foo = ((string)) + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto alias1 = stat->body.data[0]->as(); + REQUIRE(alias1); + + auto group1 = alias1->type->as(); + REQUIRE(group1); + + auto group2 = group1->type->as(); + REQUIRE(group2); + CHECK(group2->type->is()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_return_type_ast_type_group") +{ + ScopedFastFlag _{FFlag::LuauAstTypeGroup2, true}; + + AstStatBlock* stat = parse(R"( + type Foo = () -> (string) + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto alias1 = stat->body.data[0]->as(); + REQUIRE(alias1); + + auto funcType = alias1->type->as(); + REQUIRE(funcType); + + REQUIRE_EQ(1, funcType->returnTypes.types.size); + REQUIRE(!funcType->returnTypes.tailType); + CHECK(funcType->returnTypes.types.data[0]->is()); +} + +TEST_CASE_FIXTURE(Fixture, "inner_and_outer_scope_of_functions_have_correct_end_position") +{ + + AstStatBlock* stat = parse(R"( + local function foo() + local x = 1 + end + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto func = stat->body.data[0]->as(); + REQUIRE(func); + CHECK_EQ(func->func->body->location, Location{{1, 28}, {3, 8}}); + CHECK_EQ(func->location, Location{{1, 8}, {3, 11}}); +} + +TEST_CASE_FIXTURE(Fixture, "do_block_end_location_is_after_end_token") +{ + ScopedFastFlag _{FFlag::LuauFixDoBlockEndLocation, true}; + + AstStatBlock* stat = parse(R"( + do + local x = 1 + end + )"); + REQUIRE(stat); + REQUIRE_EQ(1, stat->body.size); + + auto block = stat->body.data[0]->as(); + REQUIRE(block); + CHECK_EQ(block->location, Location{{1, 8}, {3, 11}}); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -2596,9 +2745,7 @@ TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") } }; - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); checkRecovery("function foo(a, b. c) return a + b end", "function foo(a, b) return a + b end", 1); checkRecovery( @@ -2840,9 +2987,7 @@ TEST_CASE_FIXTURE(Fixture, "AstName_comparison") TEST_CASE_FIXTURE(Fixture, "generic_type_list_recovery") { - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); try { @@ -3138,8 +3283,6 @@ TEST_CASE_FIXTURE(Fixture, "do_block_with_no_end") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved") { - ScopedFastFlag sff{FFlag::LuauLexerLookaheadRemembersBraceType, true}; - ParseResult result = tryParse(R"( local x = `{ {y} }` )"); @@ -3149,8 +3292,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved") TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved2") { - ScopedFastFlag sff{FFlag::LuauLexerLookaheadRemembersBraceType, true}; - ParseResult result = tryParse(R"( local x = `{ { y{} } }` )"); @@ -3325,8 +3466,6 @@ end)"); TEST_CASE_FIXTURE(Fixture, "parse_attribute_for_function_expression") { - ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntaxFunExpr, true}}; - AstStatBlock* stat1 = parse(R"( local function invoker(f) return f(1) @@ -3495,8 +3634,6 @@ function foo1 () @checked return 'a' end TEST_CASE_FIXTURE(Fixture, "dont_parse_attribute_on_argument_non_function") { - ScopedFastFlag sff[] = {{FFlag::LuauAttributeSyntaxFunExpr, true}}; - ParseResult pr = tryParse(R"( local function invoker(f, y) return f(y) @@ -3653,5 +3790,141 @@ TEST_CASE_FIXTURE(Fixture, "mixed_leading_intersection_and_union_not_allowed") matchParseError("type A = | number & string & boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); } +TEST_CASE_FIXTURE(Fixture, "grouped_function_type") +{ + ScopedFastFlag _{FFlag::LuauAllowComplexTypesInGenericParams, true}; + const auto root = parse(R"( + type X = T + local x: X<(() -> ())?> + )"); + LUAU_ASSERT(root); + CHECK_EQ(root->body.size, 2); + auto assignment = root->body.data[1]->as(); + LUAU_ASSERT(assignment); + CHECK_EQ(assignment->vars.size, 1); + CHECK_EQ(assignment->values.size, 0); + auto binding = assignment->vars.data[0]; + CHECK_EQ(binding->name, "x"); + auto genericTy = binding->annotation->as(); + LUAU_ASSERT(genericTy); + CHECK_EQ(genericTy->parameters.size, 1); + auto paramTy = genericTy->parameters.data[0]; + LUAU_ASSERT(paramTy.type); + auto unionTy = paramTy.type->as(); + LUAU_ASSERT(unionTy); + CHECK_EQ(unionTy->types.size, 2); + if (FFlag::LuauAstTypeGroup2) + { + auto groupTy = unionTy->types.data[0]->as(); // (() -> ()) + REQUIRE(groupTy); + CHECK(groupTy->type->is()); // () -> () + } + else + CHECK(unionTy->types.data[0]->is()); // () -> () + CHECK(unionTy->types.data[1]->is()); // nil +} + +TEST_CASE_FIXTURE(Fixture, "complex_union_in_generic_ty") +{ + ScopedFastFlag _{FFlag::LuauAllowComplexTypesInGenericParams, true}; + const auto root = parse(R"( + type X = T + local x: X< + | number + | boolean + | string + > + )"); + LUAU_ASSERT(root); + CHECK_EQ(root->body.size, 2); + auto assignment = root->body.data[1]->as(); + LUAU_ASSERT(assignment); + CHECK_EQ(assignment->vars.size, 1); + CHECK_EQ(assignment->values.size, 0); + auto binding = assignment->vars.data[0]; + CHECK_EQ(binding->name, "x"); + auto genericTy = binding->annotation->as(); + LUAU_ASSERT(genericTy); + CHECK_EQ(genericTy->parameters.size, 1); + auto paramTy = genericTy->parameters.data[0]; + LUAU_ASSERT(paramTy.type); + auto unionTy = paramTy.type->as(); + LUAU_ASSERT(unionTy); + CHECK_EQ(unionTy->types.size, 3); + // NOTE: These are `const char*` so we can compare them to `AstName`s later. + std::vector expectedTypes{"number", "boolean", "string"}; + for (size_t i = 0; i < expectedTypes.size(); i++) + { + auto ty = unionTy->types.data[i]->as(); + LUAU_ASSERT(ty); + CHECK_EQ(ty->name, expectedTypes[i]); + } +} + +TEST_CASE_FIXTURE(Fixture, "recover_from_bad_table_type") +{ + ScopedFastFlag _{FFlag::LuauErrorRecoveryForTableTypes, true}; + ParseOptions opts; + opts.allowDeclarationSyntax = true; + const auto result = tryParse( + R"( + declare class Widget + state: {string: function(string, Widget)} + end + )", + opts + ); + CHECK_EQ(result.errors.size(), 2); +} + +TEST_CASE_FIXTURE(Fixture, "function_name_has_correct_start_location") +{ + ScopedFastFlag _{FFlag::LuauFixFunctionNameStartPosition, true}; + AstStatBlock* block = parse(R"( + function simple() + end + + function T:complex() + end + )"); + + REQUIRE_EQ(2, block->body.size); + + const auto function1 = block->body.data[0]->as(); + LUAU_ASSERT(function1); + CHECK_EQ(Position{1, 17}, function1->name->location.begin); + + const auto function2 = block->body.data[1]->as(); + LUAU_ASSERT(function2); + CHECK_EQ(Position{4, 17}, function2->name->location.begin); +} + +TEST_CASE_FIXTURE(Fixture, "stat_end_includes_semicolon_position") +{ + ScopedFastFlag _{FFlag::LuauExtendStatEndPosWithSemicolon, true}; + AstStatBlock* block = parse(R"( + local x = 1 + local y = 2; + local z = 3 ; + )"); + + REQUIRE_EQ(3, block->body.size); + + const auto stat1 = block->body.data[0]; + LUAU_ASSERT(stat1); + CHECK_FALSE(stat1->hasSemicolon); + CHECK_EQ(Position{1, 19}, stat1->location.end); + + const auto stat2 = block->body.data[1]; + LUAU_ASSERT(stat2); + CHECK(stat2->hasSemicolon); + CHECK_EQ(Position{2, 20}, stat2->location.end); + + const auto stat3 = block->body.data[2]; + LUAU_ASSERT(stat3); + CHECK(stat3->hasSemicolon); + CHECK_EQ(Position{3, 22}, stat3->location.end); +} + TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index a0de6f10..85d53390 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -2,7 +2,8 @@ #include "lua.h" #include "lualib.h" -#include "Repl.h" +#include "Luau/Repl.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -172,15 +173,17 @@ TEST_CASE_FIXTURE(ReplFixture, "CompleteGlobalVariables") CHECK(checkCompletion(completions, prefix, "myvariable1")); CHECK(checkCompletion(completions, prefix, "myvariable2")); } + { // Try completing some builtin functions CompletionSet completions = getCompletionSet("math.m"); std::string prefix = "math."; - CHECK(completions.size() == 3); + CHECK(completions.size() == 4); CHECK(checkCompletion(completions, prefix, "max(")); CHECK(checkCompletion(completions, prefix, "min(")); CHECK(checkCompletion(completions, prefix, "modf(")); + CHECK(checkCompletion(completions, prefix, "map(")); } } diff --git a/tests/RequireByString.test.cpp b/tests/RequireByString.test.cpp index f76f9faf..59a1af3b 100644 --- a/tests/RequireByString.test.cpp +++ b/tests/RequireByString.test.cpp @@ -1,17 +1,25 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Common.h" +#include "Luau/Config.h" + #include "ScopedFlags.h" #include "lua.h" #include "lualib.h" -#include "Repl.h" -#include "FileUtils.h" +#include "Luau/Repl.h" +#include "Luau/FileUtils.h" #include "doctest.h" #include +#include #include #include +#include +#include +#include +#include +#include #if __APPLE__ #include @@ -112,7 +120,7 @@ public: for (int i = 0; i < 20; ++i) { bool engineTestDir = isDirectory(luauDirAbs + "/Client/Luau/tests"); - bool luauTestDir = isDirectory(luauDirAbs + "/luau/tests/require"); + bool luauTestDir = isDirectory(luauDirAbs + "/tests/require"); if (engineTestDir || luauTestDir) { @@ -121,12 +129,6 @@ public: luauDirRel += "/Client/Luau"; luauDirAbs += "/Client/Luau"; } - else - { - luauDirRel += "/luau"; - luauDirAbs += "/luau"; - } - if (type == PathType::Relative) return luauDirRel; @@ -217,21 +219,43 @@ TEST_CASE("PathResolution") std::string prefix = "/"; #endif - CHECK(resolvePath(prefix + "Users/modules/module.luau", "") == prefix + "Users/modules/module.luau"); - CHECK(resolvePath(prefix + "Users/modules/module.luau", "a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); - CHECK(resolvePath(prefix + "Users/modules/module.luau", "./a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); - CHECK(resolvePath(prefix + "Users/modules/module.luau", "/a/string/that/should/be/ignored") == prefix + "Users/modules/module.luau"); - CHECK(resolvePath(prefix + "Users/modules/module.luau", "/Users/modules") == prefix + "Users/modules/module.luau"); + // tuple format: {inputPath, inputBaseFilePath, expected} + std::vector> tests = { + // 1. Basic path resolution + // a. Relative to a relative path that begins with './' + {"./dep", "./src/modules/module.luau", "./src/modules/dep"}, + {"../dep", "./src/modules/module.luau", "./src/dep"}, + {"../../dep", "./src/modules/module.luau", "./dep"}, + {"../../", "./src/modules/module.luau", "./"}, - CHECK(resolvePath("../module", "") == "../module"); - CHECK(resolvePath("../../module", "") == "../../module"); - CHECK(resolvePath("../module/..", "") == ".."); - CHECK(resolvePath("../module/../..", "") == "../.."); + // b. Relative to a relative path that begins with '../' + {"./dep", "../src/modules/module.luau", "../src/modules/dep"}, + {"../dep", "../src/modules/module.luau", "../src/dep"}, + {"../../dep", "../src/modules/module.luau", "../dep"}, + {"../../", "../src/modules/module.luau", "../"}, - CHECK(resolvePath("../dependency", prefix + "Users/modules/module.luau") == prefix + "Users/dependency"); - CHECK(resolvePath("../dependency/", prefix + "Users/modules/module.luau") == prefix + "Users/dependency"); - CHECK(resolvePath("../../../../../Users/dependency", prefix + "Users/modules/module.luau") == prefix + "Users/dependency"); - CHECK(resolvePath("../..", prefix + "Users/modules/module.luau") == prefix); + // c. Relative to an absolute path + {"./dep", prefix + "src/modules/module.luau", prefix + "src/modules/dep"}, + {"../dep", prefix + "src/modules/module.luau", prefix + "src/dep"}, + {"../../dep", prefix + "src/modules/module.luau", prefix + "dep"}, + {"../../", prefix + "src/modules/module.luau", prefix}, + + + // 2. Check behavior for extraneous ".." + // a. Relative paths retain '..' and append if needed + {"../../../", "./src/modules/module.luau", "../"}, + {"../../../", "../src/modules/module.luau", "../../"}, + + // b. Absolute paths ignore '..' if already at root + {"../../../", prefix + "src/modules/module.luau", prefix}, + }; + + for (const auto& [inputPath, inputBaseFilePath, expected] : tests) + { + std::optional resolved = resolvePath(inputPath, inputBaseFilePath); + CHECK(resolved); + CHECK_EQ(resolved, expected); + } } TEST_CASE("PathNormalization") @@ -242,34 +266,57 @@ TEST_CASE("PathNormalization") std::string prefix = "/"; #endif - // Relative path - std::optional result = normalizePath("../../modules/module"); - CHECK(result); - std::string normalized = *result; - std::vector variants = { - "./.././.././modules/./module/", "placeholder/../../../modules/module", "../placeholder/placeholder2/../../../modules/module" - }; - for (const std::string& variant : variants) - { - result = normalizePath(variant); - CHECK(result); - CHECK(normalized == *result); - } + // pair format: {input, expected} + std::vector> tests = { + // 1. Basic formatting checks + {"", "./"}, + {".", "./"}, + {"..", "../"}, + {"a/relative/path", "./a/relative/path"}, - // Absolute path - result = normalizePath(prefix + "Users/modules/module"); - CHECK(result); - normalized = *result; - variants = { - "Users/Users/Users/.././.././modules/./module/", - "placeholder/../Users/..//Users/modules/module", - "Users/../placeholder/placeholder2/../../Users/modules/module" + + // 2. Paths containing extraneous '.' and '/' symbols + {"./remove/extraneous/symbols/", "./remove/extraneous/symbols"}, + {"./remove/extraneous//symbols", "./remove/extraneous/symbols"}, + {"./remove/extraneous/symbols/.", "./remove/extraneous/symbols"}, + {"./remove/extraneous/./symbols", "./remove/extraneous/symbols"}, + + {"../remove/extraneous/symbols/", "../remove/extraneous/symbols"}, + {"../remove/extraneous//symbols", "../remove/extraneous/symbols"}, + {"../remove/extraneous/symbols/.", "../remove/extraneous/symbols"}, + {"../remove/extraneous/./symbols", "../remove/extraneous/symbols"}, + + {prefix + "remove/extraneous/symbols/", prefix + "remove/extraneous/symbols"}, + {prefix + "remove/extraneous//symbols", prefix + "remove/extraneous/symbols"}, + {prefix + "remove/extraneous/symbols/.", prefix + "remove/extraneous/symbols"}, + {prefix + "remove/extraneous/./symbols", prefix + "remove/extraneous/symbols"}, + + + // 3. Paths containing '..' + // a. '..' removes the erasable component before it + {"./remove/me/..", "./remove"}, + {"./remove/me/../", "./remove"}, + + {"../remove/me/..", "../remove"}, + {"../remove/me/../", "../remove"}, + + {prefix + "remove/me/..", prefix + "remove"}, + {prefix + "remove/me/../", prefix + "remove"}, + + // b. '..' stays if path is relative and component is non-erasable + {"./..", "../"}, + {"./../", "../"}, + + {"../..", "../../"}, + {"../../", "../../"}, + + // c. '..' disappears if path is absolute and component is non-erasable + {prefix + "..", prefix}, }; - for (const std::string& variant : variants) + + for (const auto& [input, expected] : tests) { - result = normalizePath(prefix + variant); - CHECK(result); - CHECK(normalized == *result); + CHECK_EQ(normalizePath(input), expected); } } @@ -308,6 +355,22 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireInitLua") assertOutputContainsAll({"true", "result from init.lua"}); } +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireWithFileAmbiguity") +{ + std::string ambiguousPath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/ambiguous_file_requirer"; + + runProtectedRequire(ambiguousPath); + assertOutputContainsAll({"false", "require path could not be resolved to a unique file"}); +} + +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireWithDirectoryAmbiguity") +{ + std::string ambiguousPath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/ambiguous_directory_requirer"; + + runProtectedRequire(ambiguousPath); + assertOutputContainsAll({"false", "require path could not be resolved to a unique file"}); +} + TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireLuau") { std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/module"; @@ -384,6 +447,13 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCacheAfterRequireInitLua") REQUIRE_FALSE_MESSAGE(lua_isnil(L, -1), "Cache did not contain module result"); } +TEST_CASE_FIXTURE(ReplWithPathFixture, "CheckCachedResult") +{ + std::string relativePath = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/validate_cache"; + runProtectedRequire(relativePath); + assertOutputContainsAll({"true"}); +} + TEST_CASE_FIXTURE(ReplWithPathFixture, "LoadStringRelative") { runCode(L, "return pcall(function() return loadstring(\"require('a/relative/path')\")() end)"); @@ -401,25 +471,18 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireAbsolutePath") assertOutputContainsAll({"false", "cannot require an absolute path"}); } -TEST_CASE_FIXTURE(ReplWithPathFixture, "PathsArrayRelativePath") +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireUnprefixedPath") { - std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/requirer"; + std::string path = "an/unprefixed/path"; runProtectedRequire(path); - assertOutputContainsAll({"true", "result from library"}); + assertOutputContainsAll({"false", "require path must start with a valid prefix: ./, ../, or @"}); } -TEST_CASE_FIXTURE(ReplWithPathFixture, "PathsArrayExplicitlyRelativePath") +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithExtension") { - std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/fail_requirer"; + std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency.luau"; runProtectedRequire(path); - assertOutputContainsAll({"false", "error requiring module"}); -} - -TEST_CASE_FIXTURE(ReplWithPathFixture, "PathsArrayFromParent") -{ - std::string path = getLuauDirectory(PathType::Relative) + "/tests/require/with_config/src/global_library_requirer"; - runProtectedRequire(path); - assertOutputContainsAll({"true", "result from global_library"}); + assertOutputContainsAll({"false", "error requiring module: consider removing the file extension"}); } TEST_CASE_FIXTURE(ReplWithPathFixture, "RequirePathWithAlias") @@ -475,4 +538,63 @@ TEST_CASE_FIXTURE(ReplWithPathFixture, "AliasHasIllegalFormat") assertOutputContainsAll({"false", " is not a valid alias"}); } +TEST_CASE_FIXTURE(ReplWithPathFixture, "RequireFromLuauBinary") +{ + char executable[] = "luau"; + std::vector paths = { + getLuauDirectory(PathType::Relative) + "/tests/require/without_config/dependency.luau", + getLuauDirectory(PathType::Absolute) + "/tests/require/without_config/dependency.luau" + }; + + for (const std::string& path : paths) + { + std::vector pathStr(path.size() + 1); + strncpy(pathStr.data(), path.c_str(), path.size()); + pathStr[path.size()] = '\0'; + + char* argv[2] = {executable, pathStr.data()}; + CHECK_EQ(replMain(2, argv), 0); + } +} + +TEST_CASE("ParseAliases") +{ + std::string configJson = R"({ + "aliases": { + "MyAlias": "/my/alias/path", + } +})"; + + Luau::Config config; + + Luau::ConfigOptions::AliasOptions aliasOptions; + aliasOptions.configLocation = "/default/location"; + aliasOptions.overwriteAliases = true; + + Luau::ConfigOptions options{false, aliasOptions}; + + std::optional error = Luau::parseConfig(configJson, config, options); + REQUIRE(!error); + + auto checkContents = [](Luau::Config& config) -> void + { + CHECK(config.aliases.size() == 1); + REQUIRE(config.aliases.contains("myalias")); + + Luau::Config::AliasInfo& aliasInfo = config.aliases["myalias"]; + CHECK(aliasInfo.value == "/my/alias/path"); + CHECK(aliasInfo.originalCase == "MyAlias"); + }; + + checkContents(config); + + // Ensure that copied Configs retain the same information + Luau::Config copyConstructedConfig = config; + checkContents(copyConstructedConfig); + + Luau::Config copyAssignedConfig; + copyAssignedConfig = config; + checkContents(copyAssignedConfig); +} + TEST_SUITE_END(); diff --git a/tests/RequireTracer.test.cpp b/tests/RequireTracer.test.cpp index ba03f363..eac9f96b 100644 --- a/tests/RequireTracer.test.cpp +++ b/tests/RequireTracer.test.cpp @@ -6,6 +6,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauExtendedSimpleRequire) + using namespace Luau; namespace @@ -178,4 +180,59 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "follow_string_indexexpr") CHECK_EQ("game/Test", result.exprs[local->values.data[0]].name); } +TEST_CASE_FIXTURE(RequireTracerFixture, "follow_group") +{ + ScopedFastFlag luauExtendedSimpleRequire{FFlag::LuauExtendedSimpleRequire, true}; + + AstStatBlock* block = parse(R"( + local R = (((game).Test)) + require(R) + )"); + REQUIRE_EQ(2, block->body.size); + + RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); + + AstStatLocal* local = block->body.data[0]->as(); + REQUIRE(local != nullptr); + + CHECK_EQ("game/Test", result.exprs[local->values.data[0]].name); +} + +TEST_CASE_FIXTURE(RequireTracerFixture, "follow_type_annotation") +{ + ScopedFastFlag luauExtendedSimpleRequire{FFlag::LuauExtendedSimpleRequire, true}; + + AstStatBlock* block = parse(R"( + local R = game.Test :: (typeof(game.Redirect)) + require(R) + )"); + REQUIRE_EQ(2, block->body.size); + + RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); + + AstStatLocal* local = block->body.data[0]->as(); + REQUIRE(local != nullptr); + + CHECK_EQ("game/Redirect", result.exprs[local->values.data[0]].name); +} + +TEST_CASE_FIXTURE(RequireTracerFixture, "follow_type_annotation_2") +{ + ScopedFastFlag luauExtendedSimpleRequire{FFlag::LuauExtendedSimpleRequire, true}; + + AstStatBlock* block = parse(R"( + local R = game.Test :: (typeof(game.Redirect)) + local N = R.Nested + require(N) + )"); + REQUIRE_EQ(3, block->body.size); + + RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); + + AstStatLocal* local = block->body.data[1]->as(); + REQUIRE(local != nullptr); + + CHECK_EQ("game/Redirect/Nested", result.exprs[local->values.data[0]].name); +} + TEST_SUITE_END(); diff --git a/tests/RuntimeLimits.test.cpp b/tests/RuntimeLimits.test.cpp index b4acf138..dacbb43d 100644 --- a/tests/RuntimeLimits.test.cpp +++ b/tests/RuntimeLimits.test.cpp @@ -45,9 +45,7 @@ TEST_SUITE_BEGIN("RuntimeLimits"); TEST_CASE_FIXTURE(LimitFixture, "typescript_port_of_Result_type") { - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); constexpr const char* src = R"LUA( --!strict diff --git a/tests/SharedCodeAllocator.test.cpp b/tests/SharedCodeAllocator.test.cpp index bba8daad..13bf9f98 100644 --- a/tests/SharedCodeAllocator.test.cpp +++ b/tests/SharedCodeAllocator.test.cpp @@ -175,6 +175,8 @@ TEST_CASE("NativeModuleRefRefcounting") REQUIRE(modRefA->getRefcount() == 1); REQUIRE(modRefB->getRefcount() == 1); +#if defined(__linux__) && defined(__GNUC__) +#else // NativeModuleRef self move assignment: { NativeModuleRef modRef1{modRefA}; @@ -183,6 +185,8 @@ TEST_CASE("NativeModuleRefRefcounting") REQUIRE(modRefA->getRefcount() == 2); } +#endif + REQUIRE(modRefA->getRefcount() == 1); REQUIRE(modRefB->getRefcount() == 1); diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp index a59312ac..76efc835 100644 --- a/tests/Subtyping.test.cpp +++ b/tests/Subtyping.test.cpp @@ -65,7 +65,10 @@ struct SubtypeFixture : Fixture TypeArena arena; InternalErrorReporter iceReporter; UnifierSharedState sharedState{&ice}; + SimplifierPtr simplifier = newSimplifier(NotNull{&arena}, builtinTypes); Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + TypeCheckLimits limits; + TypeFunctionRuntime typeFunctionRuntime{NotNull{&iceReporter}, NotNull{&limits}}; ScopedFastFlag sff{FFlag::LuauSolverV2, true}; @@ -77,7 +80,9 @@ struct SubtypeFixture : Fixture Subtyping mkSubtyping() { - return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&iceReporter}}; + return Subtyping{ + builtinTypes, NotNull{&arena}, NotNull{simplifier.get()}, NotNull{&normalizer}, NotNull{&typeFunctionRuntime}, NotNull{&iceReporter} + }; } TypePackId pack(std::initializer_list tys) diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index fd72579b..11027e6f 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -322,12 +322,9 @@ n3 [label="TableType 3"]; TEST_CASE_FIXTURE(Fixture, "free") { - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; - - Type type{TypeVariant{FreeType{TypeLevel{0, 0}}}}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + Type type{TypeVariant{FreeType{TypeLevel{0, 0}, builtinTypes->neverType, builtinTypes->unknownType}}}; ToDotOptions opts; opts.showPointers = false; CHECK_EQ( @@ -433,7 +430,7 @@ n1 [label="FreeTypePack 1"]; TEST_CASE_FIXTURE(Fixture, "error_pack") { - TypePackVar pack{TypePackVariant{Unifiable::Error{}}}; + TypePackVar pack{TypePackVariant{ErrorTypePack{}}}; ToDotOptions opts; opts.showPointers = false; diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 59d9b5fd..536a4081 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -13,7 +13,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); LUAU_FASTFLAG(LuauSolverV2); LUAU_FASTFLAG(LuauAttributeSyntax); -LUAU_FASTFLAG(LuauUserDefinedTypeFunctions) TEST_SUITE_BEGIN("ToString"); @@ -45,7 +44,7 @@ TEST_CASE_FIXTURE(Fixture, "bound_types") TEST_CASE_FIXTURE(Fixture, "free_types") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check("local a"); LUAU_REQUIRE_NO_ERRORS(result); @@ -166,7 +165,7 @@ TEST_CASE_FIXTURE(Fixture, "named_metatable") TEST_CASE_FIXTURE(BuiltinsFixture, "named_metatable_toStringNamedFunction") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function createTbl(): NamedMetatable @@ -212,8 +211,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") CHECK( "t2 where " "t1 = { __index: t1, __mul: ((t2, number) -> t2) & ((t2, t2) -> t2), new: () -> t2 } ; " - "t2 = { @metatable t1, { x: number, y: number, z: number } }" == - a + "t2 = { @metatable t1, { x: number, y: number, z: number } }" == a ); } else @@ -594,7 +592,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(Fixture, "toStringErrorPack") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function target(callback: nil) return callback(4, "hello") end @@ -964,12 +962,12 @@ TEST_CASE_FIXTURE(Fixture, "correct_stringification_user_defined_type_functions" std::vector{builtinTypes->numberType}, // Type Function Arguments {}, {AstName{"woohoo"}}, // Type Function Name - std::nullopt + {}, }; Type tv{tftt}; - if (FFlag::LuauSolverV2 && FFlag::LuauUserDefinedTypeFunctions) + if (FFlag::LuauSolverV2) CHECK_EQ(toString(&tv, {}), "woohoo"); } diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index edc2bf47..f179876c 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -12,7 +12,10 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUserDefinedTypeFunctions); +LUAU_FASTFLAG(LuauStoreCSTData) +LUAU_FASTFLAG(LuauExtendStatEndPosWithSemicolon) +LUAU_FASTFLAG(LuauAstTypeGroup2); +LUAU_FASTFLAG(LexerFixInterpStringStart) TEST_SUITE_BEGIN("TranspilerTests"); @@ -44,6 +47,37 @@ TEST_CASE("string_literals_containing_utf8") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("if_stmt_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( if This then Once() end)"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( if This then Once() end)"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( if This then Once() end)"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( if This then Once() end)"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( if This then Once() else Other() end)"; + CHECK_EQ(five, transpile(five).code); + + const std::string six = R"( if This then Once() else Other() end)"; + CHECK_EQ(six, transpile(six).code); + + const std::string seven = R"( if This then Once() elseif true then Other() end)"; + CHECK_EQ(seven, transpile(seven).code); + + const std::string eight = R"( if This then Once() elseif true then Other() end)"; + CHECK_EQ(eight, transpile(eight).code); + + const std::string nine = R"( if This then Once() elseif true then Other() end)"; + CHECK_EQ(nine, transpile(nine).code); +} + TEST_CASE("elseif_chains_indent_sensibly") { const std::string code = R"( @@ -64,17 +98,31 @@ TEST_CASE("elseif_chains_indent_sensibly") TEST_CASE("strips_type_annotations") { const std::string code = R"( local s: string= 'hello there' )"; - const std::string expected = R"( local s = 'hello there' )"; - - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + const std::string expected = R"( local s = 'hello there' )"; + CHECK_EQ(expected, transpile(code).code); + } + else + { + const std::string expected = R"( local s = 'hello there' )"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE("strips_type_assertion_expressions") { const std::string code = R"( local s= some_function() :: any+ something_else() :: number )"; - const std::string expected = R"( local s= some_function() + something_else() )"; - - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + const std::string expected = R"( local s= some_function() + something_else() )"; + CHECK_EQ(expected, transpile(code).code); + } + else + { + const std::string expected = R"( local s= some_function() + something_else() )"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE("function_taking_ellipsis") @@ -99,24 +147,89 @@ TEST_CASE("for_loop") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("for_loop_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( for index = 1, 10 do call(index) end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( for index = 1 , 10 do call(index) end )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( for index = 1, 10 , 3 do call(index) end )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( for index = 1, 10 do call(index) end )"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( for index = 1, 10 do call(index) end )"; + CHECK_EQ(five, transpile(five).code); +} + TEST_CASE("for_in_loop") { const std::string code = R"( for k, v in ipairs(x)do end )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("for_in_loop_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( for k, v in ipairs(x) do end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( for k, v in ipairs(x) do end )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( for k , v in ipairs(x) do end )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( for k, v in next , t do end )"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( for k, v in ipairs(x) do end )"; + CHECK_EQ(five, transpile(five).code); +} + TEST_CASE("while_loop") { const std::string code = R"( while f(x)do print() end )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("while_loop_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( while f(x) do print() end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( while f(x) do print() end )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( while f(x) do print() end )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( while f(x) do print() end )"; + CHECK_EQ(four, transpile(four).code); +} + TEST_CASE("repeat_until_loop") { const std::string code = R"( repeat print() until f(x) )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("repeat_until_loop_condition_on_new_line") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + repeat + print() + until + f(x) )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("lambda") { const std::string one = R"( local p=function(o, m, g) return 77 end )"; @@ -126,6 +239,43 @@ TEST_CASE("lambda") CHECK_EQ(two, transpile(two).code); } +TEST_CASE("local_assignment") +{ + const std::string one = R"( local x = 1 )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local x, y, z = 1, 2, 3 )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( local x )"; + CHECK_EQ(three, transpile(three).code); +} + +TEST_CASE("local_assignment_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( local x = 1 )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local x = 1 )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( local x = 1 )"; + CHECK_EQ(three, transpile(three).code); + + const std::string four = R"( local x , y = 1, 2 )"; + CHECK_EQ(four, transpile(four).code); + + const std::string five = R"( local x, y = 1, 2 )"; + CHECK_EQ(five, transpile(five).code); + + const std::string six = R"( local x, y = 1 , 2 )"; + CHECK_EQ(six, transpile(six).code); + + const std::string seven = R"( local x, y = 1, 2 )"; + CHECK_EQ(seven, transpile(seven).code); +} + TEST_CASE("local_function") { const std::string one = R"( local function p(o, m, g) return 77 end )"; @@ -135,6 +285,16 @@ TEST_CASE("local_function") CHECK_EQ(two, transpile(two).code); } +TEST_CASE("local_function_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( local function p(o, m, ...) end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local function p(o, m, ...) end )"; + CHECK_EQ(two, transpile(two).code); +} + TEST_CASE("function") { const std::string one = R"( function p(o, m, g) return 77 end )"; @@ -144,6 +304,87 @@ TEST_CASE("function") CHECK_EQ(two, transpile(two).code); } +TEST_CASE("returns_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string one = R"( return 1 )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( return 1 , 2 )"; + CHECK_EQ(two, transpile(two).code); + + const std::string three = R"( return 1, 2 )"; + CHECK_EQ(three, transpile(three).code); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( export type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( export type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo< X, Y, Z...> = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_with_defaults_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = string )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE("table_literals") { const std::string code = R"( local t={1, 2, 3, foo='bar', baz=99,[5.5]='five point five', 'end'} )"; @@ -186,6 +427,59 @@ TEST_CASE("table_literal_closing_brace_at_correct_position") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("table_literal_with_semicolon_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1; y = 2 } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_trailing_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1, y = 2, } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_spaces_around_separator") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1 , y = 2 } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_spaces_around_equals") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { x = 1 } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_multiline_with_indexers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + local t = { + ["my first value"] = "x"; + ["my second value"] = "y"; + } + )"; + + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("method_calls") { const std::string code = R"( foo.bar.baz:quux() )"; @@ -203,8 +497,15 @@ TEST_CASE("spaces_between_keywords_even_if_it_pushes_the_line_estimation_off") // Luau::Parser doesn't exactly preserve the string representation of numbers in Lua, so we can find ourselves // falling out of sync with the original code. We need to push keywords out so that there's at least one space between them. const std::string code = R"( if math.abs(raySlope) < .01 then return 0 end )"; - const std::string expected = R"( if math.abs(raySlope) < 0.01 then return 0 end)"; - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ(code, transpile(code).code); + } + else + { + const std::string expected = R"( if math.abs(raySlope) < 0.01 then return 0 end)"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE("numbers") @@ -216,8 +517,70 @@ TEST_CASE("numbers") TEST_CASE("infinity") { const std::string code = R"( local a = 1e500 local b = 1e400 )"; - const std::string expected = R"( local a = 1e500 local b = 1e500 )"; - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ(code, transpile(code).code); + } + else + { + const std::string expected = R"( local a = 1e500 local b = 1e500 )"; + CHECK_EQ(expected, transpile(code).code); + } +} + +TEST_CASE("numbers_with_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = 123_456_789 )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("hexadecimal_numbers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = 0xFFFF )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("binary_numbers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = 0b0101 )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("single_quoted_strings") +{ + const std::string code = R"( local a = 'hello world' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("double_quoted_strings") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = "hello world" )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("simple_interp_string") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = `hello world` )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("raw_strings") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = [[ hello world ]] )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("raw_strings_with_blocks") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local a = [==[ hello world ]==] )"; + CHECK_EQ(code, transpile(code).code); } TEST_CASE("escaped_strings") @@ -232,6 +595,33 @@ TEST_CASE("escaped_strings_2") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("escaped_strings_newline") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + print("foo \ + bar") + )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("escaped_strings_raw") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( local x = [=[\v<((do|load)file|require)\s*\(?['"]\zs[^'"]+\ze['"]]=] )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("position_correctly_updated_when_writing_multiline_string") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + call([[ + testing + ]]) )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("need_a_space_between_number_literals_and_dots") { const std::string code = R"( return point and math.ceil(point* 100000* 100)/ 100000 .. '%'or '' )"; @@ -244,6 +634,86 @@ TEST_CASE("binary_keywords") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("function_call_parentheses_no_args") +{ + const std::string code = R"( call() )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_one_arg") +{ + const std::string code = R"( call(arg) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_multiple_args") +{ + const std::string code = R"( call(arg1, arg3, arg3) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_multiple_args_no_space") +{ + const std::string code = R"( call(arg1,arg3,arg3) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_parentheses_multiple_args_space_before_commas") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call(arg1 ,arg3 ,arg3) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_spaces_before_parentheses") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call () )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_spaces_within_parentheses") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call( ) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_string_double_quotes") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call "string" )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_string_single_quotes") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call 'string' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_string_no_space") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call'string' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_table_literal") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call { x = 1 } )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("function_call_table_literal_no_space") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( call{x=1} )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("do_blocks") { const std::string code = R"( @@ -260,6 +730,19 @@ TEST_CASE("do_blocks") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("nested_do_block") +{ + const std::string code = R"( + do + do + local x = 1 + end + end + )"; + + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("emit_a_do_block_in_cases_of_potentially_ambiguous_syntax") { const std::string code = R"( @@ -269,6 +752,106 @@ TEST_CASE("emit_a_do_block_in_cases_of_potentially_ambiguous_syntax") CHECK_EQ(code, transpile(code).code); } +TEST_CASE_FIXTURE(Fixture, "parentheses_multiline") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( +local test = ( + x +) + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "stmt_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( local test = 1; )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local test = 1 ; )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "do_block_ending_with_semicolon") +{ + std::string code = R"( + do + return; + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_stmt_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + if init then + x = string.sub(x, utf8.offset(x, init)); + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_stmt_semicolon_2") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + if (t < 1) then return c/2*t*t + b end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "for_loop_stmt_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + for i,v in ... do + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "while_do_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + while true do + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "function_definition_semicolon") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauExtendStatEndPosWithSemicolon, true}, + }; + std::string code = R"( + function foo() + end; + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE("roundtrip_types") { const std::string code = R"( @@ -339,9 +922,16 @@ TEST_CASE("a_table_key_can_be_the_empty_string") TEST_CASE("always_emit_a_space_after_local_keyword") { std::string code = "do local aZZZZ = Workspace.P1.Shape local bZZZZ = Enum.PartType.Cylinder end"; - std::string expected = "do local aZZZZ = Workspace.P1 .Shape local bZZZZ= Enum.PartType.Cylinder end"; - CHECK_EQ(expected, transpile(code).code); + if (FFlag::LuauStoreCSTData) + { + CHECK_EQ(code, transpile(code).code); + } + else + { + std::string expected = "do local aZZZZ = Workspace.P1 .Shape local bZZZZ= Enum.PartType.Cylinder end"; + CHECK_EQ(expected, transpile(code).code); + } } TEST_CASE_FIXTURE(Fixture, "types_should_not_be_considered_cyclic_if_they_are_not_recursive") @@ -422,6 +1012,16 @@ TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "type_assertion_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "local a = 5 :: number"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = "local a = 5 :: number"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") { std::string code = "local a = if 1 then 2 else 3"; @@ -429,6 +1029,80 @@ TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") CHECK_EQ(code, transpile(code).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else_multiple_conditions") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else_multiple_conditions_2") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + local x = if yes + then nil + else if no + then if this + then that + else other + else nil + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_then_else_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 else 3"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); + + code = "local a = if 1 then 2 elseif 3 then 4 else 5"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "if_then_else_spaces_between_else_if") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + return + if a then "was a" else + if b then "was b" else + if c then "was c" else + "was nothing!" + )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_type_reference_import") { fileResolver.source["game/A"] = R"( @@ -444,6 +1118,34 @@ local a: Import.Type CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_reference_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( local _: Foo.Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Foo .Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Foo. Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type <> )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type< > )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type< number> )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( local _: Type )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_type_packs") { std::string code = R"( @@ -473,7 +1175,10 @@ TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_3") { std::string code = "local a: nil | (string & number)"; - CHECK_EQ("local a: ( string & number)?", transpile(code, {}, true).code); + if (FFlag::LuauAstTypeGroup2) + CHECK_EQ("local a: (string & number)?", transpile(code, {}, true).code); + else + CHECK_EQ("local a: ( string & number)?", transpile(code, {}, true).code); } TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested") @@ -497,6 +1202,26 @@ TEST_CASE_FIXTURE(Fixture, "transpile_varargs") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "index_name_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = "local _ = a.name"; + CHECK_EQ(one, transpile(one, {}, true).code); + + std::string two = "local _ = a .name"; + CHECK_EQ(two, transpile(two, {}, true).code); + + std::string three = "local _ = a. name"; + CHECK_EQ(three, transpile(three, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "index_name_ends_with_digit") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = "sparkles.Color = Color3.new()"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_index_expr") { std::string code = "local a = {1, 2, 3} local b = a[2]"; @@ -504,6 +1229,22 @@ TEST_CASE_FIXTURE(Fixture, "transpile_index_expr") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "index_expr_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = "local _ = a[2]"; + CHECK_EQ(one, transpile(one, {}, true).code); + + std::string two = "local _ = a [2]"; + CHECK_EQ(two, transpile(two, {}, true).code); + + std::string three = "local _ = a[ 2]"; + CHECK_EQ(three, transpile(three, {}, true).code); + + std::string four = "local _ = a[2 ]"; + CHECK_EQ(four, transpile(four, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_unary") { std::string code = R"( @@ -518,6 +1259,32 @@ local d = #e CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "unary_spaces_around_tokens") +{ + std::string code = R"( +local _ = -1 +local _ = - 1 +local _ = not true +local _ = not true +local _ = #e +local _ = # e + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "binary_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( +local _ = 1+1 +local _ = 1 +1 +local _ = 1+ 1 + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_break_continue") { std::string code = R"( @@ -548,6 +1315,16 @@ a ..= ' - result' CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "compound_assignment_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = R"( a += 1 )"; + CHECK_EQ(one, transpile(one, {}, true).code); + + std::string two = R"( a += 1 )"; + CHECK_EQ(two, transpile(two, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") { std::string code = "a, b, c = 1, 2, 3"; @@ -555,6 +1332,31 @@ TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_assign_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string one = "a = 1"; + CHECK_EQ(one, transpile(one).code); + + std::string two = "a = 1"; + CHECK_EQ(two, transpile(two).code); + + std::string three = "a = 1"; + CHECK_EQ(three, transpile(three).code); + + std::string four = "a , b = 1, 2"; + CHECK_EQ(four, transpile(four).code); + + std::string five = "a, b = 1, 2"; + CHECK_EQ(five, transpile(five).code); + + std::string six = "a, b = 1 , 2"; + CHECK_EQ(six, transpile(six).code); + + std::string seven = "a, b = 1, 2"; + CHECK_EQ(seven, transpile(seven).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_generic_function") { std::string code = R"( @@ -684,13 +1486,58 @@ TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple_types") TEST_CASE_FIXTURE(Fixture, "transpile_string_interp") { + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; std::string code = R"( local _ = `hello {name}` )"; CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp_multiline") +{ + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; + std::string code = R"( local _ = `hello { + name + }!` )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp_on_new_line") +{ + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; + std::string code = R"( + error( + `a {b} c` + ) + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp_multiline_escape") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( local _ = `hello \ + world!` )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") { + ScopedFastFlag fflags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LexerFixInterpStringStart, true}, + }; std::string code = R"( local _ = ` bracket = \{, backtick = \` = {'ok'} ` )"; CHECK_EQ(code, transpile(code, {}, true).code); @@ -698,11 +1545,211 @@ TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") TEST_CASE_FIXTURE(Fixture, "transpile_type_functions") { - ScopedFastFlag sff{FFlag::LuauUserDefinedTypeFunctions, true}; - std::string code = R"( type function foo(arg1, arg2) if arg1 == arg2 then return arg1 end return arg2 end )"; CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_typeof_spaces_around_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type X = typeof(x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof(x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof (x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof( x) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type X = typeof(x ) )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_single_quoted_string_types") +{ + const std::string code = R"( type a = 'hello world' )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_double_quoted_string_types") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( type a = "hello world" )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_raw_string_types") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type a = [[ hello world ]] )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type a = [==[ hello world ]==] )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_escaped_string_types") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( type a = "\\b\\t\\n\\\\" )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_semicolon_separators") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + const std::string code = R"( + type Foo = { + bar: number; + baz: number; + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_access_modifiers") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + read bar: number, + write baz: number, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { read string } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { + read [string]: number, + read ["property"]: number + } )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_spaces_between_tokens") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar : number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number , } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number, } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { bar: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [ string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string ]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string] : number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = { [string]: number } )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_preserve_original_indexer_style") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + [number]: string + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( + type Foo = { { number } } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_preserve_indexer_location") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + [number]: string, + property: number, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( + type Foo = { + property: number, + [number]: string, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( + type Foo = { + property: number, + [number]: string, + property2: number, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_type_table_preserve_property_definition_style") +{ + ScopedFastFlag _{FFlag::LuauStoreCSTData, true}; + std::string code = R"( + type Foo = { + ["$$typeof1"]: string, + ['$$typeof2']: string, + } + )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE("transpile_types_preserve_parentheses_style") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauStoreCSTData, true}, + {FFlag::LuauAstTypeGroup2, true}, + }; + + std::string code = R"( type Foo = number )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = (number) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = ((number)) )"; + CHECK_EQ(code, transpile(code, {}, true).code); + + code = R"( type Foo = ( (number) ) )"; + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TxnLog.test.cpp b/tests/TxnLog.test.cpp index b4b18353..29f7c0f3 100644 --- a/tests/TxnLog.test.cpp +++ b/tests/TxnLog.test.cpp @@ -16,8 +16,8 @@ LUAU_FASTFLAG(LuauSolverV2) struct TxnLogFixture { - TxnLog log{/*useScopes*/ true}; - TxnLog log2{/*useScopes*/ true}; + TxnLog log; + TxnLog log2; TypeArena arena; BuiltinTypes builtinTypes; @@ -33,39 +33,6 @@ struct TxnLogFixture TEST_SUITE_BEGIN("TxnLog"); -TEST_CASE_FIXTURE(TxnLogFixture, "colliding_union_incoming_type_has_greater_scope") -{ - ScopedFastFlag sff{FFlag::LuauSolverV2, true}; - - log.replace(c, BoundType{a}); - log2.replace(a, BoundType{c}); - - CHECK(nullptr != log.pending(c)); - - log.concatAsUnion(std::move(log2), NotNull{&arena}); - - // 'a has greater scope than 'c, so we expect the incoming binding of 'a to - // be discarded. - - CHECK(nullptr == log.pending(a)); - - const PendingType* pt = log.pending(c); - REQUIRE(pt != nullptr); - - CHECK(!pt->dead); - const BoundType* bt = get_if(&pt->pending.ty); - - CHECK(a == bt->boundTo); - - log.commit(); - - REQUIRE(get(a)); - - const BoundType* bound = get(c); - REQUIRE(bound); - CHECK(a == bound->boundTo); -} - TEST_CASE_FIXTURE(TxnLogFixture, "colliding_union_incoming_type_has_lesser_scope") { ScopedFastFlag sff{FFlag::LuauSolverV2, true}; diff --git a/tests/TypeFunction.test.cpp b/tests/TypeFunction.test.cpp index 1398a0eb..adc7a091 100644 --- a/tests/TypeFunction.test.cpp +++ b/tests/TypeFunction.test.cpp @@ -13,15 +13,15 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctions) LUAU_DYNAMIC_FASTINT(LuauTypeFamilyApplicationCartesianProductLimit) +LUAU_FASTFLAG(LuauMetatableTypeFunctions) struct TypeFunctionFixture : Fixture { TypeFunction swapFunction; TypeFunctionFixture() - : Fixture(true, false) + : Fixture(false) { swapFunction = TypeFunction{ /* name */ "Swap", @@ -34,20 +34,20 @@ struct TypeFunctionFixture : Fixture if (isString(param)) { - return TypeFunctionReductionResult{ctx->builtins->numberType, false, {}, {}}; + return TypeFunctionReductionResult{ctx->builtins->numberType, Reduction::MaybeOk, {}, {}}; } else if (isNumber(param)) { - return TypeFunctionReductionResult{ctx->builtins->stringType, false, {}, {}}; + return TypeFunctionReductionResult{ctx->builtins->stringType, Reduction::MaybeOk, {}, {}}; } else if (is(param) || is(param) || is(param) || (ctx->solver && ctx->solver->hasUnresolvedConstraints(param))) { - return TypeFunctionReductionResult{std::nullopt, false, {param}, {}}; + return TypeFunctionReductionResult{std::nullopt, Reduction::MaybeOk, {param}, {}}; } else { - return TypeFunctionReductionResult{std::nullopt, true, {}, {}}; + return TypeFunctionReductionResult{std::nullopt, Reduction::Erroneous, {}, {}}; } } }; @@ -939,14 +939,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "index_wait_for_pending_no_crash") Exp = 0, MaxExp = 100 } - type Keys = index> - -- This function makes it think that there's going to be a pending expansion local function UpdateData(key: Keys, value) PlayerData[key] = value end - UpdateData("Coins", 2) )"); @@ -1280,18 +1277,211 @@ TEST_CASE_FIXTURE(ClassFixture, "rawget_type_function_errors_w_classes") CHECK(toString(result.errors[0]) == "Property '\"BaseField\"' does not exist on type 'BaseClass'"); } -TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors") +TEST_CASE_FIXTURE(Fixture, "fuzz_len_type_function_follow") { - if (!FFlag::LuauUserDefinedTypeFunctions) + // Should not fail assertions + check(R"( + local _ + _ = true + for l0=_,_,# _ do + end + for l0=_,_ do + if _ then + _ += _ + end + end + )"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_assigns_correct_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) return; CheckResult result = check(R"( - type function foo() - return nil - end + type Identity = setmetatable<{}, { __index: {} }> )"); - LUAU_CHECK_ERROR_COUNT(1, result); - CHECK(toString(result.errors[0]) == "This syntax is not supported"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId id = requireTypeAlias("Identity"); + CHECK_EQ(toString(id, {true}), "{ @metatable { __index: { } }, { } }"); + const MetatableType* mt = get(id); + REQUIRE(mt); + CHECK_EQ(toString(mt->metatable), "{ __index: { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_assigns_correct_metatable_2") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, { __index: {} }> + type FooBar = setmetatable<{}, Identity> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId id = requireTypeAlias("Identity"); + CHECK_EQ(toString(id, {true}), "{ @metatable { __index: { } }, { } }"); + const MetatableType* mt = get(id); + REQUIRE(mt); + CHECK_EQ(toString(mt->metatable), "{ __index: { } }"); + + TypeId foobar = requireTypeAlias("FooBar"); + const MetatableType* mt2 = get(foobar); + REQUIRE(mt2); + CHECK_EQ(toString(mt2->metatable, {true}), "{ @metatable { __index: { } }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_errors_on_metatable_with_metatable_metamethod") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, { __metatable: "blocked" }> + type Bad = setmetatable + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeId id = requireTypeAlias("Identity"); + CHECK_EQ(toString(id, {true}), "{ @metatable { __metatable: \"blocked\" }, { } }"); + const MetatableType* mt = get(id); + REQUIRE(mt); + CHECK_EQ(toString(mt->metatable), "{ __metatable: \"blocked\" }"); +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_errors_on_invalid_set") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_type_function_errors_on_nontable_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, string> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_type_function_returns_nil_if_no_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type TableWithNoMetatable = getmetatable<{}> + type NumberWithNoMetatable = getmetatable + type BooleanWithNoMetatable = getmetatable + type BooleanLiteralWithNoMetatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tableResult = requireTypeAlias("TableWithNoMetatable"); + CHECK_EQ(toString(tableResult), "nil"); + + auto numberResult = requireTypeAlias("NumberWithNoMetatable"); + CHECK_EQ(toString(numberResult), "nil"); + + auto booleanResult = requireTypeAlias("BooleanWithNoMetatable"); + CHECK_EQ(toString(booleanResult), "nil"); + + auto booleanLiteralResult = requireTypeAlias("BooleanLiteralWithNoMetatable"); + CHECK_EQ(toString(booleanLiteralResult), "nil"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_correct_metatable") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + local metatable = { __index = { w = 4 } } + local obj = setmetatable({x = 1, y = 2, z = 3}, metatable) + type Metatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireTypeAlias("Metatable"), {true}), "{ __index: { w: number } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_correct_metatable_for_union") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Identity = setmetatable<{}, {}> + type Metatable = getmetatable + type IntersectMetatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const PrimitiveType* stringType = get(builtinTypes->stringType); + REQUIRE(stringType->metatable); + + TypeArena arena = TypeArena{}; + + std::string expected1 = toString(arena.addType(UnionType{{*stringType->metatable, builtinTypes->emptyTableType}}), {true}); + CHECK_EQ(toString(requireTypeAlias("Metatable"), {true}), expected1); + + std::string expected2 = toString(arena.addType(IntersectionType{{*stringType->metatable, builtinTypes->emptyTableType}}), {true}); + CHECK_EQ(toString(requireTypeAlias("IntersectMetatable"), {true}), expected2); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_correct_metatable_for_string") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + type Metatable = getmetatable + type Metatable2 = getmetatable<"foo"> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const PrimitiveType* stringType = get(builtinTypes->stringType); + REQUIRE(stringType->metatable); + + std::string expected = toString(*stringType->metatable, {true}); + + CHECK_EQ(toString(requireTypeAlias("Metatable"), {true}), expected); + CHECK_EQ(toString(requireTypeAlias("Metatable2"), {true}), expected); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_respects_metatable_metamethod") +{ + if (!FFlag::LuauSolverV2 || !FFlag::LuauMetatableTypeFunctions) + return; + + CheckResult result = check(R"( + local metatable = { __metatable = "Test" } + local obj = setmetatable({x = 1, y = 2, z = 3}, metatable) + type Metatable = getmetatable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireTypeAlias("Metatable")), "string"); } TEST_SUITE_END(); diff --git a/tests/TypeFunction.user.test.cpp b/tests/TypeFunction.user.test.cpp new file mode 100644 index 00000000..12c9ece2 --- /dev/null +++ b/tests/TypeFunction.user.test.cpp @@ -0,0 +1,1985 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "ClassFixture.h" +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauTypeFunFixHydratedClasses) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) +LUAU_FASTFLAG(LuauTypeFunSingletonEquality) +LUAU_FASTFLAG(LuauUserTypeFunTypeofReturnsType) +LUAU_FASTFLAG(LuauTypeFunReadWriteParents) +LUAU_FASTFLAG(LuauTypeFunPrintFix) + +TEST_SUITE_BEGIN("UserDefinedTypeFunctionTests"); + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_nil(arg) + return arg + end + type type_being_serialized = nil + local function ok(idx: serialize_nil): nil return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_nil_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getnil() + local ty = types.singleton(nil) + if ty:is("nil") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnil<>): nil return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_unknown(arg) + return arg + end + type type_being_serialized = unknown + local function ok(idx: serialize_unknown): unknown return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_unknown_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getunknown() + local ty = types.unknown + if ty:is("unknown") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getunknown<>): unknown return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_never(arg) + return arg + end + type type_being_serialized = never + local function ok(idx: serialize_never): never return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_never_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getnever() + local ty = types.never + if ty:is("never") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnever<>): never return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_any(arg) + return arg + end + type type_being_serialized = any + local function ok(idx: serialize_any): any return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_any_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getany() + local ty = types.any + if ty:is("any") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getany<>): any return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_bool(arg) + return arg + end + type type_being_serialized = boolean + local function ok(idx: serialize_bool): boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolean_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getboolean() + local ty = types.boolean + if ty:is("boolean") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getboolean<>): boolean return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_num(arg) + return arg + end + type type_being_serialized = number + local function ok(idx: serialize_num): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_number_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getnumber() + local ty = types.number + if ty:is("number") then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getnumber<>): number return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "thread_and_buffer_types") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + type function work_with_thread(x) + if x:is("thread") then + return types.thread + end + return types.string + end + type X = thread + local function ok(idx: work_with_thread): thread return idx end + )")); + + LUAU_REQUIRE_NO_ERRORS(check(R"( + type function work_with_buffer(x) + if x:is("buffer") then + return types.buffer + end + return types.string + end + type X = buffer + local function ok(idx: work_with_buffer): buffer return idx end + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_str(arg) + return arg + end + type type_being_serialized = string + local function ok(idx: serialize_str): string return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_string_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getstring() + local ty = types.string + if ty:is("string") then + return ty + end + -- this should never be returned + return types.boolean + end + local function ok(idx: getstring<>): string return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_boolsingleton(arg) + return arg + end + type type_being_serialized = true + local function ok(idx: serialize_boolsingleton): true return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_boolsingleton_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getboolsingleton() + local ty = types.singleton(true) + if ty:is("singleton") and ty:value() then + return ty + end + -- this should never be returned + return types.string + end + local function ok(idx: getboolsingleton<>): true return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_strsingleton(arg) + return arg + end + type type_being_serialized = "popcorn and movies!" + local function ok(idx: serialize_strsingleton): "popcorn and movies!" return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strsingleton_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getstrsingleton() + local ty = types.singleton("hungry hippo") + if ty:is("singleton") and ty:value() == "hungry hippo" then + return ty + end + -- this should never be returned + return types.number + end + local function ok(idx: getstrsingleton<>): "hungry hippo" return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_union(arg) + return arg + end + type type_being_serialized = number | string | boolean + -- forcing an error here to check the exact type of the union + local function ok(idx: serialize_union): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "boolean | number | string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_union_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getunion() + local ty = types.unionof(types.string, types.number, types.boolean) + if ty:is("union") then + -- creating a copy of `ty` + local arr = {} + for _, value in ty:components() do + table.insert(arr, value) + end + return types.unionof(table.unpack(arr)) + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the union + local function ok(idx: getunion<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "boolean | number | string"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_intersection(arg) + return arg + end + type type_being_serialized = { boolean: boolean, number: number } & { boolean: boolean, string: string } + -- forcing an error here to check the exact type of the intersection + local function ok(idx: serialize_intersection): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ boolean: boolean, number: number } & { boolean: boolean, string: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_intersection_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getintersection() + local tbl1 = types.newtable(nil, nil, nil) + tbl1:setproperty(types.singleton("boolean"), types.boolean) -- {boolean: boolean} + tbl1:setproperty(types.singleton("number"), types.number) -- {boolean: boolean, number: number} + local tbl2 = types.newtable(nil, nil, nil) + tbl2:setproperty(types.singleton("boolean"), types.boolean) -- {boolean: boolean} + tbl2:setproperty(types.singleton("string"), types.string) -- {boolean: boolean, string: string} + local ty = types.intersectionof(tbl1, tbl2) + if ty:is("intersection") then + -- creating a copy of `ty` + local arr = {} + for index, value in ty:components() do + table.insert(arr, value) + end + return types.intersectionof(table.unpack(arr)) + end + -- this should never be returned + return types.string + end + -- forcing an error here to check the exact type of the intersection + local function ok(idx: getintersection<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ boolean: boolean, number: number } & { boolean: boolean, string: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_negation_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getnegation() + local ty = types.negationof(types.string) + if ty:is("negation") then + return ty + end + -- this should never be returned + return types.number + end + + -- forcing an error here to check the exact type of the negation + local function ok(idx: getnegation<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "~string"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_negation_inner") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(t) + return types.negationof(t):inner() +end + +type function fail(t) + return t:inner() +end + +local function ok(idx: pass): number return idx end +local function notok(idx: fail): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK( + toString(result.errors[0]) == + R"('fail' type function errored at runtime: [string "fail"]:7: type.inner: cannot call inner method on non-negation type: `number` type)" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_table(arg) + return arg + end + type type_being_serialized = { boolean: boolean, number: number, [string]: number } + -- forcing an error here to check the exact type of the table + local function ok(idx: serialize_table): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ [string]: number, boolean: boolean, number: number }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_table_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function gettable() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number] = boolean} + ty:setproperty(types.singleton("number"), types.string) -- {string: number, number: string, [number] = boolean} + ty:setproperty(types.singleton("string"), nil) -- {number: string, [number] = boolean} + local ret = types.newtable(nil, nil, nil) -- {} + -- creating a copy of `ty` + for k, v in ty:properties() do + ret:setreadproperty(k, v.read) + ret:setwriteproperty(k, v.write) + end + if ret:is("table") then + ret:setindexer(types.boolean, types.string) -- {number: string, [boolean] = string} + return ret -- {number: string, [boolean] = string} + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the table + local function ok(idx: gettable<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ [boolean]: string, number: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_metatable_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getmetatable() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metatbl = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + metatbl:setmetatable(types.newtable(nil, indexer, nil)) -- { { }, @metatable { [number]: boolean } } + local ret = metatbl:metatable() + if metatbl:is("table") and metatbl:metatable() then + return ret -- { @metatable { [number]: boolean } } + end + -- this should never be returned + return types.number + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getmetatable<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{boolean}"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_func(arg) + return arg + end + type type_being_serialized = (boolean, number, nil) -> (...string) + local function ok(idx: serialize_func): (boolean, number, nil) -> (...string) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_methods_work") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getfunction() + local ty = types.newfunction(nil, nil) -- () -> () + ty:setparameters({types.string, types.number}, nil) -- (string, number) -> () + ty:setreturns(nil, types.boolean) -- (string, number) -> (...boolean) + if ty:is("function") then + -- creating a copy of `ty` parameters + local arr = {} + for index, val in ty:parameters().head do + table.insert(arr, val) + end + return types.newfunction({head = arr}, ty:returns()) -- (string, number) -> (...boolean) + end + -- this should never be returned + return types.number + end + local function ok(idx: getfunction<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "(string, number) -> (...boolean)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_class_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_class(arg) + return arg + end + local function ok(idx: serialize_class): BaseClass return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_class_serialization_works2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauTypeFunFixHydratedClasses{FFlag::LuauTypeFunFixHydratedClasses, true}; + + CheckResult result = check(R"( + type function serialize_class(arg) + return arg + end + local function ok(idx: serialize_class): typeof(confusingBaseClassInstance) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_class_methods_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + + CheckResult result = check(R"( + type function getclass(arg) + local props = arg:properties() + local indexer = arg:indexer() + local metatable = arg:metatable() + return types.newtable(props, indexer, metatable) + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getclass): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ BaseField: number, read BaseMethod: (BaseClass, number) -> (), read Touched: Connection }"); +} + +TEST_CASE_FIXTURE(ClassFixture, "write_of_readonly_is_nil") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getclass(arg) + local props = arg:properties() + local table = types.newtable(props) + local singleton = types.singleton("BaseMethod") + + if table:writeproperty(singleton) then + return types.singleton(true) + else + return types.singleton(false) + end + end + -- forcing an error here to check the exact type of the metatable + local function ok(idx: getclass): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "false"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_check_mutability") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function checkmut() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(props, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metatbl = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + -- mutate the table + ty:setproperty(types.singleton("string"), nil) -- {[number]: boolean} + if metatbl:is("table") and metatbl:metatable() then + return metatbl -- { @metatable { [number]: boolean }, { } } + end + -- this should never be returned + return types.number + end + local function ok(idx: checkmut<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ @metatable {boolean}, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_copy_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function getcopy() + local indexer = { + index = types.number, + readresult = types.boolean, + writeresult = types.boolean, + } + local ty = types.newtable(nil, indexer, nil) -- {[number]: boolean} + ty:setproperty(types.singleton("string"), types.number) -- {string: number, [number]: boolean} + local metaty = types.newtable(nil, nil, ty) -- { { }, @metatable { [number]: boolean, string: number } } + local copy = types.copy(metaty) + -- mutate the table + ty:setproperty(types.singleton("string"), nil) -- {[number]: boolean} + if copy:is("table") and copy:metatable() then + return copy -- { { }, @metatable { [number]: boolean, string: number } } + end + -- this should never be returned + return types.number + end + local function ok(idx: getcopy<>): never return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ @metatable { [number]: boolean, string: number }, { } }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_simple_cyclic_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_cycle(arg) + return arg + end + type basety = { + first: basety2 + } + type basety2 = { + second: basety + } + local function ok(idx: serialize_cycle): basety return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_createtable_bad_metatable") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function badmetatable() + return types.newtable(nil, nil, types.number) + end + local function bad(arg: badmetatable<>) end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK( + e->message == "'badmetatable' type function errored at runtime: [string \"badmetatable\"]:3: types.newtable: expected to be given a table " + "type as a metatable, but got number instead" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_complex_cyclic_serialization_works") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function serialize_cycle2(arg) + return arg + end + type Employee = { + name: string, + department: Department? + } + type Department = { + name: string, + manager: Employee?, + employees: { Employee }, + company: Company? + } + type Company = { + name: string, + departments: { Department } + } + local function ok(idx: serialize_cycle2): Company return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_user_error_is_reported") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function errors_if_string(arg) + if arg:is("string") then + local a = 1 + error("We are in a math class! not english") + end + return arg + end + local function ok(idx: errors_if_string): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'errors_if_string' type function errored at runtime: [string \"errors_if_string\"]:5: We are in a math class! not english"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_call_metamethod") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function hello(arg) + error(type(arg)) + end + local function ok(idx: hello): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'hello' type function errored at runtime: [string \"hello\"]:3: userdata"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_type_overrides_eq_metamethod") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function hello() + local p1 = types.string + local p2 = types.string + local t1 = types.newtable(nil, nil, nil) + t1:setproperty(types.singleton("string"), types.boolean) + t1:setmetatable(t1) + local t2 = types.newtable(nil, nil, nil) + t2:setproperty(types.singleton("string"), types.boolean) + t1:setmetatable(t1) + if p1 == p2 and t1 == t2 then + return types.number + end + end + local function ok(idx: hello<>): number return idx end + )"); + + LUAU_CHECK_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_function_type_cant_call_get_props") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function hello(arg) + local arr = arg:properties() + end + local function ok(idx: hello<() -> ()>): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK( + e->message == "'hello' type function errored at runtime: [string \"hello\"]:3: type.properties: expected self to be either a table or class, " + "but got function instead" + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function foo() + return "hi" + end + type function bar() + return types.singleton(foo()) + end + local function ok(idx: bar<>): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "\"hi\""); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function first(arg) + return arg + end + type function second(arg) + return types.singleton(first(arg)) + end + type function third() + return second("hi") + end + local function ok(idx: third<>): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "\"hi\""); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_each_other_3") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + -- this function should not see 'fourth' function when invoked from 'third' that sees it + type function first(arg) + return fourth(arg) + end + type function second(arg) + return types.singleton(first(arg)) + end + + do + type function fourth(arg) + return arg + end + type function third() + return second("hi") + end + local function ok(idx: third<>): nil return idx end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('third' type function errored at runtime: [string "first"]:4: attempt to call a nil value)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_no_shared_state") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function foo() + if not glob then + glob = 'a' + else + glob ..= 'b' + end + + return glob + end + type function bar(prefix) + return types.singleton(prefix:value() .. foo()) + end + local function ok1(idx: bar<'x'>): nil return idx end + local function ok2(idx: bar<'y'>): nil return idx end + )"); + + // We are only checking first errors, others are mostly duplicates + LUAU_REQUIRE_ERROR_COUNT(8, result); + CHECK(toString(result.errors[0]) == R"('bar' type function errored at runtime: [string "foo"]:4: attempt to modify a readonly table)"); + CHECK(toString(result.errors[1]) == R"(Type function instance bar<"x"> is uninhabited)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_math_reset") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function foo(x) + return types.singleton(tostring(math.random(1, 100))) + end + local x: foo<'a'> = ('' :: any) :: foo<'b'> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_optionify") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function optionify(tbl) + if not tbl:is("table") then + error("Argument is not a table") + end + for k, v in tbl:properties() do + tbl:setproperty(k, types.unionof(v.read, types.singleton(nil))) + end + return tbl + end + type Person = { + name: string, + age: number, + alive: boolean + } + local function ok(idx: optionify): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ age: number?, alive: boolean?, name: string? }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_calling_illegal_global") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function illegal(arg) + gcinfo() -- this should error + + return arg -- this should not be reached + end + + local function ok(idx: illegal): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); // There are 2 type function uninhabited error, 2 user defined type function error + UserDefinedTypeFunctionError* e = get(result.errors[0]); + REQUIRE(e); + CHECK(e->message == "'illegal' type function errored at runtime: [string \"illegal\"]:3: this function is not supported in type functions"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recursion_and_gc") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function foo(tbl) + local count = 0 + for k,v in tbl:properties() do count += 1 end + if count < 100 then + tbl:setproperty(types.singleton(`m{count}`), types.string) + foo(tbl) + end + for i = 1,100 do table.create(10000) end + return tbl + end + type Test = {} + local function ok(idx: foo): nil return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_recovery_no_upvalues") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local var + + type function save_upvalue(arg) + var = 1 + return arg + end + + type test = "test" + local function ok(idx: save_upvalue): "test" + return idx + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == R"(Type function cannot reference outer local 'var')"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_follow") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type t0 = any + type function t0() + return types.any + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == R"(Redefinition of type 't0', previously defined at line 2)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_strip_indexer") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function stripindexer(tbl) + if not tbl:is("table") then + error("can only strip the indexer on a table!") + end + tbl:setindexer(types.never, types.never) + return tbl + end + + type map = { [number]: string, foo: string } + -- forcing an error here to check the exact type + local function ok(tbl: stripindexer): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypePackMismatch* tpm = get(result.errors[0]); + REQUIRE(tpm); + CHECK(toString(tpm->givenTp) == "{ foo: string }"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "no_type_methods_on_types") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function test(x) + return if types.is(x, "number") then types.string else types.boolean + end + local function ok(tbl: test): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('test' type function errored at runtime: [string "test"]:3: attempt to call a nil value)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "no_types_functions_on_type") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function test(x) + return x.singleton("a") + end + local function ok(tbl: test): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('test' type function errored at runtime: [string "test"]:3: attempt to call a nil value)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "no_metatable_writes") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function test(x) + local a = x.__index + a.is = function() return false end + return types.singleton(x.is("number")) + end + local function ok(tbl: test): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('test' type function errored at runtime: [string "test"]:4: attempt to index nil with 'is')"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "no_eq_field") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function test(x) + return types.singleton(x.__eq(x, types.number)) + end + local function ok(tbl: test): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"('test' type function errored at runtime: [string "test"]:3: attempt to call a nil value)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "tag_field") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function test(x) + return types.singleton(x.tag) + end + + local function ok1(tbl: test): never return tbl end + local function ok2(tbl: test): never return tbl end + local function ok3(tbl: test<{}>): never return tbl end + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + CHECK(toString(result.errors[0]) == R"(Type pack '"number"' could not be converted into 'never'; at [0], "number" is not a subtype of never)"); + CHECK(toString(result.errors[1]) == R"(Type pack '"string"' could not be converted into 'never'; at [0], "string" is not a subtype of never)"); + CHECK(toString(result.errors[2]) == R"(Type pack '"table"' could not be converted into 'never'; at [0], "table" is not a subtype of never)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_serialization") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function makemttbl() + local metaprops = { + [types.singleton("ma")] = types.boolean + } + local mt = types.newtable(metaprops) + + local props = { + [types.singleton("a")] = types.number + } + return types.newtable(props, nil, mt) + end + + type function id(x) + return x + end + + local a: number = {} :: id> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == R"(Type '{ @metatable { ma: boolean }, { a: number } }' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "nonstrict_mode") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +--!nonstrict +type function foo() return types.string end +local a: foo<> = "a" + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "implicit_export") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( +type function concat(a, b) + return types.singleton(a:value() .. b:value()) +end +export type Concat = concat +local a: concat<'first', 'second'> +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")"); + + CheckResult bResult = check(R"( +local Test = require(game.A); +local b: Test.Concat<'third', 'fourth'> + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + CHECK(toString(requireType("b")) == R"("thirdfourth")"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "local_scope") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function foo() + return "hi" +end +local function test() + type function bar() + return types.singleton(foo()) + end + + return ("" :: any) :: bar<> +end +local a = test() + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(toString(requireType("a")) == R"("hi")"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "explicit_export") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( +export type function concat(a, b) + return types.singleton(a:value() .. b:value()) +end +local a: concat<'first', 'second'> +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CHECK(toString(requireType("game/A", "a")) == R"("firstsecond")"); + + CheckResult bResult = check(R"( +local Test = require(game.A); +local b: Test.concat<'third', 'fourth'> + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + CHECK(toString(requireType("b")) == R"("thirdfourth")"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "print_to_error") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function t0(a) + print("Where does this go") + print(a.tag) + return types.any + end + local a: t0 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK(toString(result.errors[0]) == R"(Where does this go)"); + CHECK(toString(result.errors[1]) == R"(string)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "print_to_error_plus_error") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function t0(a) + print("Where does this go") + print(a.tag) + error("test") + end + local a: t0 + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Where does this go)"); + CHECK(toString(result.errors[1]) == R"(string)"); + CHECK(toString(result.errors[2]) == R"('t0' type function errored at runtime: [string "t0"]:5: test)"); + CHECK(toString(result.errors[3]) == R"(Type function instance t0 is uninhabited)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "print_to_error_plus_no_result") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + type function t0(a) + print("Where does this go") + print(a.tag) + end + local a: t0 + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Where does this go)"); + CHECK(toString(result.errors[1]) == R"(string)"); + CHECK(toString(result.errors[2]) == R"('t0' type function: returned a non-type value)"); + CHECK(toString(result.errors[3]) == R"(Type function instance t0 is uninhabited)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_serialization_1") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + return arg +end + +type test = (T, { x: (y: T) -> (), y: U }, U) -> () + +local function ok(idx: pass): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_serialization_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + return arg +end + +type test = (T) -> (T, U...) + +local function ok(idx: pass): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_serialization_3") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + return arg +end + +local function m(a, b) + return {x = a, y = b} +end + +type test = typeof(m) + +local function ok(idx: pass): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_cloning_1") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + return types.copy(arg) +end + +type test = (T, { x: (y: T) -> (), y: U }, U) -> () + +local function ok(idx: pass): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_cloning_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + return types.copy(arg) +end + +type test = (T) -> (T, U...) + +local function ok(idx: pass): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_equality") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + return types.singleton(types.copy(arg) == arg) +end + +type test = (T) -> (T, U...) + +local function ok(idx: pass): true return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_1") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + local generics = arg:generics() + local T = generics[1] + return types.newfunction({ head = {T} }, { head = {T} }, {T}) +end + +type test = (T, { x: (y: T) -> (), y: U }, U) -> () + +local function ok(idx: pass): (T) -> (T) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + local generics = arg:generics() + local T = generics[1] + local f = types.newfunction() + f:setparameters({T, T}); + f:setreturns({T}); + f:setgenerics({T}); + return f +end + +type test = (T, { x: (y: T) -> (), y: U }, U) -> () + +local function ok(idx: pass): (T, T) -> (T) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_3") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass() + local T = types.generic("T") + assert(T.tag == "generic") + assert(T:name() == "T") + assert(T:ispack() == false) + + local Us, Vs = types.generic("U", true), types.generic("V", true) + assert(Us.tag == "generic") + assert(Us:name() == "U") + assert(Us:ispack() == true) + + local f = types.newfunction() + f:setparameters({T}, Us); + f:setreturns({T}, Vs); + f:setgenerics({T, Us, Vs}); + return f +end + +local function ok(idx: pass<>): (T, U...) -> (T, V...) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_4") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass() + local T, U = types.generic("T"), types.generic("U") + + -- (T) -> () + local func = types.newfunction({ head = {T} }, {}, {T}); + + -- { x: (T) -> (), y: U } + local tbl = types.newtable({ [types.singleton("x")] = func, [types.singleton("y")] = U }) + + -- (T, { x: (T) -> (), y: U }, U) -> () + return types.newfunction({ head = {T, tbl, U } }, {}, {T, U}) +end + +type test = (T, { x: (y: T) -> (), y: U }, U) -> () + +local function ok(idx: pass<>): test return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_5") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass() + local T = types.generic("T") + return types.newfunction({ head = {T} }, {}, {types.copy(T)}) +end + +local function ok(idx: pass<>): (T) -> () return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_6") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + local generics = arg:generics() + local T, U = generics[1], generics[2] + local f = types.newfunction() + f:setparameters({T}); + f:setreturns({U}); + f:setgenerics({T, U}); + return f +end + +local function m(a, b) + return {x = a, y = b} +end + +type test = typeof(m) + +local function ok(idx: pass): (T) -> (U) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_7") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + local p, r = arg:parameters(), arg:returns() + local f = types.newfunction() + f:setparameters(p.head, p.tail); + f:setreturns(r.head, r.tail); + f:setgenerics(arg:generics()); + return f +end + +type test = (T, U...) -> (T, U...) + +local function ok(idx: pass): (T, U...) -> (T, U...) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_8") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + local p, r = arg:parameters(), arg:returns() + local f = types.newfunction() + f:setparameters(p.head, p.tail); + f:setreturns(r.head, r.tail); + f:setgenerics(arg:generics()); + return f +end + +type test = (U...) -> (U...) + +local function ok(idx: pass): (T, T) -> (T) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_equality_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + + local tbl1 = types.newtable({ [types.singleton("x")] = T }) + local tbl2 = types.newtable({ [types.singleton("x")] = Us }) -- it is possible to have invalid types in-flight + + return types.singleton(tbl1 == tbl2) +end + +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_1") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + return types.newfunction({}, {}, {Us, T}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK( + toString(result.errors[0]) == + R"('get' type function errored at runtime: [string "get"]:4: types.newfunction: generic type cannot follow a generic pack)" + ); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_2") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + return types.newfunction({ head = {T} }, {}, {}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Generic type 'T' is not in a scope of the active generic function)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_3") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function get() + local T, U = types.generic("T"), types.generic("U") + + -- (U) -> () + local func = types.newfunction({ head = {U} }, {}, {U}); + + -- broken: (T, (U) -> (), U) -> () + return types.newfunction({ head = {T, func, U } }, {}, {T}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Generic type 'U' is not in a scope of the active generic function)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_4") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + return types.newfunction({ head = {T} }, { tail = Us }, {T, T}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Duplicate type parameter 'T')"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_5") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function get() + local T, Ts = types.generic("T"), types.generic("T", true) + return types.newfunction({ head = {T} }, { tail = Ts }, {T, Ts}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Duplicate type parameter 'T')"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_6") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + return types.newfunction({ head = {Us} }, {}, {T, Us}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Generic type pack 'U...' cannot be placed in a type position)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_generic_api_error_7") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function get() + local T, Us = types.generic("T"), types.generic("U", true) + return types.newfunction({ tail = Us }, {}, {T}) +end +local function ok(idx: get<>): false return idx end + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK(toString(result.errors[0]) == R"(Generic type pack 'U...' is not in a scope of the active generic function)"); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_variadic_api") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( +type function pass(arg) + local p, r = arg:parameters(), arg:returns() + local f = types.newfunction() + f:setparameters({p.tail}, p.head[1]); + f:setreturns({r.tail}, r.head[1]); + return f +end + +type test = (string, ...number) -> (number, ...string) + +local function ok(idx: pass): (number, ...string) -> (string, ...number) return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_eqsat_opaque") +{ + ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::DebugLuauEqSatSimplification, true}}; + + CheckResult _ = check(R"( + type function t0(a) + error("test") + end + local v: t0 + )"); + TypeArena arena; + auto ty = requireType("v"); + auto simplifier = EqSatSimplification::newSimplifier(NotNull{&arena}, frontend.builtinTypes); + auto simplified = eqSatSimplify(NotNull{simplifier.get()}, ty); + REQUIRE(simplified); + CHECK_EQ("t0", toString(simplified->result)); // NOLINT(bugprone-unchecked-optional-access) +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_singleton_equality_bool") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauTypeFunSingletonEquality{FFlag::LuauTypeFunSingletonEquality, true}; + + CheckResult result = check(R"( +type function compare(arg) + return types.singleton(types.singleton(false) == arg) +end + +local function ok(idx: compare): true return idx end +local function ok(idx: compare): false return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_singleton_equality_string") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauTypeFunSingletonEquality{FFlag::LuauTypeFunSingletonEquality, true}; + + CheckResult result = check(R"( +type function compare(arg) + return types.singleton(types.singleton("") == arg) +end + +local function ok(idx: compare<"">): true return idx end +local function ok(idx: compare<"a">): false return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "typeof_type_userdata_returns_type") +{ + ScopedFastFlag solverV2{FFlag::LuauSolverV2, true}; + ScopedFastFlag luauUserTypeFunTypeofReturnsType{FFlag::LuauUserTypeFunTypeofReturnsType, true}; + + CheckResult result = check(R"( +type function test(t) + print(typeof(t)) + return t +end + +local _:test + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(toString(result.errors[0]) == R"(type)"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "udtf_print_tab_char_fix") +{ + ScopedFastFlag sffs[] = {{FFlag::LuauSolverV2, true}, {FFlag::LuauTypeFunPrintFix, true}}; + + CheckResult result = check(R"( + type function test(t) + print(1,2) + + return t + end + + local _:test + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + // It should be \t and not \x1 + CHECK_EQ("1\t2", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(ClassFixture, "udtf_class_parent_ops") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + ScopedFastFlag readWriteParents{FFlag::LuauTypeFunReadWriteParents, true}; + + CheckResult result = check(R"( + type function readparentof(arg) + return arg:readparent() + end + + type function writeparentof(arg) + return arg:writeparent() + end + + local function ok1(idx: readparentof): BaseClass return idx end + local function ok2(idx: writeparentof): BaseClass return idx end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index bb5a2cdd..3972fd6b 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -2,6 +2,7 @@ #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/AstQuery.h" @@ -9,7 +10,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauUserDefinedTypeFunctions) +LUAU_FASTFLAG(LuauFixInfiniteRecursionInNormalization) TEST_SUITE_BEGIN("TypeAliases"); @@ -105,7 +106,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_steal_hoisted_type_alias") TEST_CASE_FIXTURE(Fixture, "mismatched_generic_type_param") { // We erroneously report an extra error in this case when the new solver is enabled. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type T = (A...) -> () @@ -244,7 +245,7 @@ TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_generic_aliases") { // CLI-116108 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -354,9 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_typ // Check that recursive intersection type doesn't generate an OOM TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") { - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; // FIXME + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any @@ -417,7 +416,7 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") TEST_CASE_FIXTURE(Fixture, "generic_param_remap") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); const std::string code = R"( -- An example of a forwarded use of a type that has different type arguments than parameters @@ -543,7 +542,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type Cool = { a: number, b: string } @@ -564,7 +563,7 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type Cool = { a: number, b: string } @@ -712,7 +711,7 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") { // CLI-116108 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( -- OK because forwarded types are used with their parameters. @@ -726,7 +725,7 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") { // CLI-116108 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( -- Not OK because forwarded types are used with different types than their parameters. @@ -750,7 +749,7 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") { // CLI-116108 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type Tree1 = { data: T, children: {Tree2} } @@ -875,7 +874,7 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") { // CLI-116108 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( -- this would be an infinite type if we allowed it @@ -888,7 +887,7 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") TEST_CASE_FIXTURE(Fixture, "report_shadowed_aliases") { // CLI-116110 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); // We allow a previous type alias to depend on a future type alias. That exact feature enables a confusing example, like the following snippet, // which has the type alias FakeString point to the type alias `string` that which points to `number`. @@ -972,7 +971,7 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_locations") TEST_CASE_FIXTURE(BuiltinsFixture, "dont_lose_track_of_PendingExpansionTypes_after_substitution") { // CLI-114134 - We need egraphs to properly simplify these types. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); fileResolver.source["game/ReactCurrentDispatcher"] = R"( export type BasicStateAction = ((S) -> S) | S @@ -1155,7 +1154,7 @@ type Foo = Foo | string TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_adds_reduce_constraint_for_type_function") { - if (!FFlag::LuauSolverV2 || !FFlag::LuauUserDefinedTypeFunctions) + if (!FFlag::LuauSolverV2) return; CheckResult result = check(R"( @@ -1167,18 +1166,47 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_adds_reduce_constraint_for_type_f LUAU_CHECK_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "user_defined_type_function_errors") +TEST_CASE_FIXTURE(Fixture, "bound_type_in_alias_segfault") { - if (!FFlag::LuauUserDefinedTypeFunctions) - return; + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + LUAU_CHECK_NO_ERRORS(check(R"( + --!nonstrict + type Map = {[ K]: V} + function foo:bar(): Config end + type Config = Map & { fields: FieldConfigMap} + export type FieldConfig = {[ string]: any} + export type FieldConfigMap = Map> + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "gh1632_no_infinite_recursion_in_normalization") +{ + ScopedFastFlag flags[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauFixInfiniteRecursionInNormalization, true}, + }; CheckResult result = check(R"( - type function foo() - return nil - end + type Node = { + value: T, + next: Node?, + -- remove `prev`, solves issue + prev: Node?, + }; + + type List = { + head: Node? + } + + local function IsFront(list: List, nodeB: Node) + -- remove if statement below, solves issue + if (list.head == nodeB) then + end + end )"); - LUAU_CHECK_ERROR_COUNT(1, result); - CHECK(toString(result.errors[0]) == "This syntax is not supported"); + + LUAU_CHECK_NO_ERRORS(result); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 9912cc35..96443aeb 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" #include "Fixture.h" @@ -8,7 +9,10 @@ using namespace Luau; -LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauTableCloneClonesType3) +LUAU_FASTFLAG(LuauStringFormatErrorSuppression) +LUAU_FASTFLAG(LuauFreezeIgnorePersistent) TEST_SUITE_BEGIN("BuiltinTests"); @@ -132,7 +136,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_predicate") TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -512,7 +516,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "buffer_is_a_type") TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_resume_anything_goes") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function nifty(x, y) @@ -802,6 +806,17 @@ TEST_CASE_FIXTURE(Fixture, "string_format_as_method") CHECK_EQ(tm->givenType, builtinTypes->numberType); } +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_trivial_arity") +{ + CheckResult result = check(R"( + string.format() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Argument count mismatch. Function 'string.format' expects at least 1 argument, but none are specified", toString(result.errors[0])); +} + TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument") { CheckResult result = check(R"( @@ -1109,15 +1124,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") local c = tf3[2] local d = tf1.b + + local a2 = t1.a + local b2 = t2.b + local c2 = t3[2] )"); LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::LuauSolverV2) - CHECK("Key 'b' not found in table '{ a: number }'" == toString(result.errors[0])); + CHECK("Key 'b' not found in table '{ read a: number }'" == toString(result.errors[0])); else CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); CHECK(Location({13, 18}, {13, 23}) == result.errors[0].location); + if (FFlag::LuauSolverV2) + { + CHECK_EQ("{ read a: number }", toString(requireTypeAtPosition({15, 19}))); + CHECK_EQ("{ read b: string }", toString(requireTypeAtPosition({16, 19}))); + CHECK_EQ("{boolean}", toString(requireTypeAtPosition({17, 19}))); + } + CHECK_EQ("number", toString(requireType("a"))); CHECK_EQ("string", toString(requireType("b"))); CHECK_EQ("boolean", toString(requireType("c"))); @@ -1126,12 +1152,134 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_is_generic") CHECK_EQ("any", toString(requireType("d"))); else CHECK_EQ("*error-type*", toString(requireType("d"))); + + CHECK_EQ("number", toString(requireType("a2"))); + CHECK_EQ("string", toString(requireType("b2"))); + CHECK_EQ("boolean", toString(requireType("c2"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_does_not_retroactively_block_mutation") +{ + CheckResult result = check(R"( + local t1 = {a = 42} + + t1.q = ":3" + + local tf1 = table.freeze(t1) + + local a = tf1.a + local b = t1.a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauSolverV2) + { + CHECK_EQ("{ a: number, q: string } | { read a: number, read q: string }", toString(requireType("t1"), {/*exhaustive */ true})); + // before the assignment, it's `t1` + CHECK_EQ("{ a: number, q: string }", toString(requireTypeAtPosition({3, 8}), {/*exhaustive */ true})); + // after the assignment, it's read-only. + CHECK_EQ("{ read a: number, read q: string }", toString(requireTypeAtPosition({8, 18}), {/*exhaustive */ true})); + } + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_no_generic_table") +{ + ScopedFastFlag sff{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + --!strict + type k = { + read k: string, + } + + function _(): k + return table.freeze({ + k = "", + }) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_on_metatable") +{ + CheckResult result = check(R"( + --!strict + local meta = { + __index = function() + return "foo" + end + } + + local myTable = setmetatable({}, meta) + table.freeze(myTable) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_errors_on_no_args") +{ + CheckResult result = check(R"( + --!strict + table.freeze() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_errors_on_non_tables") +{ + CheckResult result = check(R"( + --!strict + table.freeze(42) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(tm->wantedType), "table"); + else + CHECK_EQ(toString(tm->wantedType), "{- -}"); + CHECK_EQ(toString(tm->givenType), "number"); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_persistent_skip") +{ + ScopedFastFlag luauFreezeIgnorePersistent{FFlag::LuauFreezeIgnorePersistent, true}; + + CheckResult result = check(R"( + table.freeze(table) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_clone_persistent_skip") +{ + ScopedFastFlag luauFreezeIgnorePersistent{FFlag::LuauFreezeIgnorePersistent, true}; + + CheckResult result = check(R"( + table.clone(table) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "set_metatable_needs_arguments") { // In the new solver, nil can certainly be used where a generic is required, so all generic parameters are optional. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local a = {b=setmetatable} @@ -1427,4 +1575,91 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "find_capture_types3") CHECK_EQ(toString(requireType("e")), "number?"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "string_find_should_not_crash") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local function StringSplit(input, separator) + string.find(input, separator) + if not separator then + separator = "%s+" + end + end + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_dot_clone_type_states") +{ + CheckResult result = check(R"( + local t1 = {} + t1.x = 5 + local t2 = table.clone(t1) + t2.y = 6 + t1.z = 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauTableCloneClonesType3) + { + CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, z: number }"); + CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number }"); + } + else + { + CHECK_EQ(toString(requireType("t1"), {true}), "{ x: number, y: number, z: number }"); + CHECK_EQ(toString(requireType("t2"), {true}), "{ x: number, y: number, z: number }"); + } +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_clone_should_not_break") +{ + CheckResult result = check(R"( + local Immutable = {} + + function Immutable.Set(dictionary, key, value) + local new = table.clone(dictionary) + + new[key] = value + + return new + end + + return Immutable + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_clone_should_not_break_2") +{ + CheckResult result = check(R"( + function set(dictionary, key, value) + local new = table.clone(dictionary) + + new[key] = value + + return new + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_should_support_any") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + local x: any = "world" + print(string.format("Hello, %s!", x)) + )"); + + if (FFlag::LuauStringFormatErrorSuppression) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 16751559..53f1396d 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -128,7 +128,7 @@ TEST_CASE_FIXTURE(ClassFixture, "we_can_infer_that_a_parameter_must_be_a_particu TEST_CASE_FIXTURE(ClassFixture, "we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function makeClone(o) @@ -235,7 +235,7 @@ TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class_using_string") TEST_CASE_FIXTURE(ClassFixture, "cannot_unify_class_instance_with_primitive") { // This is allowed in the new solver - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local v = Vector2.New(0, 5) @@ -472,7 +472,7 @@ Type 'number' could not be converted into 'string')"; TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") { // CLI-116433 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local i = ChildClass.New() @@ -665,12 +665,11 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") )"); if (FFlag::LuauSolverV2) - CHECK( - "Type 'boolean' could not be converted into 'number | string'" == toString(result.errors.at(0)) - ); + CHECK("Type 'boolean' could not be converted into 'number | string'" == toString(result.errors.at(0))); else CHECK_EQ( - toString(result.errors.at(0)), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible" + toString(result.errors.at(0)), + "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible" ); } { @@ -680,12 +679,11 @@ TEST_CASE_FIXTURE(ClassFixture, "indexable_classes") )"); if (FFlag::LuauSolverV2) - CHECK( - "Type 'boolean' could not be converted into 'number | string'" == toString(result.errors.at(0)) - ); + CHECK("Type 'boolean' could not be converted into 'number | string'" == toString(result.errors.at(0))); else CHECK_EQ( - toString(result.errors.at(0)), "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible" + toString(result.errors.at(0)), + "Type 'boolean' could not be converted into 'number | string'; none of the union options are compatible" ); } diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index e1eaf5e9..90593891 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -9,6 +9,10 @@ using namespace Luau; +LUAU_FASTFLAG(LuauClipNestedAndRecursiveUnion) +LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTFLAG(LuauPreventReentrantTypeFunctionReduction) + TEST_SUITE_BEGIN("DefinitionTests"); TEST_CASE_FIXTURE(Fixture, "definition_file_simple") @@ -443,6 +447,26 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") CHECK_EQ(toString(requireType("y")), "string"); } +TEST_CASE_FIXTURE(Fixture, "class_definition_malformed_string") +{ + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = frontend.loadDefinitionFile( + frontend.globals, + frontend.globals.globalScope, + R"( + declare class Foo + ["a\0property"]: string + end + )", + "@test", + /* captureComments */ false + ); + freeze(frontend.globals.globalTypes); + + REQUIRE(!result.success); + REQUIRE_EQ(result.parseResult.errors.size(), 1); + CHECK_EQ(result.parseResult.errors[0].getMessage(), "String literal contains malformed escape sequence or \\0"); +} TEST_CASE_FIXTURE(Fixture, "class_definition_indexer") { @@ -472,11 +496,7 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_indexer") TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { - unfreeze(frontend.globals.globalTypes); - LoadDefinitionFileResult result = frontend.loadDefinitionFile( - frontend.globals, - frontend.globals.globalScope, - R"( + loadDefinition(R"( declare class Channel Messages: { Message } OnMessage: (message: Message) -> () @@ -486,13 +506,19 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") Text: string Channel: Channel end - )", - "@test", - /* captureComments */ false - ); - freeze(frontend.globals.globalTypes); + )"); - REQUIRE(result.success); + CheckResult result = check(R"( + local a: Channel + local b = a.Messages[1] + local c = b.Channel + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Channel"); + CHECK_EQ(toString(requireType("b")), "Message"); + CHECK_EQ(toString(requireType("c")), "Channel"); } TEST_CASE_FIXTURE(Fixture, "definition_file_has_source_module_name_set") @@ -516,4 +542,47 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_has_source_module_name_set") CHECK_EQ(ctv->definitionModuleName, "@test"); } +TEST_CASE_FIXTURE(Fixture, "recursive_redefinition_reduces_rightfully") +{ + ScopedFastFlag _{FFlag::LuauClipNestedAndRecursiveUnion, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local t: {[string]: string} = {} + + local function f() + t = t + end + + t = t + )")); +} + +TEST_CASE_FIXTURE(Fixture, "vector3_overflow") +{ + ScopedFastFlag _{FFlag::LuauPreventReentrantTypeFunctionReduction, true}; + // We set this to zero to ensure that we either run to completion or stack overflow here. + ScopedFastInt sfi{FInt::LuauTypeInferRecursionLimit, 0}; + + loadDefinition(R"( + declare class Vector3 + function __add(self, other: Vector3): Vector3 + end + )"); + + CheckResult result = check(R"( +--!strict +local function graphPoint(t : number, points : { Vector3 }) : Vector3 + local n : number = #points - 1 + local p : Vector3 = (nil :: any) + for i = 0, n do + local x = points[i + 1] + p = p and p + x or x + end + return p +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index c76745c7..f13524fb 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -16,11 +16,14 @@ using namespace Luau; -LUAU_FASTFLAG(LuauInstantiateInSubtyping); -LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTINT(LuauTarjanChildLimit); +LUAU_FASTFLAG(DebugLuauAssertOnForcedConstraint) -LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTINT(LuauTarjanChildLimit) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) +LUAU_FASTFLAG(LuauSubtypingFixTailPack) +LUAU_FASTFLAG(LuauUngeneralizedTypesForRecursiveFunctions) TEST_SUITE_BEGIN("TypeInferFunctions"); @@ -312,7 +315,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") TEST_CASE_FIXTURE(Fixture, "too_many_arguments") { // This is not part of the new non-strict specification currently. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict @@ -604,7 +607,7 @@ TEST_CASE_FIXTURE(Fixture, "duplicate_functions_allowed_in_nonstrict") TEST_CASE_FIXTURE(Fixture, "duplicate_functions_with_different_signatures_not_allowed_in_nonstrict") { // This is not part of the spec for the new non-strict mode currently. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict @@ -682,6 +685,11 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") { + // CLI-114134: this code *probably* wants the egraph in order + // to work properly. The new solver either falls over or + // forces so many constraints as to be unreliable. + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( function bottomupmerge(comp, a, b, left, mid, right) local i, j = left, mid @@ -744,6 +752,11 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") TEST_CASE_FIXTURE(BuiltinsFixture, "higher_order_function_4") { + // CLI-114134: this code *probably* wants the egraph in order + // to work properly. The new solver either falls over or + // forces so many constraints as to be unreliable. + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( function bottomupmerge(comp, a, b, left, mid, right) local i, j = left, mid @@ -883,7 +896,7 @@ TEST_CASE_FIXTURE(Fixture, "another_indirect_function_case_where_it_is_ok_to_pro TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_nonstrict") { // new non-strict mode spec does not include this error yet. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict @@ -1018,7 +1031,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "calling_function_with_anytypepack_doesnt_lea TEST_CASE_FIXTURE(Fixture, "too_many_return_values") { // FIXME: CLI-116157 variadic and generic type packs seem to be interacting incorrectly. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -1042,7 +1055,7 @@ TEST_CASE_FIXTURE(Fixture, "too_many_return_values") TEST_CASE_FIXTURE(Fixture, "too_many_return_values_in_parentheses") { // FIXME: CLI-116157 variadic and generic type packs seem to be interacting incorrectly. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -1066,7 +1079,7 @@ TEST_CASE_FIXTURE(Fixture, "too_many_return_values_in_parentheses") TEST_CASE_FIXTURE(Fixture, "too_many_return_values_no_function") { // FIXME: CLI-116157 variadic and generic type packs seem to be interacting incorrectly. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -1235,7 +1248,7 @@ TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") { // FIXME: CLI-116133 bidirectional type inference needs to push expected types in for higher-order function calls - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); // Simple direct arg to arg propagation CheckResult result = check(R"( @@ -1354,7 +1367,7 @@ f(function(x) return x * 2 end) TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") { // FIXME: CLI-116133 bidirectional type inference needs to push expected types in for higher-order function calls - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end @@ -1385,7 +1398,7 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") { // FIXME: CLI-116133 bidirectional type inference needs to push expected types in for higher-order function calls - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function g1(a: T, f: (T) -> T) return f(a) end @@ -1452,7 +1465,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "variadic_any_is_compatible_with_a_generic_Ty TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") { // FIXME: CLI-116133 bidirectional type inference needs to push expected types in for higher-order function calls - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type Table = { x: number, y: number } @@ -1495,7 +1508,7 @@ end TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") { // FIXME: CLI-116111 test disabled until type path stringification is improved - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type A = (number, number) -> string @@ -1518,7 +1531,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") { // FIXME: CLI-116111 test disabled until type path stringification is improved - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type A = (number, number) -> string @@ -1542,7 +1555,7 @@ Type 'string' could not be converted into 'number')"; TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") { // FIXME: CLI-116111 test disabled until type path stringification is improved - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type A = (number, number) -> (number) @@ -1565,7 +1578,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") { // FIXME: CLI-116111 test disabled until type path stringification is improved - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type A = (number, number) -> string @@ -1589,7 +1602,7 @@ Type 'string' could not be converted into 'number')"; TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") { // FIXME: CLI-116111 test disabled until type path stringification is improved - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type A = (number, number) -> (number, string) @@ -1720,7 +1733,7 @@ TEST_CASE_FIXTURE(Fixture, "inferred_higher_order_functions_are_quantified_at_th { // This test regresses in the new solver, but is sort of nonsensical insofar as `foo` is known to be `nil`, so it's "right" to not be able to call // it. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local foo @@ -1789,7 +1802,7 @@ TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") { // FIXME: CLI-116122 bug where `t:b` does not check against the type from the indexer annotation on `t`. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local t: {[string]: () -> number} = {} @@ -1834,7 +1847,7 @@ TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic") { // FIXME: CLI-116157 variadic and generic type packs seem to be interacting incorrectly. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function test(a: number, b: string, ...) @@ -1862,7 +1875,7 @@ wrapper(test) TEST_CASE_FIXTURE(BuiltinsFixture, "too_few_arguments_variadic_generic2") { // FIXME: CLI-116157 variadic and generic type packs seem to be interacting incorrectly. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function test(a: number, b: string, ...) @@ -2016,7 +2029,7 @@ u.b().foo() TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_error_nonstrict") { // This behavior is not part of the current specification of the new type solver. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict @@ -2031,7 +2044,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "improved_function_arg_mismatch_error_nonstri TEST_CASE_FIXTURE(Fixture, "luau_subtyping_is_np_hard") { // The case that _should_ succeed here (`z = x`) does not currently in the new solver. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -2340,20 +2353,10 @@ TEST_CASE_FIXTURE(Fixture, "attempt_to_call_an_intersection_of_tables") LUAU_REQUIRE_ERROR_COUNT(1, result); - if (DFFlag::LuauImproveNonFunctionCallError) - { - if (FFlag::LuauSolverV2) - CHECK_EQ(toString(result.errors[0]), "Cannot call a value of type { x: number } & { y: string }"); - else - CHECK_EQ(toString(result.errors[0]), "Cannot call a value of type {| x: number |}"); - } + if (FFlag::LuauSolverV2) + CHECK_EQ(toString(result.errors[0]), "Cannot call a value of type { x: number } & { y: string }"); else - { - if (FFlag::LuauSolverV2) - CHECK_EQ(toString(result.errors[0]), "Cannot call non-function { x: number } & { y: string }"); - else - CHECK_EQ(toString(result.errors[0]), "Cannot call non-function {| x: number |}"); - } + CHECK_EQ(toString(result.errors[0]), "Cannot call a value of type {| x: number |}"); } TEST_CASE_FIXTURE(BuiltinsFixture, "attempt_to_call_an_intersection_of_tables_with_call_metamethod") @@ -2565,8 +2568,14 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_return_type") { - if (!FFlag::LuauSolverV2) - return; + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + // CLI-114134: This test: + // a) Has a kind of weird result (suggesting `number | false` is not great); + // b) Is force solving some constraints. + // We end up with a weird recursive type that, if you roughly look at it, is + // clearly `number`. Hopefully the egraph will be able to unfold this. + CheckResult result = check(R"( function fib(n) return n < 2 and 1 or fib(n-1) + fib(n-2) @@ -2576,9 +2585,7 @@ end LUAU_REQUIRE_ERRORS(result); auto err = get(result.errors.back()); LUAU_ASSERT(err); - CHECK("number" == toString(err->recommendedReturn)); - REQUIRE(1 == err->recommendedArgs.size()); - CHECK("number" == toString(err->recommendedArgs[0].second)); + CHECK("false | number" == toString(err->recommendedReturn)); } TEST_CASE_FIXTURE(BuiltinsFixture, "tf_suggest_arg_type") @@ -2845,17 +2852,12 @@ TEST_CASE_FIXTURE(Fixture, "cannot_call_union_of_functions") LUAU_REQUIRE_ERROR_COUNT(1, result); - if (DFFlag::LuauImproveNonFunctionCallError) - { - std::string expected = R"(Cannot call a value of the union type: + std::string expected = R"(Cannot call a value of the union type: | () -> () | () -> () -> () We are unable to determine the appropriate result type for such a call.)"; - CHECK(expected == toString(result.errors[0])); - } - else - CHECK("Cannot call non-function (() -> () -> ()) | (() -> ())" == toString(result.errors[0])); + CHECK(expected == toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "fuzzer_missing_follow_in_ast_stat_fun") @@ -2905,7 +2907,7 @@ TEST_CASE_FIXTURE(Fixture, "unifier_should_not_bind_free_types") auto tm2 = get(result.errors[1]); REQUIRE(tm2); CHECK(toString(tm2->wantedTp) == "string"); - CHECK(toString(tm2->givenTp) == "buffer | class | function | number | string | table | thread | true"); + CHECK(toString(tm2->givenTp) == "(buffer | class | function | number | string | table | thread | true) & unknown"); } else { @@ -3011,4 +3013,79 @@ local u,v = id(3), id(id(44)) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "hidden_variadics_should_not_break_subtyping") +{ + CheckResult result = check(R"( + --!strict + type FooType = { + SetValue: (Value: number) -> () + } + + local Foo: FooType = { + SetValue = function(Value: number) + + end + } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "coroutine_wrap_result_call") +{ + ScopedFastFlag luauSubtypingFixTailPack{FFlag::LuauSubtypingFixTailPack, true}; + + CheckResult result = check(R"( + function foo(a, b) + coroutine.wrap(a)(b) + end + )"); + + // New solver still reports an error in this case, but the main goal of the test is to not crash +} + +TEST_CASE_FIXTURE(Fixture, "recursive_function_calls_should_not_use_the_generalized_type") +{ + ScopedFastFlag crashOnForce{FFlag::DebugLuauAssertOnForcedConstraint, true}; + ScopedFastFlag sff{FFlag::LuauUngeneralizedTypesForRecursiveFunctions, true}; + + CheckResult result = check(R"( + --!strict + + function random() + return true -- chosen by fair coin toss + end + + local f + f = 5 + function f() + if random() then f() end + end + )"); + + if (FFlag::LuauSolverV2) + LUAU_REQUIRE_NO_ERRORS(result); + else + LUAU_REQUIRE_ERRORS(result); // errors without typestate, obviously +} + +TEST_CASE_FIXTURE(Fixture, "recursive_function_calls_should_not_use_the_generalized_type_2") +{ + ScopedFastFlag crashOnForce{FFlag::DebugLuauAssertOnForcedConstraint, true}; + + CheckResult result = check(R"( + --!strict + + function random() + return true -- chosen by fair coin toss + end + + local function f() + if random() then f() end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index ffd01f24..0a53836b 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -12,6 +12,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping); LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(LuauDeferBidirectionalInferenceForTableAssignment) using namespace Luau; @@ -143,7 +144,7 @@ TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local t: { m: (number)->number } = { m = function(x:number) return x+1 end } @@ -261,7 +262,7 @@ TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions_errors") TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types_old_solver") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type T = { id: (a) -> a } @@ -287,7 +288,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types_new_solver") TEST_CASE_FIXTURE(Fixture, "generic_factories") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type T = { id: (a) -> a } @@ -310,7 +311,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_factories") TEST_CASE_FIXTURE(Fixture, "factories_of_generics") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type T = { id: (a) -> a } @@ -775,7 +776,7 @@ return exports TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names_old_solver") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function f(a: T, ...: U...) end @@ -794,7 +795,7 @@ TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names_old_solver") TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_types") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type C = () -> () @@ -811,7 +812,7 @@ local d: D = c TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type C = () -> () @@ -854,6 +855,8 @@ end TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") { + ScopedFastFlag _{FFlag::LuauDeferBidirectionalInferenceForTableAssignment, true}; + CheckResult result = check(R"( --!strict -- At one point this produced a UAF @@ -865,15 +868,19 @@ local y: T = { a = { c = nil, d = 5 }, b = 37 } y.a.c = y )"); - LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauSolverV2) - CHECK( - toString(result.errors.at(0)) == - R"(Type '{ a: { c: nil, d: number }, b: number }' could not be converted into 'T'; type { a: { c: nil, d: number }, b: number }[read "a"][read "c"] (nil) is not exactly T[read "a"][read "c"][0] (T))" - ); + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto mismatch = get(result.errors.at(0)); + CHECK(mismatch); + CHECK_EQ(toString(mismatch->givenType), "{ a: { c: T?, d: number }, b: number }"); + CHECK_EQ(toString(mismatch->wantedType), "T"); + std::string reason = "at [read \"a\"][read \"d\"], number is not exactly string\n\tat [read \"b\"], number is not exactly string"; + CHECK_EQ(mismatch->reason, reason); + } else { + LUAU_REQUIRE_ERROR_COUNT(2, result); const std::string expected = R"(Type 'y' could not be converted into 'T' caused by: Property 'a' is not compatible. @@ -887,7 +894,7 @@ Type 'number' could not be converted into 'string' in an invariant context)"; TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -907,7 +914,7 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification2") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -927,7 +934,7 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification3") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -947,7 +954,7 @@ local TheDispatcher: Dispatcher = { TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_few") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function test(a: number) @@ -966,7 +973,7 @@ wrapper(test) TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function test2(a: number, b: string) @@ -1125,7 +1132,11 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); - if (FFlag::LuauInstantiateInSubtyping) + // The new solver does not attempt to instantiate generics here, so if + // either the instantiate in subtyping flag _or_ the new solver flags + // are set, assert that we're getting back the original generic + // function definition. + if (FFlag::LuauInstantiateInSubtyping || FFlag::LuauSolverV2) CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); else CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); @@ -1148,7 +1159,11 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); - if (FFlag::LuauInstantiateInSubtyping) + // The new solver does not attempt to instantiate generics here, so if + // either the instantiate in subtyping flag _or_ the new solver flags + // are set, assert that we're getting back the original generic + // function definition. + if (FFlag::LuauInstantiateInSubtyping || FFlag::LuauSolverV2) CHECK_EQ("((a) -> (b...), a) -> (b...)", toString(tm->givenType)); else CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); @@ -1455,7 +1470,7 @@ TEST_CASE_FIXTURE(Fixture, "no_extra_quantification_for_generic_functions") TEST_CASE_FIXTURE(Fixture, "do_not_always_instantiate_generic_intersection_types") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -1490,8 +1505,9 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "higher_rank_polymorphism_should_not_accept_instantiated_arguments") { + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + ScopedFastFlag sffs[] = { - {FFlag::LuauSolverV2, false}, {FFlag::LuauInstantiateInSubtyping, true}, }; @@ -1571,7 +1587,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_implicit_explicit_name_clash") TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_functions_work_in_subtyping") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); if (!FFlag::LuauSolverV2) return; @@ -1587,4 +1603,31 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "generic_type_functions_work_in_subtyping") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "generic_type_subtyping_nested_bounds_with_new_mappings") +{ + // Test shows how going over mapped generics in a subtyping check can generate more mapped generics when making a subtyping check between bounds. + // It has previously caused iterator invalidation in the new solver, but this specific test doesn't trigger a UAF, only shows an example. + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( +type Dispatch = (A) -> () +type BasicStateAction = ((S) -> S) | S + +function updateReducer(reducer: (S, A) -> S, initialArg: I, init: ((I) -> S)?): (S, Dispatch) + return 1 :: any +end + +function basicStateReducer(state: S, action: BasicStateAction): S + return action +end + +function updateState(initialState: (() -> S) | S): (S, Dispatch>) + return updateReducer(basicStateReducer, initialState) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 50e28505..ca92083b 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -332,9 +332,9 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") { - ScopedFastFlag dcr{ - FFlag::LuauSolverV2, false - }; // CLI-116476 Subtyping between type alias and an equivalent but not named type isn't working. + // CLI-116476 Subtyping between type alias and an equivalent but not named type isn't working. + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( type X = { x: (number) -> number } type Y = { y: (string) -> string } @@ -372,7 +372,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") { - ScopedFastFlag dcr{FFlag::LuauSolverV2, false}; // CLI- + DOES_NOT_PASS_NEW_SOLVER_GUARD(); // After normalization, previous 'table_intersection_write_sealed_indirect' is identical to this one CheckResult result = check(R"( type XY = { x: (number) -> number, y: (string) -> string } @@ -581,9 +581,9 @@ could not be converted into TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") { - ScopedFastFlag dcr{ - FFlag::LuauSolverV2, false - }; // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions + // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( function f(x: ((number) -> number) & ((string) -> string)) local y : ((number | string) -> (number | string)) = x -- OK @@ -811,9 +811,9 @@ could not be converted into TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") { - ScopedFastFlag dcr{ - FFlag::LuauSolverV2, false - }; // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions + // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( function f() function g(x : ((number) -> number) & ((nil) -> unknown)) @@ -833,9 +833,9 @@ could not be converted into TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") { - ScopedFastFlag dcr{ - FFlag::LuauSolverV2, false - }; // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions + // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( function f() function g(x : ((number) -> number?) & ((unknown) -> string?)) @@ -939,9 +939,9 @@ could not be converted into TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_overlapping_results_and_variadics") { - ScopedFastFlag dcr{ - FFlag::LuauSolverV2, false - }; // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions + // CLI-116474 Semantic subtyping of assignments needs to decide how to interpret intersections of functions + DOES_NOT_PASS_NEW_SOLVER_GUARD(); + CheckResult result = check(R"( function f(x : ((string?) -> (string | number)) & ((number?) -> ...number)) local y : ((nil) -> (number, number?)) = x -- OK diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index de79654b..41753b66 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -15,9 +15,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauOkWithIteratingOverTableProperties) - -LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError) TEST_SUITE_BEGIN("TypeInferLoops"); @@ -155,7 +152,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") { // CLI-116494 The generics K and V are leaking out of the next() function somehow. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local n @@ -195,10 +192,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") LUAU_REQUIRE_ERROR_COUNT(1, result); - if (DFFlag::LuauImproveNonFunctionCallError) - CHECK_EQ("Cannot call a value of type string", toString(result.errors[0])); - else - CHECK_EQ("Cannot call non-function string", toString(result.errors[0])); + CHECK_EQ("Cannot call a value of type string", toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_with_just_one_iterator_is_ok") @@ -276,7 +270,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") { // We report a spuriouus duplicate error here. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local bad_iter = 5 @@ -293,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") { // Spurious duplicate errors - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function hasDivisors(value: number, table) @@ -345,7 +339,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_factory_not_returning_t TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") { // CLI-116496 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function prime_iter(state, index) @@ -699,8 +693,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "loop_typecheck_crash_on_empty_optional") if (FFlag::LuauSolverV2) return; - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( local t = {} for _ in t do @@ -765,7 +757,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") { // CLI-116498 Sometimes you can iterate over tables with no indexers. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local t: {string} = {} @@ -782,10 +774,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_strict") { // CLI-116498 Sometimes you can iterate over tables with no indexers. - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - {FFlag::LuauOkWithIteratingOverTableProperties, true} - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local t = {} @@ -937,8 +926,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_68448_iterators_need_not_accept_nil") TEST_CASE_FIXTURE(Fixture, "iterate_over_free_table") { - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( function print(x) end @@ -1093,9 +1080,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "dcr_iteration_on_never_gives_never") TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties") { // CLI-116498 - Sometimes you can iterate over tables with no indexer. - ScopedFastFlag sff0{FFlag::LuauSolverV2, false}; - - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function f() @@ -1118,8 +1103,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties") TEST_CASE_FIXTURE(BuiltinsFixture, "iterate_over_properties_nonstrict") { - ScopedFastFlag sff{FFlag::LuauOkWithIteratingOverTableProperties, true}; - CheckResult result = check(R"( --!nonstrict local function f() diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 42f1229f..ce4490fa 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -152,6 +152,43 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "require_a_variadic_function") CHECK(get(*iter.tail())); } +TEST_CASE_FIXTURE(BuiltinsFixture, "cross_module_table_freeze") +{ + fileResolver.source["game/A"] = R"( + --!strict + return { + a = 1, + } + )"; + + fileResolver.source["game/B"] = R"( + --!strict + return table.freeze(require(game.A)) + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ModulePtr a = frontend.moduleResolver.getModule("game/A"); + REQUIRE(a != nullptr); + // confirm that no cross-module mutation happened here! + if (FFlag::LuauSolverV2) + CHECK(toString(a->returnType) == "{ a: number }"); + else + CHECK(toString(a->returnType) == "{| a: number |}"); + + ModulePtr b = frontend.moduleResolver.getModule("game/B"); + REQUIRE(b != nullptr); + // confirm that no cross-module mutation happened here! + if (FFlag::LuauSolverV2) + CHECK(toString(b->returnType) == "{ read a: number }"); + else + CHECK(toString(b->returnType) == "{| a: number |}"); +} + TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") { CheckResult result = check(R"( @@ -425,7 +462,12 @@ local b: B.T = a LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::LuauSolverV2) - CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); + { + CHECK( + toString(result.errors.at(0)) == + "Type 'T' from 'game/A' could not be converted into 'T' from 'game/B'; at [read \"x\"], number is not exactly string" + ); + } else { const std::string expected = R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' @@ -466,7 +508,12 @@ local b: B.T = a LUAU_REQUIRE_ERROR_COUNT(1, result); if (FFlag::LuauSolverV2) - CHECK(toString(result.errors.at(0)) == "Type 'T' could not be converted into 'T'; at [read \"x\"], number is not exactly string"); + { + CHECK( + toString(result.errors.at(0)) == + "Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'; at [read \"x\"], number is not exactly string" + ); + } else { const std::string expected = R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' @@ -530,4 +577,209 @@ return l0 CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); } +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_scope_is_nullptr_after_shallow_copy") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + frontend.options.retainFullTypeGraphs = false; + + fileResolver.source["game/A"] = R"( +-- Roughly taken from ReactTypes.lua +type CoreBinding = {} +type BindingMap = {} +export type Binding = CoreBinding & BindingMap + +return {} + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( +local Types = require(game.A) +type Binding = Types.Binding + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "ensure_free_variables_are_generialized_across_function_boundaries") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( +-- Roughly taken from react-shallow-renderer +function createUpdater(renderer) + local updater = { + _renderer = renderer, + } + + function updater.enqueueForceUpdate(publicInstance, callback, _callerName) + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + function updater.enqueueReplaceState( + publicInstance, + completeState, + callback, + _callerName + ) + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + function updater.enqueueSetState(publicInstance, partialState, callback, _callerName) + local currentState = updater._renderer._newState or publicInstance.state + updater._renderer.render( + updater._renderer, + updater._renderer._element, + updater._renderer._context + ) + end + + return updater +end + +local ReactShallowRenderer = {} + +function ReactShallowRenderer:_reset() + self._updater = createUpdater(self) +end + +return ReactShallowRenderer + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( +local ReactShallowRenderer = require(game.A); + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "untitled_segfault_number_13") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( + -- minimized from roblox-requests/http/src/response.lua + local Response = {} + Response.__index = Response + function Response.new(content_type) + -- creates response object from original request and roblox http response + local self = setmetatable({}, Response) + self.content_type = content_type + return self + end + + function Response:xml(ignore_content_type) + if ignore_content_type or self.content_type:find("+xml") or self.content_type:find("/xml") then + else + end + end + + --------------- + + return Response + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local _ = require(game.A); + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "spooky_blocked_type_laundered_by_bound_type") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + fileResolver.source["game/A"] = R"( + local Cache = {} + + Cache.settings = {} + + Cache.data = {} + + function Cache.should_cache(url) + url = url:split("?")[1] + + for key, _ in pairs(Cache.settings) do + if url:match('') then + return key + end + end + + return "" + end + + function Cache.is_cached(url, req_id) + -- check local server cache first + + local setting_key = Cache.should_cache(url) + local settings = Cache.settings[setting_key] + + if not setting_key then + return false + end + + if Cache.data[req_id] ~= nil then + return true + end + + if Cache.settings[setting_key].cache_globally then + return false + else + return true + end + end + + function Cache.get_expire(url) + local setting_key = Cache.should_cache(url) + return Cache.settings[setting_key].expires or math.huge + end + + return Cache + )"; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local _ = require(game.A); + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "cycles_dont_make_everything_any") +{ + fileResolver.source["game/A"] = R"( + --!strict + local module = {} + + function module.foo() + return 2 + end + + function module.bar() + local m = require(game.B) + return m.foo() + 1 + end + + return module + )"; + + fileResolver.source["game/B"] = R"( + --!strict + local module = {} + + function module.foo() + return 2 + end + + function module.bar() + local m = require(game.A) + return m.foo() + 1 + end + + return module + )"; + + frontend.check("game/A"); + + CHECK("module" == toString(frontend.moduleResolver.getModule("game/B")->returnType)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 3ccc04c1..a2143b35 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -19,7 +19,7 @@ TEST_SUITE_BEGIN("TypeInferOOP"); TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon") { // CLI-116571 method calls are missing arity checking? - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local someTable = {} @@ -37,7 +37,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defi TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") { // CLI-116571 method calls are missing arity checking? - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local someTable = {} diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index d8bf9152..32338a68 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -17,7 +17,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauRemoveBadRelationalOperatorWarning) +LUAU_FASTFLAG(LuauDoNotGeneralizeInTypeFunctions) TEST_SUITE_BEGIN("TypeInferOperators"); @@ -631,7 +631,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") { // CLI-116463 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -801,7 +801,10 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") "Operator '+' could not be applied to operands of types unknown and unknown; there is no corresponding overload for __add", toString(result.errors[0]) ); - CHECK_EQ("Operator '-' could not be applied to operands of types unknown and unknown; there is no corresponding overload for __sub", toString(result.errors[1])); + CHECK_EQ( + "Operator '-' could not be applied to operands of types unknown and unknown; there is no corresponding overload for __sub", + toString(result.errors[1]) + ); } else { @@ -812,19 +815,19 @@ TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") TEST_CASE_FIXTURE(BuiltinsFixture, "and_binexps_dont_unify") { - CheckResult result = check(R"( - --!strict - local t = {} - while true and t[1] do - print(t[1].test) - end - )"); + ScopedFastFlag _{FFlag::LuauDoNotGeneralizeInTypeFunctions, true}; - // This infers a type for `t` of `{unknown}`, and so it makes sense that `t[1].test` would error. - if (FFlag::LuauSolverV2) - LUAU_REQUIRE_ERROR_COUNT(1, result); - else - LUAU_REQUIRE_NO_ERRORS(result); + // `t` will be inferred to be of type `{ { test: unknown } }` which is + // reasonable, in that it's empty with no bounds on its members. Optimally + // we might emit an error here that the `print(...)` expression is + // unreachable. + LUAU_REQUIRE_NO_ERRORS(check(R"( + --!strict + local t = {} + while true and t[1] do + print(t[1].test) + end + )")); } TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") @@ -860,7 +863,7 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato )"); // If DCR is off and the flag to remove this check in the old solver is on, the expected behavior is no errors. - if (!FFlag::LuauSolverV2 && FFlag::LuauRemoveBadRelationalOperatorWarning) + if (!FFlag::LuauSolverV2) { LUAU_REQUIRE_NO_ERRORS(result); return; @@ -885,7 +888,7 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") { // There's an extra spurious warning here when the new solver is enabled. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -1426,7 +1429,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "luau_polyfill_is_array_simplified") TEST_CASE_FIXTURE(BuiltinsFixture, "luau_polyfill_is_array") { // CLI-116480 Subtyping bug: table should probably be a subtype of {[unknown]: unknown} - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -1576,10 +1579,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compare_singleton_string_to_string") end )"); - if (FFlag::LuauRemoveBadRelationalOperatorWarning) - LUAU_REQUIRE_NO_ERRORS(result); - else - LUAU_REQUIRE_ERROR_COUNT(1, result); + // There is a flag to gate turning this off, and this warning is not + // implemented in the new solver, so assert there are no errors. + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "no_infinite_expansion_of_free_type" * doctest::timeout(1.0)) @@ -1613,4 +1615,26 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "compound_operator_on_upvalue") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_operator_follow") +{ + CheckResult result = check(R"( +local t1 = {} +local t2 = {} +local mt = {} + +mt.__eq = function(a, b) + return false +end + +setmetatable(t1, mt) +setmetatable(t2, mt) + +if t1 == t2 then + +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index c3cce9df..2c76f123 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -120,4 +120,34 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "property_of_buffers") LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "properties_of_vectors") +{ + CheckResult result = check(R"( + local a = vector.create(1, 2, 3) + local b = vector.create(4, 5, 6) + + local t1 = { + a + b, + a - b, + a * 3, + a * b, + 3 * b, + a / 3, + a / b, + 3 / b, + a // 4, + a // b, + 4 // b, + -a, + } + local t2 = { + a.x, + a.y, + a.z, + } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 514c31c8..5cafedbf 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + #include "Luau/TypeInfer.h" #include "Luau/RecursionCounter.h" @@ -11,6 +12,8 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2); +LUAU_FASTFLAG(DebugLuauEqSatSimplification); +LUAU_FASTFLAG(LuauStoreCSTData); LUAU_FASTINT(LuauNormalizeCacheLimit); LUAU_FASTINT(LuauTarjanChildLimit); LUAU_FASTINT(LuauTypeInferIterationLimit); @@ -45,7 +48,16 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expected = R"( + const std::string expected = FFlag::LuauStoreCSTData ? R"( + function f(a:{fn:()->(a,b...)}): () + if type(a) == 'boolean' then + local a1:boolean=a + elseif a.fn() then + local a2:{fn:()->(a,b...)}=a + end + end + )" + : R"( function f(a:{fn:()->(a,b...)}): () if type(a) == 'boolean'then local a1:boolean=a @@ -55,7 +67,16 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expectedWithNewSolver = R"( + const std::string expectedWithNewSolver = FFlag::LuauStoreCSTData ? R"( + function f(a:{fn:()->(unknown,...unknown)}): () + if type(a) == 'boolean' then + local a1:{fn:()->(unknown,...unknown)}&boolean=a + elseif a.fn() then + local a2:{fn:()->(unknown,...unknown)}&(class|function|nil|number|string|thread|buffer|table)=a + end + end + )" + : R"( function f(a:{fn:()->(unknown,...unknown)}): () if type(a) == 'boolean'then local a1:{fn:()->(unknown,...unknown)}&boolean=a @@ -65,15 +86,36 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - if (FFlag::LuauSolverV2) + const std::string expectedWithEqSat = FFlag::LuauStoreCSTData ? R"( + function f(a:{fn:()->(unknown,...unknown)}): () + if type(a) == 'boolean' then + local a1:{fn:()->(unknown,...unknown)}&boolean=a + elseif a.fn() then + local a2:{fn:()->(unknown,...unknown)}&negate=a + end + end + )" + : R"( + function f(a:{fn:()->(unknown,...unknown)}): () + if type(a) == 'boolean'then + local a1:{fn:()->(unknown,...unknown)}&boolean=a + elseif a.fn()then + local a2:{fn:()->(unknown,...unknown)}&negate=a + end + end + )"; + + if (FFlag::LuauSolverV2 && !FFlag::DebugLuauEqSatSimplification) CHECK_EQ(expectedWithNewSolver, decorateWithTypes(code)); + else if (FFlag::LuauSolverV2 && FFlag::DebugLuauEqSatSimplification) + CHECK_EQ(expectedWithEqSat, decorateWithTypes(code)); else CHECK_EQ(expected, decorateWithTypes(code)); } TEST_CASE_FIXTURE(BuiltinsFixture, "luau-polyfill.Array.filter") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); // This test exercises the fact that we should reduce sealed/unsealed/free tables // res is a unsealed table with type {((T & ~nil)?) & any} @@ -172,7 +214,7 @@ TEST_CASE_FIXTURE(Fixture, "it_should_be_agnostic_of_actual_size") // For now, infer it as just a free table. TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_constrains_free_type_into_free_table") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local a = {} @@ -192,7 +234,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "setmetatable_constrains_free_type_into_free_ // Luau currently doesn't yet know how to allow assignments when the binding was refined. TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type Node = { value: T, child: Node? } @@ -217,7 +259,7 @@ TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") // We should be type checking the metamethod at the call site of setmetatable. TEST_CASE_FIXTURE(BuiltinsFixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local tab = {a = 1} @@ -390,11 +432,9 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") { - ScopedFastFlag sff[] = { - // I'm not sure why this is broken without DCR, but it seems to be fixed - // when DCR is enabled. - {FFlag::LuauSolverV2, false}, - }; + // I'm not sure why this is broken without DCR, but it seems to be fixed + // when DCR is enabled. + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function f() return end @@ -524,10 +564,10 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - TypeId free1 = arena.addType(FreeType{scope.get()}); + TypeId free1 = arena.freshType(builtinTypes, scope.get()); TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - TypeId free2 = arena.addType(FreeType{scope.get()}); + TypeId free2 = arena.freshType(builtinTypes, scope.get()); TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; @@ -535,9 +575,6 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, NotNull{scope.get()}, Location{}, Variance::Covariant}; - if (FFlag::LuauSolverV2) - u.enableNewSolver(); - u.tryUnify(option1, option2); CHECK(!u.failure); @@ -553,7 +590,7 @@ TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_zero_iterators") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function no_iter() end @@ -655,13 +692,15 @@ struct IsSubtypeFixture : Fixture { bool isSubtype(TypeId a, TypeId b) { + SimplifierPtr simplifier = newSimplifier(NotNull{&getMainModule()->internalTypes}, builtinTypes); + ModulePtr module = getMainModule(); REQUIRE(module); if (!module->hasModuleScope()) FAIL("isSubtype: module scope data is not available"); - return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, ice); + return ::Luau::isSubtype(a, b, NotNull{module->getModuleScope().get()}, builtinTypes, NotNull{simplifier.get()}, ice); } }; } // namespace @@ -847,7 +886,7 @@ Type 'number?' could not be converted into 'number' in an invariant context)"; TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function foo(t, x) @@ -921,7 +960,7 @@ TEST_CASE_FIXTURE(Fixture, "expected_type_should_be_a_helpful_deduction_guide_fo TEST_CASE_FIXTURE(Fixture, "floating_generics_should_not_be_allowed") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local assign : (target: T, source0: U?, source1: V?, source2: W?, ...any) -> T & U & V & W = (nil :: any) @@ -952,10 +991,10 @@ TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") std::unique_ptr scope = std::make_unique(builtinTypes->anyTypePack); - TypeId free1 = arena.addType(FreeType{scope.get()}); + TypeId free1 = arena.freshType(builtinTypes, scope.get()); TypeId option1 = arena.addType(UnionType{{nilType, free1}}); - TypeId free2 = arena.addType(FreeType{scope.get()}); + TypeId free2 = arena.freshType(builtinTypes, scope.get()); TypeId option2 = arena.addType(UnionType{{nilType, free2}}); InternalErrorReporter iceHandler; @@ -963,9 +1002,6 @@ TEST_CASE_FIXTURE(Fixture, "free_options_can_be_unified_together") Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; Unifier u{NotNull{&normalizer}, NotNull{scope.get()}, Location{}, Variance::Covariant}; - if (FFlag::LuauSolverV2) - u.enableNewSolver(); - u.tryUnify(option1, option2); CHECK(!u.failure); @@ -992,7 +1028,7 @@ TEST_CASE_FIXTURE(Fixture, "unify_more_complex_unions_that_include_nil") TEST_CASE_FIXTURE(Fixture, "optional_class_instances_are_invariant_old_solver") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); createSomeClasses(&frontend); @@ -1076,7 +1112,7 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_infinite_recursion") { // The new solver doesn't recurse as heavily in this situation. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); #if defined(_NOOPT) || defined(_DEBUG) ScopedFastInt LuauTypeInferRecursionLimit{FInt::LuauTypeInferRecursionLimit, 100}; @@ -1271,7 +1307,7 @@ TEST_CASE_FIXTURE(Fixture, "table_containing_non_final_type_is_erroneously_cache TableType* table = getMutable(tableTy); REQUIRE(table); - TypeId freeTy = arena.freshType(&globalScope); + TypeId freeTy = arena.freshType(builtinTypes, &globalScope); table->props["foo"] = Property::rw(freeTy); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 9c26f165..fa62fbea 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,57 +8,72 @@ #include "doctest.h" LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(DebugLuauEqSatSimplification) +LUAU_FASTFLAG(LuauGeneralizationRemoveRecursiveUpperBound2) +LUAU_FASTFLAG(LuauIntersectNotNil) +LUAU_FASTFLAG(LuauSkipNoRefineDuringRefinement) using namespace Luau; namespace { -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, - const ScopePtr& scope, - const AstExprCall& expr, - WithPredicate withPredicate -) + +struct MagicInstanceIsA final : MagicFunction { - if (expr.args.size != 1) - return std::nullopt; + std::optional> handleOldSolver( + TypeChecker& typeChecker, + const ScopePtr& scope, + const AstExprCall& expr, + WithPredicate withPredicate + ) override + { + if (expr.args.size != 1) + return std::nullopt; - auto index = expr.func->as(); - auto str = expr.args.data[0]->as(); - if (!index || !str) - return std::nullopt; + auto index = expr.func->as(); + auto str = expr.args.data[0]->as(); + if (!index || !str) + return std::nullopt; - std::optional lvalue = tryGetLValue(*index->expr); - std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); - if (!lvalue || !tfun) - return std::nullopt; + std::optional lvalue = tryGetLValue(*index->expr); + std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); + if (!lvalue || !tfun) + return std::nullopt; - ModulePtr module = typeChecker.currentModule; - TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType}); - return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; -} + ModulePtr module = typeChecker.currentModule; + TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType}); + return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; + } -void dcrMagicRefinementInstanceIsA(const MagicRefinementContext& ctx) -{ - if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty()) - return; + bool infer(const MagicFunctionCallContext&) override + { + return false; + } - auto index = ctx.callSite->func->as(); - auto str = ctx.callSite->args.data[0]->as(); - if (!index || !str) - return; + void refine(const MagicRefinementContext& ctx) override + { + if (ctx.callSite->args.size != 1 || ctx.discriminantTypes.empty()) + return; - std::optional discriminantTy = ctx.discriminantTypes[0]; - if (!discriminantTy) - return; + auto index = ctx.callSite->func->as(); + auto str = ctx.callSite->args.data[0]->as(); + if (!index || !str) + return; + + std::optional discriminantTy = ctx.discriminantTypes[0]; + if (!discriminantTy) + return; + + std::optional tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size)); + if (!tfun) + return; + + LUAU_ASSERT(get(*discriminantTy)); + asMutable(*discriminantTy)->ty.emplace(tfun->type); + } +}; - std::optional tfun = ctx.scope->lookupType(std::string(str->value.data, str->value.size)); - if (!tfun) - return; - LUAU_ASSERT(get(*discriminantTy)); - asMutable(*discriminantTy)->ty.emplace(tfun->type); -} struct RefinementClassFixture : BuiltinsFixture { @@ -82,8 +97,7 @@ struct RefinementClassFixture : BuiltinsFixture TypePackId isAParams = arena.addTypePack({inst, builtinTypes->stringType}); TypePackId isARets = arena.addTypePack({builtinTypes->booleanType}); TypeId isA = arena.addType(FunctionType{isAParams, isARets}); - getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - getMutable(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA; + getMutable(isA)->magic = std::make_shared(); getMutable(inst)->props = { {"Name", Property{builtinTypes->stringType}}, @@ -448,10 +462,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "call_an_incompatible_function_after_using_ty LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK("Type 'string' could not be converted into 'number'" == toString(result.errors[0])); - CHECK(Location{{ 7, 18}, {7, 19}} == result.errors[0].location); + CHECK(Location{{7, 18}, {7, 19}} == result.errors[0].location); CHECK("Type 'string' could not be converted into 'number'" == toString(result.errors[1])); - CHECK(Location{{ 13, 18}, {13, 19}} == result.errors[1].location); + CHECK(Location{{13, 18}, {13, 19}} == result.errors[1].location); } TEST_CASE_FIXTURE(BuiltinsFixture, "impossible_type_narrow_is_not_an_error") @@ -488,8 +502,15 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") if (FFlag::LuauSolverV2) { - // CLI-115281 - Types produced by refinements don't always get simplified - CHECK("{ x: number? } & { x: ~(false?) }" == toString(requireTypeAtPosition({4, 23}))); + if (FFlag::DebugLuauEqSatSimplification) + { + CHECK("{ x: number }" == toString(requireTypeAtPosition({4, 23}))); + } + else + { + // CLI-115281 - Types produced by refinements don't always get simplified + CHECK("{ x: number? } & { x: ~(false?) }" == toString(requireTypeAtPosition({4, 23}))); + } CHECK("number" == toString(requireTypeAtPosition({5, 26}))); } @@ -732,11 +753,15 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "nonoptional_type_can_narrow_to_nil_if_sense_ if (FFlag::LuauSolverV2) { // CLI-115281 Types produced by refinements do not consistently get simplified - CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" - CHECK_EQ("(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({6, 24}))); // type(v) ~= "nil" + CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" + CHECK_EQ( + "(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({6, 24})) + ); // type(v) ~= "nil" - CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" - CHECK_EQ("(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" + CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" + CHECK_EQ( + "(boolean | buffer | class | function | number | string | table | thread) & string", toString(requireTypeAtPosition({12, 24})) + ); // equivalent to type(v) ~= "nil" } else { @@ -1375,7 +1400,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") { // CLI-115286 - Refining via type(x) == 'vector' does not work in the new solver - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function f(vec) @@ -1569,7 +1594,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "isa_type_refinement_must_be_known_ahe { // CLI-115087 - The new solver does not consistently combine tables with // class types when they appear in the upper bounds of a free type. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function f(x): Instance @@ -1867,9 +1892,8 @@ TEST_CASE_FIXTURE(Fixture, "refine_a_property_of_some_global") if (FFlag::LuauSolverV2) { - LUAU_REQUIRE_ERROR_COUNT(3, result); - - CHECK_EQ("*error-type* | buffer | class | function | number | string | table | thread | true", toString(requireTypeAtPosition({4, 30}))); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("number", toString(requireTypeAtPosition({4, 30}))); } } @@ -2049,10 +2073,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refinements_should_preserve_error_suppressio end )"); - if (FFlag::LuauSolverV2) - LUAU_REQUIRE_NO_ERRORS(result); - else - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "many_refinements_on_val") @@ -2266,37 +2287,37 @@ TEST_CASE_FIXTURE(Fixture, "more_complex_long_disjunction_of_refinements_shouldn { CHECK_NOTHROW(check(R"( script:connect(function(obj) - if script.Parent.SeatNumber.Value == "1D" or - script.Parent.SeatNumber.Value == "2D" or - script.Parent.SeatNumber.Value == "3D" or - script.Parent.SeatNumber.Value == "4D" or - script.Parent.SeatNumber.Value == "5D" or - script.Parent.SeatNumber.Value == "6D" or - script.Parent.SeatNumber.Value == "7D" or - script.Parent.SeatNumber.Value == "8D" or - script.Parent.SeatNumber.Value == "9D" or - script.Parent.SeatNumber.Value == "10D" or - script.Parent.SeatNumber.Value == "11D" or - script.Parent.SeatNumber.Value == "12D" or - script.Parent.SeatNumber.Value == "13D" or - script.Parent.SeatNumber.Value == "14D" or - script.Parent.SeatNumber.Value == "15D" or - script.Parent.SeatNumber.Value == "16D" or - script.Parent.SeatNumber.Value == "1C" or - script.Parent.SeatNumber.Value == "2C" or - script.Parent.SeatNumber.Value == "3C" or - script.Parent.SeatNumber.Value == "4C" or - script.Parent.SeatNumber.Value == "5C" or - script.Parent.SeatNumber.Value == "6C" or - script.Parent.SeatNumber.Value == "7C" or - script.Parent.SeatNumber.Value == "8C" or - script.Parent.SeatNumber.Value == "9C" or - script.Parent.SeatNumber.Value == "10C" or - script.Parent.SeatNumber.Value == "11C" or - script.Parent.SeatNumber.Value == "12C" or - script.Parent.SeatNumber.Value == "13C" or - script.Parent.SeatNumber.Value == "14C" or - script.Parent.SeatNumber.Value == "15C" or + if script.Parent.SeatNumber.Value == "1D" or + script.Parent.SeatNumber.Value == "2D" or + script.Parent.SeatNumber.Value == "3D" or + script.Parent.SeatNumber.Value == "4D" or + script.Parent.SeatNumber.Value == "5D" or + script.Parent.SeatNumber.Value == "6D" or + script.Parent.SeatNumber.Value == "7D" or + script.Parent.SeatNumber.Value == "8D" or + script.Parent.SeatNumber.Value == "9D" or + script.Parent.SeatNumber.Value == "10D" or + script.Parent.SeatNumber.Value == "11D" or + script.Parent.SeatNumber.Value == "12D" or + script.Parent.SeatNumber.Value == "13D" or + script.Parent.SeatNumber.Value == "14D" or + script.Parent.SeatNumber.Value == "15D" or + script.Parent.SeatNumber.Value == "16D" or + script.Parent.SeatNumber.Value == "1C" or + script.Parent.SeatNumber.Value == "2C" or + script.Parent.SeatNumber.Value == "3C" or + script.Parent.SeatNumber.Value == "4C" or + script.Parent.SeatNumber.Value == "5C" or + script.Parent.SeatNumber.Value == "6C" or + script.Parent.SeatNumber.Value == "7C" or + script.Parent.SeatNumber.Value == "8C" or + script.Parent.SeatNumber.Value == "9C" or + script.Parent.SeatNumber.Value == "10C" or + script.Parent.SeatNumber.Value == "11C" or + script.Parent.SeatNumber.Value == "12C" or + script.Parent.SeatNumber.Value == "13C" or + script.Parent.SeatNumber.Value == "14C" or + script.Parent.SeatNumber.Value == "15C" or script.Parent.SeatNumber.Value == "16C" then end) )")); @@ -2324,4 +2345,185 @@ end) )")); } +TEST_CASE_FIXTURE(Fixture, "refinements_table_intersection_limits" * doctest::timeout(0.5)) +{ + CheckResult result = check(R"( +--!strict +type Dir = { + a: number?, b: number?, c: number?, d: number?, e: number?, f: number?, + g: number?, h: number?, i: number?, j: number?, k: number?, l: number?, + m: number?, n: number?, o: number?, p: number?, q: number?, r: number?, +} + +local function test(dirs: {Dir}) + for k, dir in dirs + local success, message = pcall(function() + assert(dir.a == nil or type(dir.a) == "number") + assert(dir.b == nil or type(dir.b) == "number") + assert(dir.c == nil or type(dir.c) == "number") + assert(dir.d == nil or type(dir.d) == "number") + assert(dir.e == nil or type(dir.e) == "number") + assert(dir.f == nil or type(dir.f) == "number") + assert(dir.g == nil or type(dir.g) == "number") + assert(dir.h == nil or type(dir.h) == "number") + assert(dir.i == nil or type(dir.i) == "number") + assert(dir.j == nil or type(dir.j) == "number") + assert(dir.k == nil or type(dir.k) == "number") + assert(dir.l == nil or type(dir.l) == "number") + assert(dir.m == nil or type(dir.m) == "number") + assert(dir.n == nil or type(dir.n) == "number") + assert(dir.o == nil or type(dir.o) == "number") + assert(dir.p == nil or type(dir.p) == "number") + assert(dir.q == nil or type(dir.q) == "number") + assert(dir.r == nil or type(dir.r) == "number") + assert(dir.t == nil or type(dir.t) == "number") + assert(dir.u == nil or type(dir.u) == "number") + assert(dir.v == nil or type(dir.v) == "number") + local checkpoint = dir + + checkpoint.w = 1 + end) + assert(success) + end +end + )"); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeof_instance_refinement") +{ + CheckResult result = check(R"( + local function f(x: Instance | Vector3) + if typeof(x) == "Instance" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Instance", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeof_instance_error") +{ + CheckResult result = check(R"( + local function f(x: Part) + if typeof(x) == "Instance" then + local foo : Folder = x + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeof_instance_isa_refinement") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | string) + if typeof(x) == "Instance" then + local foo = x + if foo:IsA("Folder") then + local bar = foo + end + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 32}))); + CHECK_EQ("string", toString(requireTypeAtPosition({8, 28}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "remove_recursive_upper_bound_when_generalizing") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::DebugLuauEqSatSimplification, true}, + {FFlag::LuauGeneralizationRemoveRecursiveUpperBound2, true}, + }; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local t = {"hello"} + local v = t[2] + if type(v) == "nil" then + local foo = v + end + )")); + + CHECK_EQ("(nil & string)?", toString(requireTypeAtPosition({4, 24}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "nonnil_refinement_on_generic") +{ + ScopedFastFlag sff{FFlag::LuauIntersectNotNil, true}; + + CheckResult result = check(R"( + local function printOptional(item: T?, printer: (T) -> string): string + if item ~= nil then + return printer(item) + else + return "" + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + CHECK_EQ("T & ~nil", toString(requireTypeAtPosition({3, 31}))); + else + CHECK_EQ("T", toString(requireTypeAtPosition({3, 31}))); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "truthy_refinement_on_generic") +{ + ScopedFastFlag sff{FFlag::LuauIntersectNotNil, true}; + + CheckResult result = check(R"( + local function printOptional(item: T?, printer: (T) -> string): string + if item then + return printer(item) + else + return "" + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::LuauSolverV2) + CHECK_EQ("T & ~(false?)", toString(requireTypeAtPosition({3, 31}))); + else + CHECK_EQ("T", toString(requireTypeAtPosition({3, 31}))); +} + +TEST_CASE_FIXTURE(Fixture, "truthy_call_of_function_with_table_value_as_argument_should_not_refine_value_as_never") +{ + ScopedFastFlag sff{FFlag::LuauSkipNoRefineDuringRefinement, true}; + + CheckResult result = check(R"( + type Item = {} + + local function predicate(value: Item): boolean + return true + end + + local function checkValue(value: Item) + if predicate(value) then + local _ = value + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("Item", toString(requireTypeAtPosition({8, 27}))); + CHECK_EQ("Item", toString(requireTypeAtPosition({9, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 43b1305e..3aa4efee 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -153,7 +153,7 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(g: ((true, string) -> ()) & ((false, number) -> ())) @@ -334,6 +334,27 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "indexer_can_be_union_of_singletons") +{ + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type Target = "A" | "B" + + type Test = {[Target]: number} + + local test: Test = {} + + test.A = 2 + test.C = 4 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(8 == result.errors[0].location.begin.line); +} + TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { CheckResult result = check(R"( @@ -442,7 +463,7 @@ local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_singleton") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function foo(f, x) @@ -462,7 +483,7 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function foo(f, x): "hello"? -- anyone there? diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 9d76e7bd..4b7ce57c 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -18,9 +18,13 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauFixIndexerSubtypingOrdering) -LUAU_FASTFLAG(LuauAcceptIndexingTableUnionsIntersections) - -LUAU_DYNAMIC_FASTFLAG(LuauImproveNonFunctionCallError) +LUAU_FASTFLAG(LuauTrackInteriorFreeTypesOnScope) +LUAU_FASTFLAG(LuauTrackInteriorFreeTablesOnScope) +LUAU_FASTFLAG(LuauDontInPlaceMutateTableType) +LUAU_FASTFLAG(LuauAllowNonSharedTableTypesInLiteral) +LUAU_FASTFLAG(LuauFollowTableFreeze) +LUAU_FASTFLAG(LuauPrecalculateMutatedFreeTypes2) +LUAU_FASTFLAG(LuauDeferBidirectionalInferenceForTableAssignment) TEST_SUITE_BEGIN("TableTests"); @@ -318,7 +322,7 @@ TEST_CASE_FIXTURE(Fixture, "call_method_with_explicit_self_argument") TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon") { // CLI-114792 Dot vs colon warnings aren't in the new solver yet. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local T = {} @@ -371,7 +375,7 @@ TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon_but_correctly") TEST_CASE_FIXTURE(Fixture, "used_colon_instead_of_dot") { // CLI-114792 Dot vs colon warnings aren't in the new solver yet. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local T = {} @@ -396,7 +400,7 @@ TEST_CASE_FIXTURE(Fixture, "used_colon_instead_of_dot") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { // CLI-114792 We don't report MissingProperties in many places where the old solver does. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local a = {} @@ -536,7 +540,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_width_subtyping_3") TEST_CASE_FIXTURE(Fixture, "table_unification_4") { // CLI-114134 - Use egraphs to simplify types better. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function foo(o) @@ -567,7 +571,7 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") { // CLI-114872 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -586,7 +590,7 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignmen TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_call") { // CLI-114873 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -811,7 +815,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_its_variable_type_and_unifiable") { // This code is totally different in the new solver. We instead create a new type state for t2. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local t1: { [string]: string } = {} @@ -893,7 +897,7 @@ TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") TEST_CASE_FIXTURE(Fixture, "array_factory_function") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function empty() return {} end @@ -930,7 +934,7 @@ TEST_CASE_FIXTURE(Fixture, "indexer_on_sealed_table_must_unify_with_free_table") // CLI-114134 What should be happening here is that the type of `t` should // be reduced from `{number} & {string}` to `never`, but that's not // happening. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function F(t): {number} @@ -993,7 +997,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_from_a_table_should_prefer_properti TEST_CASE_FIXTURE(Fixture, "any_when_indexing_into_an_unsealed_table_with_no_indexer_in_nonstrict_mode") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict @@ -1145,7 +1149,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_inferred") TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add_both_ways") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type VectorMt = { __add: (Vector, number) -> Vector } @@ -1523,7 +1527,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "found_multiple_like_keys") TEST_CASE_FIXTURE(BuiltinsFixture, "dont_suggest_exact_match_keys") { // CLI-114977 Unsealed table writes don't account for order properly - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local t = {} @@ -1566,7 +1570,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "getmetatable_returns_pointer_to_metatable") TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_mismatch_should_fail") { // This test is invalid because we now create a new type state for t1 at the assignment. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local t1 = {x = 1} @@ -1610,7 +1614,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "property_lookup_through_tabletypevar_metatab TEST_CASE_FIXTURE(BuiltinsFixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") { // This test is invalid because we now create a new type state for t at the assignment. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local t = {x = 1} @@ -1665,7 +1669,7 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key") TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") { // CLI-114792 We don't report MissingProperties - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(t: {}): { [string]: string, a: string } @@ -1926,18 +1930,124 @@ TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short") TEST_CASE_FIXTURE(Fixture, "ok_to_set_nil_even_on_non_lvalue_base_expr") { - // CLI-100076 Assigning nil to an indexer should always succeed - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + ScopedFastFlag _{FFlag::LuauSolverV2, true}; - CheckResult result = check(R"( + LUAU_REQUIRE_NO_ERRORS(check(R"( local function f(): { [string]: number } return { ["foo"] = 1 } end f()["foo"] = nil + )")); + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local function f( + t: {known_prop: boolean, [string]: number}, + key: string + ) + t[key] = nil + t["hello"] = nil + t.undefined = nil + end + )")); + + auto result = check(R"( + local function f(t: {known_prop: boolean, [string]: number, }) + t.known_prop = nil + end )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(Location{{2, 27}, {2, 30}}, result.errors[0].location); + CHECK_EQ("Type 'nil' could not be converted into 'boolean'", toString(result.errors[0])); + + loadDefinition(R"( + declare class FancyHashtable + [string]: number + real_property: string + end + )"); + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local function removekey(fh: FancyHashtable, other_key: string) + fh["hmmm"] = nil + fh[other_key] = nil + fh.dne = nil + end + )")); + + result = check(R"( + local function removekey(fh: FancyHashtable) + fh.real_property = nil + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0].location, Location{{2, 31}, {2, 34}}); + CHECK_EQ(toString(result.errors[0]), "Type 'nil' could not be converted into 'string'"); +} + +TEST_CASE_FIXTURE(Fixture, "ok_to_set_nil_on_generic_map") +{ + LUAU_REQUIRE_NO_ERRORS(check(R"( + type MyMap = { [K]: V } + function set(m: MyMap, k: K, v: V) + m[k] = v + end + function unset(m: MyMap, k: K) + m[k] = nil + end + local m: MyMap = {} + set(m, "foo", true) + unset(m, "foo") + )")); +} + +TEST_CASE_FIXTURE(Fixture, "key_setting_inference_given_nil_upper_bound") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + LUAU_REQUIRE_NO_ERRORS(check(R"( + local function setkey_object(t: { [string]: number }, v) + t.foo = v + t.foo = nil + end + local function setkey_constindex(t: { [string]: number }, v) + t["foo"] = v + t["foo"] = nil + end + local function setkey_unknown(t: { [string]: number }, k, v) + t[k] = v + t[k] = nil + end + )")); + CHECK_EQ(toString(requireType("setkey_object")), "({ [string]: number }, number) -> ()"); + CHECK_EQ(toString(requireType("setkey_constindex")), "({ [string]: number }, number) -> ()"); + CHECK_EQ(toString(requireType("setkey_unknown")), "({ [string]: number }, string, number) -> ()"); + + LUAU_REQUIRE_NO_ERRORS(check(R"( + local function on_number(v: number): () end + local function setkey_object(t: { [string]: number }, v) + t.foo = v + on_number(v) + end + )")); + CHECK_EQ(toString(requireType("setkey_object")), "({ [string]: number }, number) -> ()"); +} + +TEST_CASE_FIXTURE(Fixture, "explicit_nil_indexer") +{ + + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + auto result = check(R"( + local function _(t: { [string]: number? }): number + return t.hello + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0].location, Location{{2, 12}, {2, 26}}); + CHECK(get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "ok_to_provide_a_subtype_during_construction") @@ -2096,7 +2206,7 @@ local Test: {Table} = { TEST_CASE_FIXTURE(Fixture, "common_table_element_general") { // CLI-115275 - Bidirectional inference does not always propagate indexer types into the expression - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type Table = { @@ -2218,7 +2328,7 @@ foo({ TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call_tail") { // CLI-115239 - Bidirectional checking does not work for __call metamethods - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type Foo = {x: number | string} @@ -2267,12 +2377,12 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table if (FFlag::LuauSolverV2) { - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(get(result.errors[0])); - CHECK(Location{{6, 45}, {6, 46}} == result.errors[0].location); + CHECK(get(result.errors[0])); - CHECK(get(result.errors[1])); + // This is not actually the expected behavior, but the typemismatch we were seeing before was for the wrong reason. + // The behavior of this test is just regressed generally in the new solver, and will need to be consciously addressed. } // TODO: test behavior is wrong with LuauInstantiateInSubtyping until we can re-enable the covariant requirement for instantiation in subtyping @@ -2406,7 +2516,7 @@ could not be converted into // // Second, nil <: unknown, so we consider that parameter to be optional. LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK("Type 'b1' could not be converted into 'a1'; at [read \"y\"], string is not exactly number" == toString(result.errors[0])); + CHECK("Type 'b1' could not be converted into 'a1'; at table()[read \"y\"], string is not exactly number" == toString(result.errors[0])); } else if (FFlag::LuauInstantiateInSubtyping) { @@ -2504,7 +2614,7 @@ Type 'number' could not be converted into 'string' in an invariant context)"; TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { // Table properties like HasSuper.p must be invariant. The new solver rightly rejects this program. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -2554,7 +2664,7 @@ Table type '{ x: number, y: number }' not compatible with type 'Super' because t TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") { // CLI-114791 Bidirectional inference should be able to cause the inference engine to forget that a table literal has some property - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -2572,7 +2682,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") TEST_CASE_FIXTURE(BuiltinsFixture, "recursive_metatable_type_call") { // CLI-114782 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local b @@ -2582,10 +2692,7 @@ b() LUAU_REQUIRE_ERROR_COUNT(1, result); - if (DFFlag::LuauImproveNonFunctionCallError) - CHECK_EQ(toString(result.errors[0]), R"(Cannot call a value of type t1 where t1 = { @metatable { __call: t1 }, { } })"); - else - CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); + CHECK_EQ(toString(result.errors[0]), R"(Cannot call a value of type t1 where t1 = { @metatable { __call: t1 }, { } })"); } TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") @@ -2653,12 +2760,15 @@ local y = #x TEST_CASE_FIXTURE(Fixture, "length_operator_union_errors") { + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + CheckResult result = check(R"( local x: {number} | number | string local y = #x )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + // CLI-119936: This shouldn't double error but does under the new solver. + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(BuiltinsFixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") @@ -2826,26 +2936,25 @@ TEST_CASE_FIXTURE(Fixture, "table_length") TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") { - // CLI-100076 - Assigning a table key to `nil` in the presence of an indexer should always be permitted - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; - - CheckResult result = check("local a = {} a[0] = 7 a[0] = nil"); - LUAU_REQUIRE_ERROR_COUNT(0, result); + LUAU_REQUIRE_NO_ERRORS(check("local a = {} a[0] = 7 a[0] = nil")); } TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") { + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + CheckResult result = check(R"( local a = {} a[0] = 7 a[0] = 't' + a[0] = nil )"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK((Location{Position{3, 15}, Position{3, 18}}) == result.errors[0].location); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK(tm->wantedType == builtinTypes->numberType); + CHECK_EQ("number?", toString(tm->wantedType)); CHECK(tm->givenType == builtinTypes->stringType); } @@ -3263,10 +3372,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_must_be_callable") if (FFlag::LuauSolverV2) { - if (DFFlag::LuauImproveNonFunctionCallError) - CHECK("Cannot call a value of type a" == toString(result.errors[0])); - else - CHECK("Cannot call non-function { @metatable { __call: number }, { } }" == toString(result.errors[0])); + CHECK("Cannot call a value of type a" == toString(result.errors[0])); } else { @@ -3274,7 +3380,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_must_be_callable") Location{{5, 20}, {5, 21}}, CannotCallNonFunction{builtinTypes->numberType}, }; - CHECK(result.errors[0] == e); } } @@ -3300,7 +3405,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_call_metamethod_generic") TEST_CASE_FIXTURE(BuiltinsFixture, "table_simple_call") { // The new solver can see that this function is safe to oversaturate. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local a = setmetatable({ x = 2 }, { @@ -3593,7 +3698,7 @@ local b = a.x TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shape_type") { // CLI-115087 The new solver cannot infer that a table-like type is actually string - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function f(s) @@ -3682,7 +3787,7 @@ Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutel TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") { // CLI-115087 The new solver cannot infer that a table-like type is actually string - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function f(s): string @@ -3697,6 +3802,11 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compati TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible") { + ScopedFastFlag sffs[] = { + {FFlag::LuauTrackInteriorFreeTypesOnScope, true}, + {FFlag::LuauTrackInteriorFreeTablesOnScope, true}, + }; + CheckResult result = check(R"( local function f(s): string local foo = s:absolutely_no_scalar_has_this_method() @@ -3706,17 +3816,14 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_ if (FFlag::LuauSolverV2) { - LUAU_REQUIRE_ERROR_COUNT(4, result); + LUAU_REQUIRE_ERROR_COUNT(3, result); CHECK(toString(result.errors[0]) == "Parameter 's' has been reduced to never. This function is not callable with any possible value."); - // FIXME: These free types should have been generalized by now. CHECK( toString(result.errors[1]) == - "Parameter 's' is required to be a subtype of '{- read absolutely_no_scalar_has_this_method: ('a <: (never) -> ('b, c...)) -}' here." + "Parameter 's' is required to be a subtype of '{ read absolutely_no_scalar_has_this_method: (never) -> (unknown, ...unknown) }' here." ); CHECK(toString(result.errors[2]) == "Parameter 's' is required to be a subtype of 'string' here."); - CHECK(get(result.errors[3])); - CHECK_EQ("(never) -> string", toString(requireType("f"))); } else @@ -3737,7 +3844,7 @@ Table type 'typeof(string)' not compatible with type 't1 where t1 = {+ absolutel TEST_CASE_FIXTURE(BuiltinsFixture, "a_free_shape_can_turn_into_a_scalar_directly") { // We need egraphs to simplify the type of `out` here. CLI-114134 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function stringByteList(str) @@ -4235,9 +4342,7 @@ TEST_CASE_FIXTURE(Fixture, "identify_all_problematic_table_fields") TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported") { - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type W = {read x: number} @@ -4374,7 +4479,7 @@ TEST_CASE_FIXTURE(Fixture, "write_annotations_are_unsupported_even_with_the_new_ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported") { - ScopedFastFlag sff[] = {{FFlag::LuauSolverV2, false}}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type W = {read x: number} @@ -4398,7 +4503,7 @@ TEST_CASE_FIXTURE(Fixture, "read_and_write_only_table_properties_are_unsupported TEST_CASE_FIXTURE(Fixture, "read_ond_write_only_indexers_are_unsupported") { - ScopedFastFlag sff[] = {{FFlag::LuauSolverV2, false}}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type T = {read [string]: number} @@ -4798,8 +4903,6 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_branching_table") { - ScopedFastFlag sff{FFlag::LuauAcceptIndexingTableUnionsIntersections, true}; - CheckResult result = check(R"( local test = if true then { "meow", "woof" } else { 4, 81 } local test2 = test[1] @@ -4816,8 +4919,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_branching_table") TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_branching_table2") { - ScopedFastFlag sff{FFlag::LuauAcceptIndexingTableUnionsIntersections, true}; - CheckResult result = check(R"( local test = if true then {} else {} local test2 = test[1] @@ -4832,4 +4933,254 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "indexing_branching_table2") CHECK("any" == toString(requireType("test2"))); } +TEST_CASE_FIXTURE(BuiltinsFixture, "length_of_array_is_number") +{ + CheckResult result = check(R"( + local function TestFunc(ranges: {number}): number + if true then + ranges = {} :: {number} + end + local numRanges: number = #ranges + return numRanges + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "subtyping_with_a_metatable_table_path") +{ + // Builtin functions have to be setup for the new solver + if (!FFlag::LuauSolverV2) + return; + + CheckResult result = check(R"( + type self = {} & {} + type Class = typeof(setmetatable()) + local function _(): Class + return setmetatable({}::self, {}) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + "Type pack '{ @metatable { }, { } & { } }' could not be converted into 'Class'; at [0].metatable(), { } is not a subtype of nil\n" + "\ttype { @metatable { }, { } & { } }[0].table()[0] ({ }) is not a subtype of Class[0].table() (nil)\n" + "\ttype { @metatable { }, { } & { } }[0].table()[1] ({ }) is not a subtype of Class[0].table() (nil)", + toString(result.errors[0]) + ); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_union_type") +{ + + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + // This will have one (legitimate) error but previously would crash. + auto result = check(R"( + local function set(key, value) + local Message = {} + function Message.new(message) + local self = message or {} + setmetatable(self, Message) + return self + end + local self = Message.new(nil) + self[key] = value + end + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + "Cannot add indexer to table '{ @metatable t1, (nil & ~(false?)) | { } } where t1 = { new: (a) -> { @metatable t1, (a & ~(false?)) | { " + "} } }'", + toString(result.errors[0]) + ); +} + +TEST_CASE_FIXTURE(Fixture, "function_check_constraint_too_eager") +{ + // NOTE: All of these examples should have no errors, but + // bidirectional inference is known to be broken. + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauPrecalculateMutatedFreeTypes2, true}, + }; + + auto result = check(R"( + local function doTheThing(_: { [string]: unknown }) end + doTheThing({ + ['foo'] = 5, + ['bar'] = 'heyo', + }) + )"); + LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_CHECK_NO_ERROR(result, ConstraintSolvingIncompleteError); + + LUAU_CHECK_ERROR_COUNT(1, check(R"( + type Input = { [string]: unknown } + + local i : Input = { + [('%s'):format('3.14')]=5, + ['stringField']='Heyo' + } + )")); + + // This example previously asserted due to eagerly mutating the underlying + // table type. + result = check(R"( + type Input = { [string]: unknown } + + local function doTheThing(_: Input) end + + doTheThing({ + [('%s'):format('3.14')]=5, + ['stringField']='Heyo' + }) + )"); + LUAU_CHECK_ERROR_COUNT(1, result); + LUAU_CHECK_NO_ERROR(result, ConstraintSolvingIncompleteError); +} + + +TEST_CASE_FIXTURE(BuiltinsFixture, "read_only_property_reads") +{ + ScopedFastFlag newSolver{FFlag::LuauSolverV2, true}; + + // none of the `t.id` accesses here should error + auto result = check(R"( + --!strict + type readonlyTable = {read id: number} + local t:readonlyTable = {id = 1} + + local _:{number} = {[t.id] = 1} + local _:{number} = {[t.id::number] = 1} + + local arr:{number} = {} + arr[t.id] = 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "multiple_fields_in_literal") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontInPlaceMutateTableType, true}, + }; + + auto result = check(R"( + type Foo = { + [string]: { + Min: number, + Max: number + } + } + local Foos: Foo = { + ["Foo"] = { + Min = -1, + Max = 1 + }, + ["Foo"] = { + Min = -1, + Max = 1 + } + } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "multiple_fields_from_fuzzer") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontInPlaceMutateTableType, true}, + {FFlag::LuauAllowNonSharedTableTypesInLiteral, true}, + }; + + // This would trigger an assert previously, so we really only care that + // there are errors (and there will be: lots of syntax errors). + LUAU_CHECK_ERRORS(check(R"( + function _(l0,l0) _(_,{n0=_,n0=_,},if l0:n0()[_] then _) + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "write_only_table_field_duplicate") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauDontInPlaceMutateTableType, true}, + {FFlag::LuauAllowNonSharedTableTypesInLiteral, true}, + }; + + auto result = check(R"( + type WriteOnlyTable = { write x: number } + local wo: WriteOnlyTable = { + x = 42, + x = 13, + } + )"); + + LUAU_CHECK_ERROR_COUNT(1, result); + CHECK_EQ("write keyword is illegal here", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "table_freeze_musnt_assert") +{ + ScopedFastFlag sffs[] = { + {FFlag::LuauSolverV2, true}, + {FFlag::LuauFollowTableFreeze, true}, + }; + + auto result = check(R"( + local m = {} + function m.foo() + local self = { entries = entries, _caches = {}} + local self = setmetatable(self, {}) + table.freeze(self) + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "optional_property_with_call") +{ + ScopedFastFlag _{FFlag::LuauDeferBidirectionalInferenceForTableAssignment, true}; + + LUAU_CHECK_NO_ERRORS(check(R"( + type t = { + key: boolean?, + time: number, + } + + local function num(): number + return 0 + end + + local _: t = { + time = num(), + } + )")); +} + +TEST_CASE_FIXTURE(Fixture, "empty_union_container_overflow") +{ + LUAU_REQUIRE_NO_ERRORS(check(R"( + --!strict + local CellRenderer = {} + function CellRenderer:init(props) + self._separators = { + unhighlight = function() + local cellKey, prevCellKey = self.props.cellKey, self.props.prevCellKey + self.props.onUpdateSeparators({ cellKey, prevCellKey }) + end, + updateProps = function (select, newProps) + local cellKey, prevCellKey = self.props.cellKey, self.props.prevCellKey + self.props.onUpdateSeparators({ if select == 'leading' then prevCellKey else cellKey }) + end + } + end + )")); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 09c6c05b..217391b8 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -16,13 +16,16 @@ #include -LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr); -LUAU_FASTFLAG(LuauSolverV2); -LUAU_FASTFLAG(LuauInstantiateInSubtyping); -LUAU_FASTINT(LuauCheckRecursionLimit); -LUAU_FASTINT(LuauNormalizeCacheLimit); -LUAU_FASTINT(LuauRecursionLimit); -LUAU_FASTINT(LuauTypeInferRecursionLimit); +LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) +LUAU_FASTFLAG(LuauSolverV2) +LUAU_FASTFLAG(LuauInstantiateInSubtyping) +LUAU_FASTINT(LuauCheckRecursionLimit) +LUAU_FASTINT(LuauNormalizeCacheLimit) +LUAU_FASTINT(LuauRecursionLimit) +LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTFLAG(LuauAstTypeGroup2) +LUAU_FASTFLAG(LuauNewNonStrictWarnOnUnknownGlobals) +LUAU_FASTFLAG(LuauInferLocalTypesInMultipleAssignments) using namespace Luau; @@ -145,9 +148,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") { - ScopedFastFlag sff[]{ - {FFlag::LuauSolverV2, false}, - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nocheck @@ -224,7 +225,7 @@ TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local o @@ -265,7 +266,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "weird_case") TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -384,7 +385,7 @@ TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") // checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_count") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); #if defined(LUAU_ENABLE_ASAN) int limit = 250; @@ -442,7 +443,7 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") TEST_CASE_FIXTURE(Fixture, "globals") { // The new solver does not permit assignments to globals like this. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict @@ -456,7 +457,7 @@ TEST_CASE_FIXTURE(Fixture, "globals") TEST_CASE_FIXTURE(Fixture, "globals2") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!nonstrict @@ -506,7 +507,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CHECK_NOTHROW(check(R"( --!nonstrict @@ -600,7 +601,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_assert") TEST_CASE_FIXTURE(BuiltinsFixture, "tc_after_error_recovery_no_replacement_name_in_error") { { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -622,7 +623,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tc_after_error_recovery_no_replacement_name_ } { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -697,7 +698,7 @@ TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( foo @@ -708,7 +709,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstExprError") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local a = foo: @@ -819,7 +820,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "no_heap_use_after_free_error") end )"); - if (FFlag::LuauSolverV2) + if (FFlag::LuauSolverV2 && !FFlag::LuauNewNonStrictWarnOnUnknownGlobals) LUAU_REQUIRE_NO_ERRORS(result); else LUAU_REQUIRE_ERRORS(result); @@ -877,7 +878,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") CheckResult result = check(R"(local a = if true then "true" else "false")"); LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveType::String); + CHECK("string" == toString(aType)); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") @@ -888,7 +889,7 @@ local a = if false then "a" elseif false then "b" else "c" )"); LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveType::String); + CHECK("string" == toString(aType)); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") @@ -1099,7 +1100,7 @@ end TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( --!strict @@ -1197,13 +1198,29 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_recursion_limit_normalizer") validateErrors(result.errors); REQUIRE_MESSAGE(!result.errors.empty(), getErrors(result)); - CHECK(1 == result.errors.size()); - if (FFlag::LuauSolverV2) - CHECK(Location{{3, 22}, {3, 42}} == result.errors[0].location); + { + CHECK(3 == result.errors.size()); + if (FFlag::LuauAstTypeGroup2) + CHECK(Location{{2, 22}, {2, 42}} == result.errors[0].location); + else + CHECK(Location{{2, 22}, {2, 41}} == result.errors[0].location); + CHECK(Location{{3, 22}, {3, 42}} == result.errors[1].location); + if (FFlag::LuauAstTypeGroup2) + CHECK(Location{{3, 22}, {3, 41}} == result.errors[2].location); + else + CHECK(Location{{3, 23}, {3, 40}} == result.errors[2].location); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[1])); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[2])); + } else + { + CHECK(1 == result.errors.size()); + CHECK(Location{{3, 12}, {3, 46}} == result.errors[0].location); - CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); + CHECK_EQ("Code is too complex to typecheck! Consider simplifying the code around this area", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "type_infer_cache_limit_normalizer") @@ -1222,7 +1239,7 @@ TEST_CASE_FIXTURE(Fixture, "type_infer_cache_limit_normalizer") TEST_CASE_FIXTURE(Fixture, "follow_on_new_types_in_substitution") { // CLI-114134 - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local obj = {} @@ -1563,7 +1580,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "lti_must_record_contributing_locations") */ TEST_CASE_FIXTURE(BuiltinsFixture, "be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local function concat(target: {T}, ...: {T} | T): {T} @@ -1683,4 +1700,112 @@ TEST_CASE_FIXTURE(Fixture, "leading_ampersand_no_type") CHECK("*error-type*" == toString(requireTypeAlias("Amp"))); } +TEST_CASE_FIXTURE(Fixture, "react_lua_follow_free_type_ub") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + LUAU_REQUIRE_NO_ERRORS(check(R"( + return function(Roact) + local Tree = Roact.Component:extend("Tree") + + function Tree:render() + local breadth, components, depth, id, wrap = + self.props.breadth, self.props.components, self.props.depth, self.props.id, self.props.wrap + local Box = components.Box + if depth == 0 then + Roact.createElement(Box, {}) + else + Roact.createElement(Tree, {}) + end + + end + end + )")); +} + +TEST_CASE_FIXTURE(Fixture, "visit_error_nodes_in_lvalue") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + // This should always fail to parse, but shouldn't assert. Previously this + // would assert as we end up _roughly_ parsing this (with a lot of error + // nodes) as: + // + // do + // x :: T, y = z + // end + // + // We assume that `T` has some resolved type that is set up during + // constraint generation and resolved during constraint solving to + // be used during typechecking. We didn't descend into error nodes + // in lvalue positions. + LUAU_REQUIRE_ERRORS(check(R"( + --!strict + (::, + )")); +} + +TEST_CASE_FIXTURE(Fixture, "avoid_blocking_type_function") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + LUAU_CHECK_NO_ERRORS(check(R"( + --!strict + local function foo(a : string?) + local b = a or "" + return b:upper() + end + )")); +} + +TEST_CASE_FIXTURE(Fixture, "avoid_double_reference_to_free_type") +{ + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + + LUAU_CHECK_NO_ERRORS(check(R"( + --!strict + local function wtf(name: string?) + local message + message = "invalid alternate fiber: " .. (name or "UNNAMED alternate") + end + )")); +} + +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_types_of_globals") +{ + ScopedFastFlag sff_LuauSolverV2{FFlag::LuauSolverV2, true}; + + CheckResult result = check(R"( + --!strict + foo = 5 + print(foo) + )"); + + CHECK_EQ("number", toString(requireTypeAtPosition({3, 14}))); + + REQUIRE_EQ(1, result.errors.size()); + CHECK_EQ("Unknown global 'foo'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "multiple_assignment") +{ + ScopedFastFlag sff_LuauSolverV2{FFlag::LuauSolverV2, true}; + ScopedFastFlag sff_InferLocalTypesInMultipleAssignments{FFlag::LuauInferLocalTypesInMultipleAssignments, true}; + + CheckResult result = check(R"( + local function requireString(arg: string) end + local function requireNumber(arg: number) end + + local function f(): ...number end + + local w: "a", x, y, z = "a", 1, f() + requireString(w) + requireNumber(x) + requireNumber(y) + requireNumber(z) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 10ddd097..66bf034e 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -17,7 +17,7 @@ LUAU_FASTFLAG(LuauUnifierRecursionOnRestart); struct TryUnifyFixture : Fixture { // Cannot use `TryUnifyFixture` under DCR. - ScopedFastFlag noDcr{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); TypeArena arena; ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; @@ -32,7 +32,7 @@ TEST_SUITE_BEGIN("TryUnifyTests"); TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") { Type numberOne{TypeVariant{PrimitiveType{PrimitiveType::Number}}}; - Type numberTwo = numberOne; + Type numberTwo = numberOne.clone(); state.tryUnify(&numberTwo, &numberOne); @@ -42,12 +42,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") { - Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) + Type functionOne{TypeVariant{ + FunctionType(arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) }}; - Type functionTwo{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))} - }; + Type functionTwo{TypeVariant{FunctionType( + arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}) + )}}; state.tryUnify(&functionTwo, &functionOne); CHECK(!state.failure); @@ -60,17 +61,19 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") { - TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; - Type functionOne{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) + TypePackVar argPackOne{TypePack{{arena.freshType(builtinTypes, globalScope->level)}, std::nullopt}}; + Type functionOne{TypeVariant{ + FunctionType(arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({builtinTypes->numberType})) }}; - Type functionOneSaved = functionOne; + Type functionOneSaved = functionOne.clone(); - TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; - Type functionTwo{TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->stringType})) + TypePackVar argPackTwo{TypePack{{arena.freshType(builtinTypes, globalScope->level)}, std::nullopt}}; + Type functionTwo{TypeVariant{ + FunctionType(arena.addTypePack({arena.freshType(builtinTypes, globalScope->level)}), arena.addTypePack({builtinTypes->stringType})) }}; - Type functionTwoSaved = functionTwo; + Type functionTwoSaved = functionTwo.clone(); state.tryUnify(&functionTwo, &functionOne); CHECK(state.failure); @@ -83,11 +86,11 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") { Type tableOne{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, + TableType{{{"foo", {arena.freshType(builtinTypes, globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; Type tableTwo{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, + TableType{{{"foo", {arena.freshType(builtinTypes, globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; CHECK_NE(*getMutable(&tableOne)->props["foo"].type(), *getMutable(&tableTwo)->props["foo"].type()); @@ -106,7 +109,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") { Type tableOne{TypeVariant{ TableType{ - {{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, + {{"foo", {arena.freshType(builtinTypes, globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, std::nullopt, globalScope->level, TableState::Unsealed @@ -115,7 +118,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") Type tableTwo{TypeVariant{ TableType{ - {{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->stringType}}}, + {{"foo", {arena.freshType(builtinTypes, globalScope->level)}}, {"bar", {builtinTypes->stringType}}}, std::nullopt, globalScope->level, TableState::Unsealed @@ -154,7 +157,7 @@ TEST_CASE_FIXTURE(Fixture, "uninhabited_intersection_sub_anything") TEST_CASE_FIXTURE(Fixture, "uninhabited_table_sub_never") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(arg : { prop : string & number }) : never @@ -166,7 +169,7 @@ TEST_CASE_FIXTURE(Fixture, "uninhabited_table_sub_never") TEST_CASE_FIXTURE(Fixture, "uninhabited_table_sub_anything") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(arg : { prop : string & number }) : boolean @@ -178,7 +181,7 @@ TEST_CASE_FIXTURE(Fixture, "uninhabited_table_sub_anything") TEST_CASE_FIXTURE(Fixture, "members_of_failed_typepack_unification_are_unified_with_errorType") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(arg: number) end @@ -195,7 +198,7 @@ TEST_CASE_FIXTURE(Fixture, "members_of_failed_typepack_unification_are_unified_w TEST_CASE_FIXTURE(Fixture, "result_of_failed_typepack_unification_is_constrained") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(arg: number) return arg end @@ -295,7 +298,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") { - Type redirect{FreeType{TypeLevel{}}}; + Type redirect{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; Type table{TableType{}}; Type metatable{MetatableType{&redirect, &table}}; redirect = BoundType{&metatable}; // Now we have a metatable that is recursive on the table type @@ -318,7 +321,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") { - TypeId a = arena.addType(Type{FreeType{TypeLevel{}}}); + TypeId a = arena.freshType(builtinTypes, TypeLevel{}); TypeId b = builtinTypes->numberType; state.tryUnify(a, b); @@ -338,50 +341,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") CHECK_EQ(a->owningArena, &arena); } -TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table") -{ - TableType::Props freeProps{ - {"foo", {builtinTypes->numberType}}, - }; - - TypeId free = arena.addType(TableType{freeProps, std::nullopt, TypeLevel{}, TableState::Free}); - - TableType::Props indexProps{ - {"foo", {builtinTypes->stringType}}, - }; - - TypeId index = arena.addType(TableType{indexProps, std::nullopt, TypeLevel{}, TableState::Sealed}); - - TableType::Props mtProps{ - {"__index", {index}}, - }; - - TypeId mt = arena.addType(TableType{mtProps, std::nullopt, TypeLevel{}, TableState::Sealed}); - - TypeId target = arena.addType(TableType{TableState::Unsealed, TypeLevel{}}); - TypeId metatable = arena.addType(MetatableType{target, mt}); - - state.enableNewSolver(); - state.tryUnify(metatable, free); - state.log.commit(); - - REQUIRE_EQ(state.errors.size(), 1); - const std::string expected = R"(Type - '{ @metatable {| __index: {| foo: string |} |}, { } }' -could not be converted into - '{- foo: number -}' -caused by: - Type 'number' could not be converted into 'string')"; - CHECK_EQ(expected, toString(state.errors[0])); -} - TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") { TypePackVar variadicAny{VariadicTypePack{builtinTypes->anyType}}; TypePackVar packTmp{TypePack{{builtinTypes->anyType}, &variadicAny}}; TypePackVar packSub{TypePack{{builtinTypes->anyType, builtinTypes->anyType}, &packTmp}}; - Type freeTy{FreeType{TypeLevel{}}}; + Type freeTy{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; TypePackVar freeTp{FreeTypePack{TypeLevel{}}}; TypePackVar packSuper{TypePack{{&freeTy}, &freeTp}}; @@ -400,101 +366,6 @@ local l0:(any)&(typeof(_)),l0:(any)|(any) = _,_ LUAU_REQUIRE_ERRORS(result); } -static TypeId createTheType(TypeArena& arena, NotNull builtinTypes, Scope* scope, TypeId freeTy) -{ - /* - ({| - render: ( - (('a) -> ()) | {| current: 'a |} - ) -> nil - |}) -> () - */ - TypePackId emptyPack = arena.addTypePack({}); - - return arena.addType(FunctionType{ - arena.addTypePack({arena.addType(TableType{ - TableType::Props{ - {{"render", - Property(arena.addType(FunctionType{ - arena.addTypePack({arena.addType(UnionType{ - {arena.addType(FunctionType{arena.addTypePack({freeTy}), emptyPack}), - arena.addType(TableType{TableType::Props{{"current", {freeTy}}}, std::nullopt, TypeLevel{}, scope, TableState::Sealed})} - })}), - arena.addTypePack({builtinTypes->nilType}) - }))}} - }, - std::nullopt, - TypeLevel{}, - scope, - TableState::Sealed - })}), - emptyPack - }); -}; - -// See CLI-71190 -TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_two_unions_under_dcr_does_not_create_a_BoundType_cycle") -{ - const std::shared_ptr scope = globalScope; - const std::shared_ptr nestedScope = std::make_shared(scope); - - const TypeId outerType = arena.freshType(scope.get()); - const TypeId outerType2 = arena.freshType(scope.get()); - - const TypeId innerType = arena.freshType(nestedScope.get()); - - state.enableNewSolver(); - - SUBCASE("equal_scopes") - { - TypeId one = createTheType(arena, builtinTypes, scope.get(), outerType); - TypeId two = createTheType(arena, builtinTypes, scope.get(), outerType2); - - state.tryUnify(one, two); - state.log.commit(); - - ToStringOptions opts; - - CHECK(follow(outerType) == follow(outerType2)); - } - - SUBCASE("outer_scope_is_subtype") - { - TypeId one = createTheType(arena, builtinTypes, scope.get(), outerType); - TypeId two = createTheType(arena, builtinTypes, scope.get(), innerType); - - state.tryUnify(one, two); - state.log.commit(); - - ToStringOptions opts; - - CHECK(follow(outerType) == follow(innerType)); - - // The scope of outerType exceeds that of innerType. The latter should be bound to the former. - const BoundType* bt = get_if(&innerType->ty); - REQUIRE(bt); - CHECK(bt->boundTo == outerType); - } - - SUBCASE("outer_scope_is_supertype") - { - TypeId one = createTheType(arena, builtinTypes, scope.get(), innerType); - TypeId two = createTheType(arena, builtinTypes, scope.get(), outerType); - - state.tryUnify(one, two); - state.log.commit(); - - ToStringOptions opts; - - CHECK(follow(outerType) == follow(innerType)); - - // The scope of outerType exceeds that of innerType. The latter should be bound to the former. - const BoundType* bt = get_if(&innerType->ty); - REQUIRE(bt); - CHECK(bt->boundTo == outerType); - } -} - TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_full_restart_recursion") { ScopedFastFlag luauUnifierRecursionOnRestart{FFlag::LuauUnifierRecursionOnRestart, true}; diff --git a/tests/TypeInfer.typePacks.test.cpp b/tests/TypeInfer.typePacks.test.cpp index 8b489c44..858c3052 100644 --- a/tests/TypeInfer.typePacks.test.cpp +++ b/tests/TypeInfer.typePacks.test.cpp @@ -787,7 +787,7 @@ local d: Y ()> TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type Y = { a: T } @@ -811,7 +811,7 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors2") TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors3") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type Y = { a: (T) -> U... } @@ -824,7 +824,7 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors3") TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors4") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type Packed = (T) -> T @@ -953,11 +953,10 @@ a = b if (FFlag::LuauSolverV2) { - const std::string expected = - "Type\n" - " '() -> (number, ...boolean)'\n" - "could not be converted into\n" - " '() -> (number, ...string)'; at returns().tail().variadic(), boolean is not a subtype of string"; + const std::string expected = "Type\n" + " '() -> (number, ...boolean)'\n" + "could not be converted into\n" + " '() -> (number, ...string)'; at returns().tail().variadic(), boolean is not a subtype of string"; CHECK(expected == toString(result.errors[0])); } @@ -1065,7 +1064,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "detect_cyclic_typepacks2") TEST_CASE_FIXTURE(Fixture, "unify_variadic_tails_in_arguments") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function foo(...: string): number diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 6cdec4af..247894d1 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -9,7 +9,6 @@ using namespace Luau; LUAU_FASTFLAG(LuauSolverV2) -LUAU_FASTFLAG(LuauAcceptIndexingTableUnionsIntersections) TEST_SUITE_BEGIN("UnionTypes"); @@ -36,7 +35,7 @@ TEST_CASE_FIXTURE(Fixture, "return_types_can_be_disjoint") { // CLI-114134 We need egraphs to consistently reduce the cyclic union // introduced by the increment here. - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local count = 0 @@ -122,7 +121,7 @@ TEST_CASE_FIXTURE(Fixture, "optional_arguments") TEST_CASE_FIXTURE(Fixture, "optional_arguments_table") { // CLI-115588 - Bidirectional inference does not happen for assignments - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local a:{a:string, b:string?} @@ -419,6 +418,9 @@ TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors_2") TEST_CASE_FIXTURE(Fixture, "optional_length_error") { + + ScopedFastFlag _{FFlag::LuauSolverV2, true}; + CheckResult result = check(R"( type A = {number} function f(a: A?) @@ -426,8 +428,10 @@ TEST_CASE_FIXTURE(Fixture, "optional_length_error") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); + // CLI-119936: This shouldn't double error but does under the new solver. + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Operator '#' could not be applied to operand of type A?; there is no corresponding overload for __len", toString(result.errors[0])); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[1])); } TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details") @@ -473,7 +477,7 @@ end TEST_CASE_FIXTURE(Fixture, "unify_unsealed_table_union_check") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( local x = { x = 3 } @@ -617,7 +621,7 @@ TEST_CASE_FIXTURE(Fixture, "indexing_into_a_cyclic_union_doesnt_crash") TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); - TypeId badCyclicUnionTy = arena.freshType(frontend.globals.globalScope.get()); + TypeId badCyclicUnionTy = arena.freshType(builtinTypes, frontend.globals.globalScope.get()); UnionType u; u.options.push_back(badCyclicUnionTy); @@ -638,16 +642,14 @@ TEST_CASE_FIXTURE(Fixture, "indexing_into_a_cyclic_union_doesnt_crash") )"); // this is a cyclic union of number arrays, so it _is_ a table, even if it's a nonsense type. - // no need to generate a NotATable error here. - if (FFlag::LuauAcceptIndexingTableUnionsIntersections) - LUAU_REQUIRE_NO_ERRORS(result); - else - LUAU_REQUIRE_ERROR_COUNT(1, result); + // no need to generate a NotATable error here. The new solver automatically handles this and + // correctly reports no errors. + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( type A = { x: number, y: (number) -> string } | { z: number, y: (number) -> string } @@ -723,7 +725,7 @@ TEST_CASE_FIXTURE(Fixture, "union_of_generic_typepack_functions") TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generics") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f() @@ -743,7 +745,7 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generics") TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generic_typepacks") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f() @@ -764,7 +766,7 @@ could not be converted into TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(x : (number) -> number?) @@ -783,7 +785,7 @@ could not be converted into TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(x : () -> (number | string)) @@ -802,7 +804,7 @@ could not be converted into TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(x : (...nil) -> (...number?)) @@ -848,7 +850,7 @@ could not be converted into TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(x : () -> (number?, ...number)) @@ -879,7 +881,9 @@ TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("(({ read x: unknown } & { x: number }) | ({ read x: unknown } & { x: string })) -> { x: number } | { x: string }", toString(requireType("f"))); + CHECK_EQ( + "(({ read x: unknown } & { x: number }) | ({ read x: unknown } & { x: string })) -> { x: number } | { x: string }", toString(requireType("f")) + ); } TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types_2") @@ -916,7 +920,7 @@ TEST_CASE_FIXTURE(Fixture, "union_table_any_property") TEST_CASE_FIXTURE(Fixture, "union_function_any_args") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(sup : ((...any) -> (...any))?, sub : ((number) -> (...any))) @@ -940,7 +944,7 @@ TEST_CASE_FIXTURE(Fixture, "optional_any") TEST_CASE_FIXTURE(Fixture, "generic_function_with_optional_arg") { - ScopedFastFlag sff{FFlag::LuauSolverV2, false}; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CheckResult result = check(R"( function f(x : T?) : {T} diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 7d8ed38f..85425f77 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -199,14 +199,14 @@ TEST_CASE_FIXTURE(TypePackFixture, "std_distance") TEST_CASE("content_reassignment") { - TypePackVar myError{Unifiable::Error{}, /*presistent*/ true}; + TypePackVar myError{ErrorTypePack{}, /*presistent*/ true}; TypeArena arena; TypePackId futureError = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); asMutable(futureError)->reassign(myError); - CHECK(get(futureError) != nullptr); + CHECK(get(futureError) != nullptr); CHECK(!futureError->persistent); CHECK(futureError->owningArena == &arena); } diff --git a/tests/TypePath.test.cpp b/tests/TypePath.test.cpp index 2481f27a..b281dcab 100644 --- a/tests/TypePath.test.cpp +++ b/tests/TypePath.test.cpp @@ -17,6 +17,7 @@ using namespace Luau::TypePath; LUAU_FASTFLAG(LuauSolverV2); LUAU_DYNAMIC_FASTINT(LuauTypePathMaximumTraverseSteps); +LUAU_FASTFLAG(LuauFreeTypesMustHaveBounds); struct TypePathFixture : Fixture { @@ -235,6 +236,23 @@ TEST_CASE_FIXTURE(ClassFixture, "metatables") } SUBCASE("table") + { + TYPESOLVE_CODE(R"( + type Table = { foo: number } + type Metatable = { bar: number } + local tbl: Table = { foo = 123 } + local mt: Metatable = { bar = 456 } + local res = setmetatable(tbl, mt) + )"); + + // Tricky test setup because 'setmetatable' mutates the argument 'tbl' type + auto result = traverseForType(requireType("res"), Path(TypeField::Table), builtinTypes); + auto expected = lookupType("Table"); + REQUIRE(expected); + CHECK(result == follow(*expected)); + } + + SUBCASE("metatable") { TYPESOLVE_CODE(R"( local mt = { foo = 123 } @@ -260,7 +278,7 @@ TEST_CASE_FIXTURE(TypePathFixture, "bounds") TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); - TypeId ty = arena.freshType(frontend.globals.globalScope.get()); + TypeId ty = arena.freshType(frontend.builtinTypes, frontend.globals.globalScope.get()); FreeType* ft = getMutable(ty); SUBCASE("upper") @@ -521,9 +539,7 @@ TEST_SUITE_BEGIN("TypePathToString"); TEST_CASE("field") { - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); CHECK(toString(PathBuilder().prop("foo").build()) == R"(["foo"])"); } @@ -550,9 +566,7 @@ TEST_CASE("empty_path") TEST_CASE("prop") { - ScopedFastFlag sff[] = { - {FFlag::LuauSolverV2, false}, - }; + DOES_NOT_PASS_NEW_SOLVER_GUARD(); Path p = PathBuilder().prop("foo").build(); CHECK(p == Path(TypePath::Property{"foo"})); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 9e21b1e0..1e5fdaf1 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -219,7 +219,7 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeIterator_with_only_cyclic_union") */ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { - Type ftv11{FreeType{TypeLevel{}}}; + Type ftv11{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; TypePackVar tp24{TypePack{{&ftv11}}}; TypePackVar tp17{TypePack{}}; @@ -469,8 +469,8 @@ TEST_CASE("content_reassignment") myAny.documentationSymbol = "@global/any"; TypeArena arena; - - TypeId futureAny = arena.addType(FreeType{TypeLevel{}}); + BuiltinTypes builtinTypes; + TypeId futureAny = arena.freshType(NotNull{&builtinTypes}, TypeLevel{}); asMutable(futureAny)->reassign(myAny); CHECK(get(futureAny) != nullptr); diff --git a/tests/VisitType.test.cpp b/tests/VisitType.test.cpp index 186afaa5..86063ae8 100644 --- a/tests/VisitType.test.cpp +++ b/tests/VisitType.test.cpp @@ -4,6 +4,7 @@ #include "Luau/RecursionCounter.h" +#include "Luau/Type.h" #include "doctest.h" using namespace Luau; @@ -54,7 +55,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_throw_when_limit_is_high_enough") TEST_CASE_FIXTURE(Fixture, "some_free_types_do_not_have_bounds") { - Type t{FreeType{TypeLevel{}}}; + Type t{FreeType{TypeLevel{}, builtinTypes->neverType, builtinTypes->unknownType}}; (void)toString(&t); } diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.luau similarity index 100% rename from tests/conformance/apicalls.lua rename to tests/conformance/apicalls.luau diff --git a/tests/conformance/assert.lua b/tests/conformance/assert.luau similarity index 100% rename from tests/conformance/assert.lua rename to tests/conformance/assert.luau diff --git a/tests/conformance/attrib.lua b/tests/conformance/attrib.luau similarity index 100% rename from tests/conformance/attrib.lua rename to tests/conformance/attrib.luau diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.luau similarity index 98% rename from tests/conformance/basic.lua rename to tests/conformance/basic.luau index 98f8000e..7f382485 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.luau @@ -36,7 +36,7 @@ assert(foo(1, 2, 3) == 2) assert(concat(pcall(function () end)) == "true") assert(concat(pcall(function () return nil end)) == "true,nil") assert(concat(pcall(function () return 1,2,3 end)) == "true,1,2,3") -assert(concat(pcall(function () error("oops") end)) == "false,basic.lua:39: oops") +assert(concat(pcall(function () error("oops") end)) == "false,basic.luau:39: oops") -- assignments assert((function() local a = 1 a = 2 return a end)() == 2) @@ -92,6 +92,16 @@ assert((function() local a = 1 a = a - 2 return a end)() == -1) assert((function() local a = 1 a = a * 2 return a end)() == 2) assert((function() local a = 1 a = a / 2 return a end)() == 0.5) +-- binary ops with fp specials, neg zero, large constants +-- argument is passed into anonymous function to prevent constant folding +assert((function(a) return tostring(a + 0) end)(-0) == "0") +assert((function(a) return tostring(a - 0) end)(-0) == "-0") +assert((function(a) return tostring(0 - a) end)(0) == "0") +assert((function(a) return tostring(a - a) end)(1 / 0) == "nan") +assert((function(a) return tostring(a * 0) end)(0 / 0) == "nan") +assert((function(a) return tostring(a / (2^1000)) end)(2^1000) == "1") +assert((function(a) return tostring(a / (2^-1000)) end)(2^-1000) == "1") + -- floor division should always round towards -Infinity assert((function() local a = 1 a = a // 2 return a end)() == 0) assert((function() local a = 3 a = a // 2 return a end)() == 1) @@ -290,7 +300,7 @@ assert((function() local t = {[1] = 1, [2] = 2} return t[1] + t[2] end)() == 3) assert((function() return table.concat({}, ',') end)() == "") assert((function() return table.concat({1}, ',') end)() == "1") assert((function() return table.concat({1,2}, ',') end)() == "1,2") -assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() == +assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15") assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16") assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17") @@ -770,7 +780,7 @@ assert(tostring(0) == "0") assert(tostring(-0) == "-0") -- test newline handling in long strings -assert((function() +assert((function() local s1 = [[ ]] local s2 = [[ diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.luau similarity index 100% rename from tests/conformance/bitwise.lua rename to tests/conformance/bitwise.luau diff --git a/tests/conformance/buffers.lua b/tests/conformance/buffers.luau similarity index 87% rename from tests/conformance/buffers.lua rename to tests/conformance/buffers.luau index 5da2a688..370fb8a8 100644 --- a/tests/conformance/buffers.lua +++ b/tests/conformance/buffers.luau @@ -599,6 +599,90 @@ end misc(table.create(16, 0)) +local function bitops(size, base) + local b = buffer.create(size) + + buffer.writeu32(b, base / 8, 0x12345678) + + assert(buffer.readbits(b, base, 8) == buffer.readu8(b, base / 8)) + assert(buffer.readbits(b, base, 16) == buffer.readu16(b, base / 8)) + assert(buffer.readbits(b, base, 32) == buffer.readu32(b, base / 8)) + + buffer.writebits(b, base, 32, 0) + + buffer.writebits(b, base, 1, 1) + assert(buffer.readi8(b, base / 8) == 1) + + buffer.writebits(b, base + 1, 1, 1) + assert(buffer.readi8(b, base / 8) == 3) + + -- construct 00000010 00000000_01000000_00010000_00001000 00001000_00010000_01000010_00100101 + buffer.writebits(b, base + 0, 1, 0b1) + buffer.writebits(b, base + 1, 2, 0b10) + buffer.writebits(b, base + 3, 3, 0b100) + buffer.writebits(b, base + 6, 4, 0b1000) + buffer.writebits(b, base + 10, 5, 0b10000) + buffer.writebits(b, base + 15, 6, 0b100000) + buffer.writebits(b, base + 21, 7, 0b1000000) + buffer.writebits(b, base + 28, 8, 0b10000000) + buffer.writebits(b, base + 36, 9, 0b100000000) + buffer.writebits(b, base + 45, 10, 0b1000000000) + buffer.writebits(b, base + 55, 11, 0b10000000000) + + assert(buffer.readbits(b, base + 0, 32) == 0b00001000_00010000_01000010_00100101) + assert(buffer.readbits(b, base + 32, 32) == 0b00000000_01000000_00010000_00001000) + + assert(buffer.readu32(b, base / 8 + 0) == 0b00001000_00010000_01000010_00100101) + assert(buffer.readu32(b, base / 8 + 4) == 0b00000000_01000000_00010000_00001000) + + -- slide the window to touch 5 bytes + assert(buffer.readbits(b, base + 1, 32) == 0b00000100000010000010000100010010) + assert(buffer.readbits(b, base + 2, 32) == 0b00000010000001000001000010001001) + assert(buffer.readbits(b, base + 3, 32) == 0b00000001000000100000100001000100) + assert(buffer.readbits(b, base + 4, 32) == 0b10000000100000010000010000100010) + assert(buffer.readbits(b, base + 5, 32) == 0b01000000010000001000001000010001) + assert(buffer.readbits(b, base + 6, 32) == 0b00100000001000000100000100001000) + assert(buffer.readbits(b, base + 7, 32) == 0b00010000000100000010000010000100) + assert(buffer.readbits(b, base + 8, 32) == 0b00001000000010000001000001000010) + + assert(buffer.readbits(b, base + 1, 15) == 0b010000100010010) + assert(buffer.readbits(b, base + 2, 15) == 0b001000010001001) + assert(buffer.readbits(b, base + 3, 15) == 0b000100001000100) + assert(buffer.readbits(b, base + 4, 15) == 0b000010000100010) + assert(buffer.readbits(b, base + 5, 15) == 0b000001000010001) + assert(buffer.readbits(b, base + 6, 15) == 0b100000100001000) + assert(buffer.readbits(b, base + 7, 15) == 0b010000010000100) + assert(buffer.readbits(b, base + 8, 15) == 0b001000001000010) + + -- zero bit + buffer.writebits(b, base, 0, 0b1) + assert(buffer.readbits(b, base, 32) == 0b00001000_00010000_01000010_00100101) + assert(buffer.readbits(b, base, 0) == 0) + assert(buffer.readbits(b, size * 8, 0) == 0) + + -- bounds + assert(ecall(function() buffer.readbits(b, -1, 0) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readbits(b, size * 8, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readbits(b, size * 8 - 1, 2) end) == "buffer access out of bounds") + assert(ecall(function() buffer.readbits(b, 0, 64) end) == "bit count is out of range of [0; 32]") + + assert(ecall(function() buffer.writebits(b, -1, 0, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writebits(b, size * 8, 1, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writebits(b, size * 8 - 1, 2, 1) end) == "buffer access out of bounds") + assert(ecall(function() buffer.writebits(b, 0, 64, 1) end) == "bit count is out of range of [0; 32]") + + + return b +end + +do + bitops(16, 0) + bitops(17, 8) + + -- a very large buffer and bit offsets can now be over 32 bits + bitops(1024 * 1024 * 1024, 6 * 1024 * 1024 * 1024) +end + local function testslowcalls() getfenv() @@ -619,6 +703,7 @@ local function testslowcalls() fromtostring() fill() misc(table.create(16, 0)) + bitops(16, 0) end testslowcalls() diff --git a/tests/conformance/calls.lua b/tests/conformance/calls.luau similarity index 94% rename from tests/conformance/calls.lua rename to tests/conformance/calls.luau index 621a921a..63ad81e1 100644 --- a/tests/conformance/calls.lua +++ b/tests/conformance/calls.luau @@ -236,4 +236,12 @@ if not limitedstack then assert(not err and string.find(msg, "error")) end +-- testing deep nested calls with a large thread stack +if not limitedstack then + function recurse(n, ...) return n <= 1 and (1 + #{...}) or recurse(n-1, table.unpack(table.create(4000, 1))) + 1 end + + local ok, msg = pcall(recurse, 19000) + assert(not ok and string.find(msg, "not enough memory")) +end + return('OK') diff --git a/tests/conformance/clear.lua b/tests/conformance/clear.luau similarity index 100% rename from tests/conformance/clear.lua rename to tests/conformance/clear.luau diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.luau similarity index 99% rename from tests/conformance/closure.lua rename to tests/conformance/closure.luau index 10dc322f..fba65706 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.luau @@ -284,7 +284,7 @@ function foo () error("foo") end -local fooerr = "closure.lua:284: foo" +local fooerr = "closure.luau:284: foo" function goo() foo() end x = coroutine.wrap(goo) diff --git a/tests/conformance/constructs.lua b/tests/conformance/constructs.luau similarity index 100% rename from tests/conformance/constructs.lua rename to tests/conformance/constructs.luau diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.luau similarity index 100% rename from tests/conformance/coroutine.lua rename to tests/conformance/coroutine.luau diff --git a/tests/conformance/coverage.lua b/tests/conformance/coverage.luau similarity index 100% rename from tests/conformance/coverage.lua rename to tests/conformance/coverage.luau diff --git a/tests/conformance/datetime.lua b/tests/conformance/datetime.luau similarity index 100% rename from tests/conformance/datetime.lua rename to tests/conformance/datetime.luau diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.luau similarity index 88% rename from tests/conformance/debug.lua rename to tests/conformance/debug.luau index e044ea45..89d43480 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.luau @@ -35,7 +35,7 @@ end local co2 = coroutine.create(halp) coroutine.resume(co2, 0 / 0, 42) -assert(debug.traceback(co2) == "debug.lua:31 function halp\n") +assert(debug.traceback(co2) == "debug.luau:31 function halp\n") assert(debug.info(co2, 0, "l") == 31) assert(debug.info(co2, 0, "f") == halp) @@ -64,7 +64,7 @@ assert(baz(1, "n") == "baz") assert(baz(2, "n") == "") -- main/anonymous assert(baz(3, "n") == nil) assert(baz(0, "s") == "[C]") -assert(baz(1, "s") == "debug.lua") +assert(baz(1, "s") == "debug.luau") assert(baz(0, "l") == -1) assert(baz(1, "l") > 42) assert(baz(0, "f") == debug.info) @@ -87,7 +87,7 @@ end assert(#(quux(1, "nlsf")) == 4) assert(quux(1, "nlsf")[1] == "quux") assert(quux(1, "nlsf")[2] > 64) -assert(quux(1, "nlsf")[3] == "debug.lua") +assert(quux(1, "nlsf")[3] == "debug.luau") assert(quux(1, "nlsf")[4] == quux) -- info arity @@ -138,4 +138,17 @@ end) coroutine.resume(wrapped2) +local wrapped3 = coroutine.create(function() + local thread = coroutine.create(function(target) + for i = 1, 100 do pcall(debug.info, target, 0, "?f") end + return 123 + end) + + local success, res = coroutine.resume(thread, coroutine.running()) + assert(success) + assert(res == 123) +end) + +coroutine.resume(wrapped3) + return 'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.luau similarity index 100% rename from tests/conformance/debugger.lua rename to tests/conformance/debugger.luau diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.luau similarity index 100% rename from tests/conformance/errors.lua rename to tests/conformance/errors.luau diff --git a/tests/conformance/events.lua b/tests/conformance/events.luau similarity index 100% rename from tests/conformance/events.lua rename to tests/conformance/events.luau diff --git a/tests/conformance/exceptions.lua b/tests/conformance/exceptions.luau similarity index 100% rename from tests/conformance/exceptions.lua rename to tests/conformance/exceptions.luau diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.luau similarity index 100% rename from tests/conformance/gc.lua rename to tests/conformance/gc.luau diff --git a/tests/conformance/ifelseexpr.lua b/tests/conformance/ifelseexpr.luau similarity index 100% rename from tests/conformance/ifelseexpr.lua rename to tests/conformance/ifelseexpr.luau diff --git a/tests/conformance/interrupt.lua b/tests/conformance/interrupt.luau similarity index 100% rename from tests/conformance/interrupt.lua rename to tests/conformance/interrupt.luau diff --git a/tests/conformance/iter.lua b/tests/conformance/iter.luau similarity index 100% rename from tests/conformance/iter.lua rename to tests/conformance/iter.luau diff --git a/tests/conformance/literals.lua b/tests/conformance/literals.luau similarity index 100% rename from tests/conformance/literals.lua rename to tests/conformance/literals.luau diff --git a/tests/conformance/locals.lua b/tests/conformance/locals.luau similarity index 100% rename from tests/conformance/locals.lua rename to tests/conformance/locals.luau diff --git a/tests/conformance/math.lua b/tests/conformance/math.luau similarity index 91% rename from tests/conformance/math.lua rename to tests/conformance/math.luau index 98d5b317..586023ed 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.luau @@ -388,6 +388,37 @@ assert(math.pow(noinline(2), 2) == 4) assert(math.pow(noinline(4), 0.5) == 2) assert(math.pow(noinline(-2), 2) == 4) +-- map +assert(math.map(0, -1, 1, 0, 2) == 1) +assert(math.map(1, 1, 4, 0, 2) == 0) +assert(math.map(2.5, 1, 4, 0, 2) == 1) +assert(math.map(4, 1, 4, 0, 2) == 2) +assert(math.map(1, 1, 4, 2, 0) == 2) +assert(math.map(2.5, 1, 4, 2, 0) == 1) +assert(math.map(4, 1, 4, 2, 0) == 0) +assert(math.map(1, 4, 1, 2, 0) == 0) +assert(math.map(2.5, 4, 1, 2, 0) == 1) +assert(math.map(4, 4, 1, 2, 0) == 2) +assert(math.map(-8, 0, 4, 0, 2) == -4) +assert(math.map(16, 0, 4, 0, 2) == 8) + +-- lerp basics +assert(math.lerp(1, 5, 0) == 1) +assert(math.lerp(1, 5, 1) == 5) +assert(math.lerp(1, 5, 0.5) == 3) +assert(math.lerp(1, 5, 1.5) == 7) +assert(math.lerp(1, 5, -0.5) == -1) +assert(math.lerp(1, 5, noinline(0.5)) == 3) + +-- lerp properties +local sq2, sq3 = math.sqrt(2), math.sqrt(3) +assert(math.lerp(sq2, sq3, 0) == sq2) -- exact at 0 +assert(math.lerp(sq2, sq3, 1) == sq3) -- exact at 1 +assert(math.lerp(-sq3, sq2, 1) == sq2) -- exact at 1 (fails for a + t*(b-a)) +assert(math.lerp(sq2, sq2, sq2 / 2) <= math.lerp(sq2, sq2, 1)) -- monotonic (fails for a*t + b*(1-t)) +assert(math.lerp(-sq3, sq2, 1) <= math.sqrt(2)) -- bounded (fails for a + t*(b-a)) +assert(math.lerp(sq2, sq2, sq2 / 2) == sq2) -- consistent (fails for a*t + b*(1-t)) + assert(tostring(math.pow(-2, 0.5)) == "nan") -- test that fastcalls return correct number of results @@ -450,5 +481,6 @@ assert(math.sign("2") == 1) assert(math.sign("-2") == -1) assert(math.sign("0") == 0) assert(math.round("1.8") == 2) +assert(math.lerp("1", "5", 0.5) == 3) return('OK') diff --git a/tests/conformance/move.lua b/tests/conformance/move.luau similarity index 100% rename from tests/conformance/move.lua rename to tests/conformance/move.luau diff --git a/tests/conformance/native.lua b/tests/conformance/native.luau similarity index 85% rename from tests/conformance/native.lua rename to tests/conformance/native.luau index 03845013..16172bab 100644 --- a/tests/conformance/native.lua +++ b/tests/conformance/native.luau @@ -513,4 +513,68 @@ end assert(extramath3(2) == "number") assert(extramath3("2") == "number") +local function slotcachelimit1() + local tbl = { + f1 = function() return 1 end, + f2 = function() return 2 end, + f3 = function() return 3 end, + f4 = function() return 4 end, + f5 = function() return 5 end, + f6 = function() return 6 end, + f7 = function() return 7 end, + f8 = function() return 8 end, + f9 = function() return 9 end, + f10 = function() return 10 end, + f11 = function() return 11 end, + f12 = function() return 12 end, + f13 = function() return 13 end, + f14 = function() return 14 end, + f15 = function() return 15 end, + f16 = function() return 16 end, + } + + local lookup = { + [tbl.f1] = 1, + [tbl.f2] = 2, + [tbl.f3] = 3, + [tbl.f4] = 4, + [tbl.f5] = 5, + [tbl.f6] = 6, + [tbl.f7] = 7, + [tbl.f8] = 8, + [tbl.f9] = 9, + [tbl.f10] = 10, + [tbl.f11] = 11, + [tbl.f12] = 12, + [tbl.f13] = 13, + [tbl.f14] = 14, + [tbl.f15] = 15, + [tbl.f16] = 16, + } + + assert(is_native()) + + return lookup +end + +slotcachelimit1() + +local function slotcachelimit2(foo, size) + local c1 = foo(vector.create(size.X, size.Y, size.Z)) + local c2 = foo(vector.create(-size.X, size.Y, size.Z)) + local c3 = foo(vector.create(-size.X, -size.Y, size.Z)) + local c4 = foo(vector.create(-size.X, -size.Y, -size.Z)) + local c5 = foo(vector.create(size.X, -size.Y, -size.Z)) + local c6 = foo(vector.create(size.X, size.Y, -size.Z)) + local c7 = foo(vector.create(size.X, -size.Y, size.Z)) + local c8 = foo(vector.create(-size.X, size.Y, -size.Z)) + local max = vector.create(math.max(c1.X, c2.X, c3.X, c4.X, c5.X, c6.X, c7.X, c8.X), math.max(c1.Y, c2.Y, c3.Y, c4.Y, c5.Y, c6.Y, c7.Y, c8.Y), math.max(c1.Z, c2.Z, c3.Z, c4.Z, c5.Z, c6.Z, c7.Z, c8.Z)) + local min = vector.create(math.min(c1.X, c2.X, c3.X, c4.X, c5.X, c6.X, c7.X, c8.X), math.min(c1.Y, c2.Y, c3.Y, c4.Y, c5.Y, c6.Y, c7.Y, c8.Y), math.min(c1.Z, c2.Z, c3.Z, c4.Z, c5.Z, c6.Z, c7.Z, c8.Z)) + + assert(is_native()) + return max - min +end + +slotcachelimit2(function(a) return -a end, vector.create(1, 2, 3)) + return('OK') diff --git a/tests/conformance/native_types.lua b/tests/conformance/native_types.luau similarity index 100% rename from tests/conformance/native_types.lua rename to tests/conformance/native_types.luau diff --git a/tests/conformance/native_userdata.lua b/tests/conformance/native_userdata.luau similarity index 100% rename from tests/conformance/native_userdata.lua rename to tests/conformance/native_userdata.luau diff --git a/tests/conformance/ndebug_upvalues.lua b/tests/conformance/ndebug_upvalues.luau similarity index 100% rename from tests/conformance/ndebug_upvalues.lua rename to tests/conformance/ndebug_upvalues.luau diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.luau similarity index 87% rename from tests/conformance/pcall.lua rename to tests/conformance/pcall.luau index c2be2708..abb242d3 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.luau @@ -77,9 +77,9 @@ checkresults({ "yield", "return", true, 1, 2, 3}, colog(function() return pcall( checkresults({ "yield", 1, "yield", 2, "return", true, true, 3}, colog(function() return pcall(function() coroutine.yield(1) return pcall(function() coroutine.yield(2) return 3 end) end) end)) -- error after yield tests -checkresults({ "yield", "return", false, "pcall.lua:80: foo" }, colog(function() return pcall(function() coroutine.yield() error("foo") end) end)) -checkresults({ "yield", "yield", "return", true, false, "pcall.lua:81: foo" }, colog(function() return pcall(function() coroutine.yield() return pcall(function() coroutine.yield() error("foo") end) end) end)) -checkresults({ "yield", "yield", "return", false, "pcall.lua:82: bar" }, colog(function() return pcall(function() coroutine.yield() pcall(function() coroutine.yield() error("foo") end) error("bar") end) end)) +checkresults({ "yield", "return", false, "pcall.luau:80: foo" }, colog(function() return pcall(function() coroutine.yield() error("foo") end) end)) +checkresults({ "yield", "yield", "return", true, false, "pcall.luau:81: foo" }, colog(function() return pcall(function() coroutine.yield() return pcall(function() coroutine.yield() error("foo") end) end) end)) +checkresults({ "yield", "yield", "return", false, "pcall.luau:82: bar" }, colog(function() return pcall(function() coroutine.yield() pcall(function() coroutine.yield() error("foo") end) error("bar") end) end)) -- returning lots of results (past MINSTACK limits) local res = {pcall(function() return table.unpack(table.create(100, 'a')) end)} @@ -100,15 +100,15 @@ checkresults({ true, 2 }, xpcall(function(...) return select('#', ...) end, erro checkresults({ "yield", "return", true, 42 }, colog(function() return xpcall(function() coroutine.yield() return 42 end, error) end)) -- xpcall immediate error handling -checkresults({ false, "pcall.lua:103: foo" }, xpcall(function() error("foo") end, function(err) return err end)) +checkresults({ false, "pcall.luau:103: foo" }, xpcall(function() error("foo") end, function(err) return err end)) checkresults({ false, "bar" }, xpcall(function() error("foo") end, function(err) return "bar" end)) checkresults({ false, 1 }, xpcall(function() error("foo") end, function(err) return 1, 2 end)) -checkresults({ false, "pcall.lua:106: foo\npcall.lua:106\npcall.lua:106\n" }, xpcall(function() error("foo") end, debug.traceback)) +checkresults({ false, "pcall.luau:106: foo\npcall.luau:106\npcall.luau:106\n" }, xpcall(function() error("foo") end, debug.traceback)) checkresults({ false, "error in error handling" }, xpcall(function() error("foo") end, function(err) error("bar") end)) -- xpcall error handling after yields -checkresults({ "yield", "return", false, "pcall.lua:110: foo" }, colog(function() return xpcall(function() coroutine.yield() error("foo") end, function(err) return err end) end)) -checkresults({ "yield", "return", false, "pcall.lua:111: foo\npcall.lua:111\npcall.lua:111\n" }, colog(function() return xpcall(function() coroutine.yield() error("foo") end, debug.traceback) end)) +checkresults({ "yield", "return", false, "pcall.luau:110: foo" }, colog(function() return xpcall(function() coroutine.yield() error("foo") end, function(err) return err end) end)) +checkresults({ "yield", "return", false, "pcall.luau:111: foo\npcall.luau:111\npcall.luau:111\n" }, colog(function() return xpcall(function() coroutine.yield() error("foo") end, debug.traceback) end)) -- xpcall error handling during error handling inside xpcall after yields checkresults({ "yield", "return", true, false, "error in error handling" }, colog(function() return xpcall(function() return xpcall(function() coroutine.yield() error("foo") end, function(err) error("bar") end) end, error) end)) @@ -126,7 +126,7 @@ coroutine.yield(weird) weird() end -checkresults({ false, "pcall.lua:129: cannot resume dead coroutine" }, pcall(function() for _ in coroutine.wrap(pcall), weird do end end)) +checkresults({ false, "pcall.luau:129: cannot resume dead coroutine" }, pcall(function() for _ in coroutine.wrap(pcall), weird do end end)) -- c++ exception checkresults({ false, "oops" }, pcall(cxxthrow)) @@ -168,6 +168,10 @@ checkresults({ false, "oops" }, xpcall(function() table.create(1e6) end, functio checkresults({ false, "error in error handling" }, xpcall(function() error("oops") end, function(e) table.create(1e6) end)) checkresults({ false, "not enough memory" }, xpcall(function() table.create(1e6) end, function(e) table.create(1e6) end)) +co = coroutine.create(function() table.create(1e6) end) +coroutine.resume(co) +checkresults({ false, "not enough memory" }, coroutine.close(co)) + -- ensure that pcall and xpcall close upvalues when handling error local upclo local function uptest(y) diff --git a/tests/conformance/pm.lua b/tests/conformance/pm.luau similarity index 100% rename from tests/conformance/pm.lua rename to tests/conformance/pm.luau diff --git a/tests/conformance/safeenv.lua b/tests/conformance/safeenv.luau similarity index 100% rename from tests/conformance/safeenv.lua rename to tests/conformance/safeenv.luau diff --git a/tests/conformance/sort.lua b/tests/conformance/sort.luau similarity index 100% rename from tests/conformance/sort.lua rename to tests/conformance/sort.luau diff --git a/tests/conformance/strconv.lua b/tests/conformance/strconv.luau similarity index 100% rename from tests/conformance/strconv.lua rename to tests/conformance/strconv.luau diff --git a/tests/conformance/stringinterp.lua b/tests/conformance/stringinterp.luau similarity index 100% rename from tests/conformance/stringinterp.lua rename to tests/conformance/stringinterp.luau diff --git a/tests/conformance/strings.lua b/tests/conformance/strings.luau similarity index 87% rename from tests/conformance/strings.lua rename to tests/conformance/strings.luau index 370641d9..857a4bb9 100644 --- a/tests/conformance/strings.lua +++ b/tests/conformance/strings.luau @@ -61,7 +61,7 @@ assert(#"\0\0\0" == 3) assert(#"1234567890" == 10) assert(string.byte("a") == 97) -assert(string.byte("á") > 127) +assert(string.byte("\xe4") > 127) assert(string.byte(string.char(255)) == 255) assert(string.byte(string.char(0)) == 0) assert(string.byte("\0") == 0) @@ -76,10 +76,10 @@ assert(string.byte("hi", 9, 10) == nil) assert(string.byte("hi", 2, 1) == nil) assert(string.char() == "") assert(string.char(0, 255, 0) == "\0\255\0") -assert(string.char(0, string.byte("á"), 0) == "\0á\0") -assert(string.char(string.byte("ál\0óu", 1, -1)) == "ál\0óu") -assert(string.char(string.byte("ál\0óu", 1, 0)) == "") -assert(string.char(string.byte("ál\0óu", -10, 100)) == "ál\0óu") +assert(string.char(0, string.byte("\xe4"), 0) == "\0\xe4\0") +assert(string.char(string.byte("\xe4l\0óu", 1, -1)) == "\xe4l\0óu") +assert(string.char(string.byte("\xe4l\0óu", 1, 0)) == "") +assert(string.char(string.byte("\xe4l\0óu", -10, 100)) == "\xe4l\0óu") assert(pcall(function() return string.char(256) end) == false) assert(pcall(function() return string.char(-1) end) == false) print('+') @@ -87,7 +87,7 @@ print('+') assert(string.upper("ab\0c") == "AB\0C") assert(string.lower("\0ABCc%$") == "\0abcc%$") assert(string.rep('teste', 0) == '') -assert(string.rep('tés\00tê', 2) == 'tés\0têtés\000tê') +assert(string.rep('tés\00tê', 2) == 'tés\0têtés\000tê') assert(string.rep('', 10) == '') assert(string.rep('', 1e9) == '') assert(pcall(string.rep, 'x', 2e9) == false) @@ -115,15 +115,18 @@ assert(pcall(function() return tostring(nothing()) end) == false) print('+') -x = '"ílo"\n\\' -assert(string.format('%q%s', x, x) == '"\\"ílo\\"\\\n\\\\""ílo"\n\\') +x = '"ílo"\n\\' +assert(string.format('%q%s', x, x) == '"\\"ílo\\"\\\n\\\\""ílo"\n\\') assert(string.format('%q', "\0") == [["\000"]]) assert(string.format('%q', "\r") == [["\r"]]) -assert(string.format("\0%c\0%c%x\0", string.byte("á"), string.byte("b"), 140) == - "\0á\0b8c\0") +assert(string.format("\0%c\0%c%x\0", string.byte("\xe4"), string.byte("b"), 140) == + "\0\xe4\0b8c\0") assert(string.format('') == "") assert(string.format("%c",34)..string.format("%c",48)..string.format("%c",90)..string.format("%c",100) == string.format("%c%c%c%c", 34, 48, 90, 100)) +assert(string.format("%c%c%c%c", 1, 0, 2, 3) == '\1\0\2\3') +assert(string.format("%5c%5c%5c%5c", 1, 0, 2, 3) == ' \1 \0 \2 \3') +assert(string.format("%-5c%-5c%-5c%-5c", 1, 0, 2, 3) == '\1 \0 \2 \3 ') assert(string.format("%s\0 is not \0%s", 'not be', 'be') == 'not be\0 is not \0be') assert(string.format("%%%d %010d", 10, 23) == "%10 0000000023") assert(tonumber(string.format("%f", 10.3)) == 10.3) @@ -184,7 +187,7 @@ assert(pcall(function() string.format("%#*", "bad form") end) == false) -assert(loadstring("return 1\n--comentário sem EOL no final")() == 1) +assert(loadstring("return 1\n--comentário sem EOL no final")() == 1) assert(table.concat{} == "") @@ -244,16 +247,16 @@ end if not trylocale("collate") then print("locale not supported") else - assert("alo" < "álo" and "álo" < "amo") + assert("alo" < "álo" and "álo" < "amo") end if not trylocale("ctype") then print("locale not supported") else - assert(string.gsub("áéíóú", "%a", "x") == "xxxxx") - assert(string.gsub("áÁéÉ", "%l", "x") == "xÁxÉ") - assert(string.gsub("áÁéÉ", "%u", "x") == "áxéx") - assert(string.upper"áÁé{xuxu}ção" == "ÁÁÉ{XUXU}ÇÃO") + assert(string.gsub("áéíóú", "%a", "x") == "xxxxx") + assert(string.gsub("áÃéÉ", "%l", "x") == "xÃxÉ") + assert(string.gsub("áÃéÉ", "%u", "x") == "áxéx") + assert(string.upper"áÃé{xuxu}ção" == "ÃÃÉ{XUXU}ÇÃO") end os.setlocale("C") diff --git a/tests/conformance/tables.lua b/tests/conformance/tables.luau similarity index 100% rename from tests/conformance/tables.lua rename to tests/conformance/tables.luau diff --git a/tests/conformance/tmerror.lua b/tests/conformance/tmerror.luau similarity index 100% rename from tests/conformance/tmerror.lua rename to tests/conformance/tmerror.luau diff --git a/tests/conformance/tpack.lua b/tests/conformance/tpack.luau similarity index 100% rename from tests/conformance/tpack.lua rename to tests/conformance/tpack.luau diff --git a/tests/conformance/types.lua b/tests/conformance/types.luau similarity index 100% rename from tests/conformance/types.lua rename to tests/conformance/types.luau diff --git a/tests/conformance/userdata.lua b/tests/conformance/userdata.luau similarity index 100% rename from tests/conformance/userdata.lua rename to tests/conformance/userdata.luau diff --git a/tests/conformance/utf8.lua b/tests/conformance/utf8.luau similarity index 100% rename from tests/conformance/utf8.lua rename to tests/conformance/utf8.luau diff --git a/tests/conformance/vararg.lua b/tests/conformance/vararg.luau similarity index 100% rename from tests/conformance/vararg.lua rename to tests/conformance/vararg.luau diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.luau similarity index 100% rename from tests/conformance/vector.lua rename to tests/conformance/vector.luau diff --git a/tests/conformance/vector_library.luau b/tests/conformance/vector_library.luau new file mode 100644 index 00000000..dd5f2d1b --- /dev/null +++ b/tests/conformance/vector_library.luau @@ -0,0 +1,197 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing vector library') + +-- detect vector size +local vector_size = if pcall(function() return vector(0, 0, 0).w end) then 4 else 3 + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub((err:find(": ") or -1) + 2, #err) +end + +-- make sure we cover both builtin and C impl +assert(vector.create(1, 2) == vector.create("1", "2")) +assert(vector.create(1, 2, 4) == vector.create("1", "2", "4")) + +-- 'create' +local v12 = vector.create(1, 2) +local v123 = vector.create(1, 2, 3) +assert(v12.x == 1 and v12.y == 2 and v12.z == 0) +assert(v123.x == 1 and v123.y == 2 and v123.z == 3) + +-- testing 'dot' with error handling and different call kinds to mostly check details in the codegen +assert(vector.dot(vector.create(1, 2, 4), vector.create(5, 6, 7)) == 45) +assert(ecall(function() vector.dot(vector.create(1, 2, 4)) end) == "missing argument #2 to 'dot' (vector expected)") +assert(ecall(function() vector.dot(vector.create(1, 2, 4), 2) end) == "invalid argument #2 to 'dot' (vector expected, got number)") + +local function doDot1(a: vector, b) + return vector.dot(a, b) +end + +local function doDot2(a: vector, b) + return (vector.dot(a, b)) +end + +local v124 = vector.create(1, 2, 4) + +assert(doDot1(v124, vector.create(5, 6, 7)) == 45) +assert(doDot2(v124, vector.create(5, 6, 7)) == 45) +assert(ecall(function() doDot1(v124, "a") end) == "invalid argument #2 to 'dot' (vector expected, got string)") +assert(ecall(function() doDot2(v124, "a") end) == "invalid argument #2 to 'dot' (vector expected, got string)") +assert(select("#", doDot1(v124, vector.create(5, 6, 7))) == 1) +assert(select("#", doDot2(v124, vector.create(5, 6, 7))) == 1) + +-- 'cross' tests and next ones will only test basic results +assert(vector.cross(vector.create(1, 0, 0), vector.create(0, 1, 0)) == vector.create(0, 0, 1)) +assert(vector.cross(vector.create(0, 1, 0), vector.create(1, 0, 0)) == vector.create(0, 0, -1)) +assert(select("#", vector.cross(vector.zero, vector.one)) == 1) + +-- 'normalize' +assert(vector.normalize(vector.create(0.5, 0, 0)) == vector.create(1, 0, 0)) +assert(select("#", vector.normalize(vector.one)) == 1) + +-- 'magnitude' +assert(vector.magnitude(vector.create(1, 2, 2)) == 3) +assert(select("#", vector.magnitude(vector.one)) == 1) + +-- 'abs' +assert(vector.abs(-vector.one) == vector.one) +assert(vector.abs(vector.create(math.huge, 0, 0)).x == math.abs(math.huge)) +assert(vector.abs(vector.create(0/0, 0, 0)).x ~= 0/0) +assert(select("#", vector.abs(vector.one)) == 1) + +-- 'floor' +assert(vector.floor(vector.create(1, 2, 3)) == vector.create(1, 2, 3)) +assert(vector.floor(vector.create(1.5, 2.4, 3)) == vector.create(1, 2, 3)) +assert(vector.floor(vector.create(-1.5, -2.4, -3)) == vector.create(-2, -3, -3)) +assert(select("#", vector.floor(vector.one)) == 1) + +-- 'ceil' +assert(vector.ceil(vector.create(1, 2, 3)) == vector.create(1, 2, 3)) +assert(vector.ceil(vector.create(1.5, 2.4, 3)) == vector.create(2, 3, 3)) +assert(vector.ceil(vector.create(-1.5, -2.4, -3)) == vector.create(-1, -2, -3)) +assert(select("#", vector.ceil(vector.one)) == 1) + +-- 'sign' +assert(vector.sign(vector.zero) == vector.zero) +assert(vector.sign(vector.one) == vector.one) +assert(vector.sign(vector.create(-10, 0, 10)) == vector.create(-1, 0, 1)) +assert(vector.sign(vector.create(math.huge, 0, -math.huge)) == vector.create(1, 0, -1)) +-- negative zero and nan are consistent with math library, even if implementation defined +assert(vector.sign(vector.create(-0, 0, 0)).x == math.sign(-0)) +assert(vector.sign(vector.create(0/0, 0, 0)).x == math.sign(0/0)) +assert(select("#", vector.sign(vector.one)) == 1) + +-- 'angle' +assert(math.abs(vector.angle(vector.create(1, 2, 3), vector.create(4, 5, 6)) - 0.2257259) < 0.00001) +assert(select("#", vector.angle(vector.zero, vector.one)) == 1) +assert(select("#", vector.angle(vector.one, -vector.one, vector.zero)) == 1) + +do + -- random (non-unit) vectors + local rand = { + vector.create(-1.05, -0.04, 1.06), + vector.create(-0.75, 1.71, 1.29), + vector.create(1.94, 0.76, -0.93), + vector.create(0.02, -1.58, 0.20), + vector.create(1.64, -0.76, -0.73), + vector.create(-2.44, 0.66, 1.06), + vector.create(-2.61, 1.01, 0.50), + vector.create(1.21, -2.28, -0.45), + vector.create(-0.31, -0.12, 1.96), + vector.create(1.16, -0.07, -1.93) + } + + -- numeric answers to the tests below (in degrees) + local ans = { + -105.1702, + -69.49491, + 0.0, + -102.9083, + 0.0, + 0.0, + 180.0, + -0.02797646, + -90.0, + 165.8858 + } + + for i,v in ans do + ans[i] = math.rad(ans[i]) + end + + local function fuzzyeq(x, y, eps) return x == y or math.abs(x - y) < (eps or 1e-6) end + + assert(fuzzyeq(vector.angle(rand[10], rand[1]), math.abs(ans[10]))) + assert(fuzzyeq(vector.angle(rand[2], rand[3]), math.abs(ans[1]))) + assert(fuzzyeq(vector.angle(rand[4], rand[5]), math.abs(ans[2]))) + assert(fuzzyeq(vector.angle(vector.zero, rand[6]), math.abs(ans[3]))) + assert(fuzzyeq(vector.angle(vector.one, rand[7]), math.abs(ans[4]))) + assert(fuzzyeq(vector.angle(vector.zero, vector.zero), math.abs(ans[5]))) + assert(fuzzyeq(vector.angle(rand[8], rand[8]), math.abs(ans[6]))) + assert(fuzzyeq(vector.angle(-rand[8], rand[8]), math.abs(ans[7]))) + assert(fuzzyeq(vector.angle(rand[9], rand[9] + vector.create(0, 1, 0) * 0.001), math.abs(ans[8]), 1e-3)) -- slightly more generous eps + assert(fuzzyeq(vector.angle(vector.create(1, 0, 0), vector.create(0, 1, 0)), math.abs(ans[9]))) + + assert(fuzzyeq(vector.angle(rand[10], rand[1], rand[2]), ans[10])) + assert(fuzzyeq(vector.angle(rand[2], rand[3], rand[4]), ans[1])) + assert(fuzzyeq(vector.angle(rand[4], rand[5], rand[5]), ans[2])) + assert(fuzzyeq(vector.angle(vector.zero, rand[6], rand[10]), ans[3])) + assert(fuzzyeq(vector.angle(vector.one, rand[7], rand[10]), ans[4])) + assert(fuzzyeq(vector.angle(vector.zero, vector.zero, vector.zero), ans[5])) + assert(fuzzyeq(vector.angle(rand[8], rand[8], rand[10]), ans[6])) + assert(fuzzyeq(vector.angle(rand[9], rand[9] + vector.create(0, 1, 0) * 0.001, rand[10]), ans[8], 1e-3)) -- slightly more generous eps + assert(fuzzyeq(vector.angle(vector.create(1, 0, 0), vector.create(0, 1, 0), rand[10]), ans[9])) +end + +-- 'min'/'max' +assert(vector.max(vector.create(-1, 2, 0.5)) == vector.create(-1, 2, 0.5)) +assert(vector.min(vector.create(-1, 2, 0.5)) == vector.create(-1, 2, 0.5)) + +assert(ecall(function() vector.min() end) == "missing argument #1 to 'min' (vector expected)") +assert(ecall(function() vector.max() end) == "missing argument #1 to 'max' (vector expected)") + +assert(select("#", vector.max(vector.zero, vector.one)) == 1) +assert(select("#", vector.min(vector.zero, vector.one)) == 1) + +assert(vector.max(vector.create(-1, 2, 3), vector.create(3, 2, 1)) == vector.create(3, 2, 3)) +assert(vector.min(vector.create(-1, 2, 3), vector.create(3, 2, 1)) == vector.create(-1, 2, 1)) + +assert(vector.max(vector.create(1, 2, 3),vector.create(2, 3, 4),vector.create(3, 4, 5),vector.create(4, 5, 6)) == vector.create(4, 5, 6)) +assert(vector.min(vector.create(1, 2, 3),vector.create(2, 3, 4),vector.create(3, 4, 5),vector.create(4, 5, 6)) == vector.create(1, 2, 3)) + +-- clamp +assert(vector.clamp(vector.create(1, 1, 1), vector.create(0, 1, 2), vector.create(3, 3, 3)) == vector.create(1, 1, 2)) +assert(vector.clamp(vector.create(1, 1, 1), vector.create(-1, -1, -1), vector.create(0, 1, 2)) == vector.create(0, 1, 1)) +assert(select("#", vector.clamp(vector.zero, vector.zero, vector.one)) == 1) + +-- validate component access +assert(vector.create(1, 2, 3).x == 1) +assert(vector.create(1, 2, 3).X == 1) +assert(vector.create(1, 2, 3)['X'] == 1) +assert(vector.create(1, 2, 3).y == 2) +assert(vector.create(1, 2, 3).Y == 2) +assert(vector.create(1, 2, 3)['Y'] == 2) +assert(vector.create(1, 2, 3).z == 3) +assert(vector.create(1, 2, 3).Z == 3) +assert(vector.create(1, 2, 3)['Z'] == 3) + +local function getcomp(v: vector, field: string) + return v[field] +end + +assert(getcomp(vector.create(1, 2, 3), 'x') == 1) +assert(getcomp(vector.create(1, 2, 3), 'y') == 2) +assert(getcomp(vector.create(1, 2, 3), 'z') == 3) + +assert(ecall(function() return vector.create(1, 2, 3).zz end) == "attempt to index vector with 'zz'") + +-- additional checks for 4-component vectors +if vector_size == 4 then + assert(vector.create(1, 2, 3, 4).w == 4) + assert(vector.create(1, 2, 3, 4).W == 4) + assert(vector.create(1, 2, 3, 4)['W'] == 4) +end + +return 'OK' diff --git a/tests/main.cpp b/tests/main.cpp index 4612829b..005a3e61 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Common.h" +#include "Luau/CodeGenCommon.h" + #define DOCTEST_CONFIG_IMPLEMENT // Our calls to parseOption/parseFlag don't provide a prefix so set the prefix to the empty string. #define DOCTEST_CONFIG_OPTIONS_PREFIX "" @@ -28,6 +30,7 @@ #endif #include + #include // Indicates if verbose output is enabled; can be overridden via --verbose diff --git a/tests/require/with_config/.luaurc b/tests/require/with_config/.luaurc index 28ebca11..7e7abf18 100644 --- a/tests/require/with_config/.luaurc +++ b/tests/require/with_config/.luaurc @@ -1,5 +1,4 @@ { - "paths": ["GlobalLuauLibraries"], "aliases": { "dep": "this_should_be_overwritten_by_child_luaurc", "otherdep": "src/other_dependency" diff --git a/tests/require/with_config/src/.luaurc b/tests/require/with_config/src/.luaurc index 8c1ae683..90c6b646 100644 --- a/tests/require/with_config/src/.luaurc +++ b/tests/require/with_config/src/.luaurc @@ -1,5 +1,4 @@ { - "paths": ["../ProjectLuauLibraries"], "aliases": { "dep": "dependency", "subdir": "subdirectory" diff --git a/tests/require/with_config/src/fail_requirer.luau b/tests/require/with_config/src/fail_requirer.luau deleted file mode 100644 index 0454f922..00000000 --- a/tests/require/with_config/src/fail_requirer.luau +++ /dev/null @@ -1,2 +0,0 @@ --- shouldn't attempt to search paths array because of "./" prefix -return require("./library") diff --git a/tests/require/with_config/src/global_library_requirer.luau b/tests/require/with_config/src/global_library_requirer.luau deleted file mode 100644 index 747e14f5..00000000 --- a/tests/require/with_config/src/global_library_requirer.luau +++ /dev/null @@ -1,2 +0,0 @@ --- should be required using the paths array in the parent directory's .luaurc -return require("global_library") diff --git a/tests/require/with_config/src/requirer.luau b/tests/require/with_config/src/requirer.luau deleted file mode 100644 index 67028abb..00000000 --- a/tests/require/with_config/src/requirer.luau +++ /dev/null @@ -1,2 +0,0 @@ --- should be required using the paths array in .luaurc -return require("library") diff --git a/tests/require/without_config/ambiguous/directory/dependency.luau b/tests/require/without_config/ambiguous/directory/dependency.luau new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/require/without_config/ambiguous/directory/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/require/without_config/ambiguous/directory/dependency/init.luau b/tests/require/without_config/ambiguous/directory/dependency/init.luau new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/require/without_config/ambiguous/directory/dependency/init.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/require/without_config/ambiguous/file/dependency.lua b/tests/require/without_config/ambiguous/file/dependency.lua new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/require/without_config/ambiguous/file/dependency.lua @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/require/without_config/ambiguous/file/dependency.luau b/tests/require/without_config/ambiguous/file/dependency.luau new file mode 100644 index 00000000..07466f42 --- /dev/null +++ b/tests/require/without_config/ambiguous/file/dependency.luau @@ -0,0 +1 @@ +return {"result from dependency"} diff --git a/tests/require/without_config/ambiguous_directory_requirer.luau b/tests/require/without_config/ambiguous_directory_requirer.luau new file mode 100644 index 00000000..e46be806 --- /dev/null +++ b/tests/require/without_config/ambiguous_directory_requirer.luau @@ -0,0 +1,3 @@ +local result = require("./ambiguous/directory/dependency") +result[#result+1] = "required into module" +return result diff --git a/tests/require/without_config/ambiguous_file_requirer.luau b/tests/require/without_config/ambiguous_file_requirer.luau new file mode 100644 index 00000000..8e3a576d --- /dev/null +++ b/tests/require/without_config/ambiguous_file_requirer.luau @@ -0,0 +1,3 @@ +local result = require("./ambiguous/file/dependency") +result[#result+1] = "required into module" +return result diff --git a/tests/require/without_config/module.luau b/tests/require/without_config/module.luau index 94826b66..1d1393ff 100644 --- a/tests/require/without_config/module.luau +++ b/tests/require/without_config/module.luau @@ -1,3 +1,3 @@ -local result = require("dependency") +local result = require("./dependency") result[#result+1] = "required into module" return result diff --git a/tests/require/without_config/validate_cache.luau b/tests/require/without_config/validate_cache.luau new file mode 100644 index 00000000..8e729af1 --- /dev/null +++ b/tests/require/without_config/validate_cache.luau @@ -0,0 +1,4 @@ +local result1 = require("./dependency") +local result2 = require("./dependency") +assert(result1 == result2) +return {} \ No newline at end of file diff --git a/tools/flag-bisect.py b/tools/flag-bisect.py index 01f3ef7c..55663a78 100644 --- a/tools/flag-bisect.py +++ b/tools/flag-bisect.py @@ -135,7 +135,7 @@ def add_argument_parsers(parser): interestness_parser.add_argument('--auto', dest='mode', action='store_const', const=InterestnessMode.AUTO, default=InterestnessMode.AUTO, help='Automatically figure out which one of --pass or --fail should be used') interestness_parser.add_argument('--fail', dest='mode', action='store_const', const=InterestnessMode.FAIL, - help='You want this if omitting --fflags=true causes tests to fail') + help='You want this if passing --fflags=true causes tests to fail') interestness_parser.add_argument('--pass', dest='mode', action='store_const', const=InterestnessMode.PASS, help='You want this if passing --fflags=true causes tests to pass') interestness_parser.add_argument('--timeout', dest='timeout', type=int, default=0, metavar='SECONDS', diff --git a/tools/fuzz/requirements.txt b/tools/fuzz/requirements.txt index 297ba324..9d7222f0 100644 --- a/tools/fuzz/requirements.txt +++ b/tools/fuzz/requirements.txt @@ -1,2 +1,2 @@ -Jinja2==3.1.4 +Jinja2==3.1.5 MarkupSafe==2.1.3 diff --git a/tools/heapsnapshot.py b/tools/heapsnapshot.py new file mode 100644 index 00000000..d3c0c92d --- /dev/null +++ b/tools/heapsnapshot.py @@ -0,0 +1,221 @@ +#!/usr/bin/python3 +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a Luau heap dump, this tool generates a heap snapshot which can be imported by Chrome's DevTools Memory panel +# To generate a snapshot, use luaC_dump, ideally preceded by luaC_fullgc +# To import in Chrome, ensure the snapshot has the .heapsnapshot extension and go to: Inspect -> Memory -> Load Profile +# A reference for the heap snapshot schema can be found here: https://learn.microsoft.com/en-us/microsoft-edge/devtools-guide-chromium/memory-problems/heap-snapshot-schema + +# Usage: python3 heapsnapshot.py luauDump.json heapSnapshot.heapsnapshot + +import json +import sys + +# Header describing the snapshot format, copied from a real Chrome heap snapshot +snapshotMeta = { + "node_fields": ["type", "name", "id", "self_size", "edge_count", "trace_node_id", "detachedness"], + "node_types": [ + ["hidden", "array", "string", "object", "code", "closure", "regexp", "number", "native", "synthetic", "concatenated string", "sliced string", "symbol", "bigint", "object shape"], + "string", "number", "number", "number", "number", "number" + ], + "edge_fields": ["type", "name_or_index", "to_node"], + "edge_types": [ + ["context", "element", "property", "internal", "hidden", "shortcut", "weak"], + "string_or_number", "node" + ], + "trace_function_info_fields": ["function_id", "name", "script_name", "script_id", "line", "column"], + "trace_node_fields": ["id", "function_info_index", "count", "size", "children"], + "sample_fields": ["timestamp_us", "last_assigned_id"], + "location_fields": ["object_index", "script_id", "line", "column"], +} + +# These indices refer to the index in the snapshot's metadata header +nodeTypeToMetaIndex = {type: i for i, type in enumerate(snapshotMeta["node_types"][0])} +edgeTypeToMetaIndex = {type: i for i, type in enumerate(snapshotMeta["edge_types"][0])} + +nodeFieldCount = len(snapshotMeta["node_fields"]) +edgeFieldCount = len(snapshotMeta["edge_fields"]) + + +def readAddresses(data): + # Ordered list of addresses to ensure the registry is the first node, and also so we can process nodes in index order + addresses = [] + addressToNodeIndex = {} + + def addAddress(address): + assert address not in addressToNodeIndex, f"Address already exists in the snapshot: '{address}'" + addresses.append(address) + addressToNodeIndex[address] = len(addresses) - 1 + + # The registry is a special case that needs to be either the first or last node to ensure gc "distances" are calculated correctly + registryAddress = data["roots"]["registry"] + addAddress(registryAddress) + + for address, obj in data["objects"].items(): + if address == registryAddress: + continue + addAddress(address) + + return addresses, addressToNodeIndex + + +def convertToSnapshot(data): + addresses, addressToNodeIndex = readAddresses(data) + + # Some notable idiosyncrasies with the heap snapshot format: + # 1. The snapshot format contains a flat array of nodes and edges. Oddly, edges must reference the "absolute" index of a node's first element after flattening. + # 2. A node's outgoing edges are implicitly represented by a contiguous block of edges in the edges array which correspond to the node's position + # in the nodes array and its edge count. So if the first node has 3 edges, the first 3 edges in the edges array are its edges, and so on. + + nodes = [] + edges = [] + strings = [] + + stringToSnapshotIndex = {} + + def getUniqueId(address): + # TODO: we should hash this to an int32 instead of using the address directly + # Addresses are hexadecimal strings + return int(address, 16) + + def addNode(node): + assert len(node) == nodeFieldCount, f"Expected {nodeFieldCount} fields, got {len(node)}" + nodes.append(node) + + def addEdge(edge): + assert len(edge) == edgeFieldCount, f"Expected {edgeFieldCount} fields, got {len(edge)}" + edges.append(edge) + + def getStringSnapshotIndex(string): + assert isinstance(string, str), f"'{string}' is not of type string" + if string not in stringToSnapshotIndex: + strings.append(string) + stringToSnapshotIndex[string] = len(strings) - 1 + return stringToSnapshotIndex[string] + + def getNodeSnapshotIndex(address): + # This is the index of the first element of the node in the flattened nodes array + return addressToNodeIndex[address] * nodeFieldCount + + for address in addresses: + obj = data["objects"][address] + edgeCount = 0 + + if obj["type"] == "table": + # TODO: support weak references + name = f"Registry ({address})" if address == data["roots"]["registry"] else f"Luau table ({address})" + if "pairs" in obj: + for i in range(0, len(obj["pairs"]), 2): + key = obj["pairs"][i] + value = obj["pairs"][i + 1] + if key is None and value is None: + # Both the key and value are value types, nothing meaningful to add here + continue + elif key is None: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["property"], getStringSnapshotIndex("(Luau table key value type)"), getNodeSnapshotIndex(value)]) + elif value is None: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'Luau table key ref: {data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(key)]) + elif data["objects"][key]["type"] == "string": + edgeCount += 2 + # This is a special case where the key is a string, so we can use it as the edge name + addEdge([edgeTypeToMetaIndex["property"], getStringSnapshotIndex(data["objects"][key]["data"]), getNodeSnapshotIndex(value)]) + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'Luau table key ref: {data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(key)]) + else: + edgeCount += 2 + addEdge([edgeTypeToMetaIndex["property"], getStringSnapshotIndex(f'{data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(value)]) + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'Luau table key ref: {data["objects"][key]["type"]} ({key})'), getNodeSnapshotIndex(key)]) + if "array" in obj: + for i, element in enumerate(obj["array"]): + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["element"], i, getNodeSnapshotIndex(element)]) + if "metatable" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'metatable ({obj["metatable"]})'), getNodeSnapshotIndex(obj["metatable"])]) + # TODO: consider distinguishing "object" and "array" node types + addNode([nodeTypeToMetaIndex["object"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "thread": + name = f'Luau thread: {obj["source"]}:{obj["line"]} ({address})' if "source" in obj else f"Luau thread ({address})" + if address == data["roots"]["mainthread"]: + name += " (main thread)" + if "env" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'env ({obj["env"]})'), getNodeSnapshotIndex(obj["env"])]) + if "stack" in obj: + for i, frame in enumerate(obj["stack"]): + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f"callstack[{i}]"), getNodeSnapshotIndex(frame)]) + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "function": + name = f'Luau function: {obj["name"]} ({address})' if "name" in obj else f"Luau anonymous function ({address})" + if "env" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'env ({obj["env"]})'), getNodeSnapshotIndex(obj["env"])]) + if "proto" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'proto ({obj["proto"]})'), getNodeSnapshotIndex(obj["proto"])]) + if "upvalues" in obj: + for i, upvalue in enumerate(obj["upvalues"]): + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f"up value ({upvalue})"), getNodeSnapshotIndex(upvalue)]) + addNode([nodeTypeToMetaIndex["closure"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "upvalue": + if "object" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(f'upvalue object ({obj["object"]})'), getNodeSnapshotIndex(obj["object"])]) + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(f"Luau upvalue ({address})"), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "userdata": + if "metatable" in obj: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["internal"], getStringSnapshotIndex(f'metatable ({obj["metatable"]})'), getNodeSnapshotIndex(obj["metatable"])]) + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(f"Luau userdata ({address})"), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "proto": + name = f'Luau proto: {obj["source"]}:{obj["line"]} ({address})' if "source" in obj else f"Luau proto ({address})" + if "constants" in obj: + for constant in obj["constants"]: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(constant), getNodeSnapshotIndex(constant)]) + if "protos" in obj: + for proto in obj["protos"]: + edgeCount += 1 + addEdge([edgeTypeToMetaIndex["context"], getStringSnapshotIndex(proto), getNodeSnapshotIndex(proto)]) + addNode([nodeTypeToMetaIndex["code"], getStringSnapshotIndex(name), getUniqueId(address), obj["size"], edgeCount, 0, 0]) + elif obj["type"] == "string": + addNode([nodeTypeToMetaIndex["string"], getStringSnapshotIndex(obj["data"]), getUniqueId(address), obj["size"], 0, 0, 0]) + elif obj["type"] == "buffer": + addNode([nodeTypeToMetaIndex["native"], getStringSnapshotIndex(f'buffer ({address})'), getUniqueId(address), obj["size"], 0, 0, 0]) + else: + raise Exception(f"Unknown object type: '{obj['type']}'") + + return { + "snapshot": { + "meta": snapshotMeta, + "node_count": len(nodes), + "edge_count": len(edges), + "trace_function_count": 0, + }, + # flatten the nodes and edges arrays + "nodes": [field for node in nodes for field in node], + "edges": [field for edge in edges for field in edge], + "trace_function_infos": [], + "trace_tree": [], + "samples": [], + "locations": [], + "strings": strings, + } + + +if __name__ == "__main__": + luauDump = sys.argv[1] + heapSnapshot = sys.argv[2] + + with open(luauDump, "r") as file: + dump = json.load(file) + + snapshot = convertToSnapshot(dump) + + with open(heapSnapshot, "w") as file: + json.dump(snapshot, file) + + print(f"Heap snapshot written to: '{heapSnapshot}'") diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis index 59bc43c4..adf603eb 100644 --- a/tools/natvis/VM.natvis +++ b/tools/natvis/VM.natvis @@ -77,7 +77,7 @@ --- - + table metatable