diff --git a/.github/workflows/benchmark-dev.yml b/.github/workflows/benchmark-dev.yml deleted file mode 100644 index b6115acc..00000000 --- a/.github/workflows/benchmark-dev.yml +++ /dev/null @@ -1,185 +0,0 @@ -name: benchmark-dev - -on: - push: - branches: - - master - - paths-ignore: - - "docs/**" - - "papers/**" - - "rfcs/**" - - "*.md" - -jobs: - windows: - name: windows-${{matrix.arch}} - strategy: - fail-fast: false - matrix: - os: [windows-latest] - arch: [Win32, x64] - bench: - - { - script: "run-benchmarks", - timeout: 12, - title: "Luau Benchmarks", - } - benchResultsRepo: - - { name: "luau-lang/benchmark-data", branch: "main" } - - runs-on: ${{ matrix.os }} - steps: - - name: Checkout Luau repository - uses: actions/checkout@v3 - - - name: Build Luau - shell: bash # necessary for fail-fast - run: | - mkdir build && cd build - cmake .. -DCMAKE_BUILD_TYPE=Release - cmake --build . --target Luau.Repl.CLI --config Release - cmake --build . --target Luau.Analyze.CLI --config Release - - - name: Move build files to root - run: | - move build/Release/* . - - - uses: actions/setup-python@v3 - with: - python-version: "3.9" - architecture: "x64" - - - name: Install python dependencies - run: | - python -m pip install requests - python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose - - - name: Run benchmark - run: | - python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt - - - name: Push benchmark results - id: pushBenchmarkAttempt1 - continue-on-error: true - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})" - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" - - - name: Push benchmark results (Attempt 2) - id: pushBenchmarkAttempt2 - continue-on-error: true - if: steps.pushBenchmarkAttempt1.outcome == 'failure' - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})" - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" - - - name: Push benchmark results (Attempt 3) - id: pushBenchmarkAttempt3 - continue-on-error: true - if: steps.pushBenchmarkAttempt2.outcome == 'failure' - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: "${{ matrix.bench.title }} (Windows ${{matrix.arch}})" - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" - - unix: - name: ${{matrix.os}} - strategy: - fail-fast: false - matrix: - os: [ubuntu-20.04, macos-latest] - bench: - - { - script: "run-benchmarks", - timeout: 12, - title: "Luau Benchmarks", - } - benchResultsRepo: - - { name: "luau-lang/benchmark-data", branch: "main" } - - runs-on: ${{ matrix.os }} - steps: - - name: Checkout Luau repository - uses: actions/checkout@v3 - - - name: Build Luau - run: make config=release luau luau-analyze - - - uses: actions/setup-python@v3 - with: - python-version: "3.9" - architecture: "x64" - - - name: Install python dependencies - run: | - python -m pip install requests - python -m pip install --user numpy scipy matplotlib ipython jupyter pandas sympy nose - - - name: Run benchmark - run: | - python bench/bench.py | tee ${{ matrix.bench.script }}-output.txt - - - name: Push benchmark results - id: pushBenchmarkAttempt1 - continue-on-error: true - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: ${{ matrix.bench.title }} - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" - - - name: Push benchmark results (Attempt 2) - id: pushBenchmarkAttempt2 - continue-on-error: true - if: steps.pushBenchmarkAttempt1.outcome == 'failure' - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: ${{ matrix.bench.title }} - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" - - - name: Push benchmark results (Attempt 3) - id: pushBenchmarkAttempt3 - continue-on-error: true - if: steps.pushBenchmarkAttempt2.outcome == 'failure' - uses: ./.github/workflows/push-results - with: - repository: ${{ matrix.benchResultsRepo.name }} - branch: ${{ matrix.benchResultsRepo.branch }} - token: ${{ secrets.BENCH_GITHUB_TOKEN }} - path: "./gh-pages" - bench_name: ${{ matrix.bench.title }} - bench_tool: "benchmarkluau" - bench_output_file_path: "./${{ matrix.bench.script }}-output.txt" - bench_external_data_json_path: "./gh-pages/dev/bench/data-${{ matrix.os }}.json" diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index c7531608..7a11fbe1 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -25,6 +25,7 @@ jobs: - name: Install valgrind run: | + sudo apt-get update sudo apt-get install valgrind - name: Build Luau (gcc) diff --git a/.github/workflows/push-results/action.yml b/.github/workflows/push-results/action.yml deleted file mode 100644 index b5ffebee..00000000 --- a/.github/workflows/push-results/action.yml +++ /dev/null @@ -1,63 +0,0 @@ -name: Checkout & push results -description: Checkout a given repo and push results to GitHub -inputs: - repository: - required: true - type: string - description: The benchmark results repository to check out - branch: - required: true - type: string - description: The benchmark results repository's branch to check out - token: - required: true - type: string - description: The GitHub token to use for pushing results - path: - required: true - type: string - description: The path to check out the results repository to - bench_name: - required: true - type: string - bench_tool: - required: true - type: string - bench_output_file_path: - required: true - type: string - bench_external_data_json_path: - required: true - type: string - -runs: - using: "composite" - steps: - - name: Checkout repository - uses: actions/checkout@v3 - with: - repository: ${{ inputs.repository }} - ref: ${{ inputs.branch }} - token: ${{ inputs.token }} - path: ${{ inputs.path }} - - - name: Store results - uses: Roblox/rhysd-github-action-benchmark@v-luau - with: - name: ${{ inputs.bench_name }} - tool: ${{ inputs.bench_tool }} - gh-pages-branch: ${{ inputs.branch }} - output-file-path: ${{ inputs.bench_output_file_path }} - external-data-json-path: ${{ inputs.bench_external_data_json_path }} - - - name: Push benchmark results - shell: bash - run: | - echo "Pushing benchmark results..." - cd gh-pages - git config user.name github-actions - git config user.email github@users.noreply.github.com - git add *.json - git commit -m "Add benchmarks results for ${{ github.sha }}" - git push - cd .. diff --git a/.gitignore b/.gitignore index af77f73c..528ab204 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ /luau /luau-tests /luau-analyze +/luau-compile __pycache__ diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index d4457638..6154f3d1 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -14,8 +14,6 @@ struct GlobalTypes; struct TypeChecker; struct TypeArena; -void registerBuiltinTypes(GlobalTypes& globals); - void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete = false); TypeId makeUnion(TypeArena& arena, std::vector&& types); TypeId makeIntersection(TypeArena& arena, std::vector&& types); diff --git a/Analysis/include/Luau/Clone.h b/Analysis/include/Luau/Clone.h index b3cbe467..28ab9931 100644 --- a/Analysis/include/Luau/Clone.h +++ b/Analysis/include/Luau/Clone.h @@ -16,6 +16,8 @@ using SeenTypePacks = std::unordered_map; struct CloneState { + NotNull builtinTypes; + SeenTypes seenTypes; SeenTypePacks seenTypePacks; diff --git a/Analysis/include/Luau/ConstraintGraphBuilder.h b/Analysis/include/Luau/ConstraintGraphBuilder.h index 1d3f20ee..902da0d5 100644 --- a/Analysis/include/Luau/ConstraintGraphBuilder.h +++ b/Analysis/include/Luau/ConstraintGraphBuilder.h @@ -101,9 +101,10 @@ struct ConstraintGraphBuilder DcrLogger* logger; - ConstraintGraphBuilder(ModulePtr module, NotNull normalizer, NotNull moduleResolver, NotNull builtinTypes, - NotNull ice, const ScopePtr& globalScope, std::function prepareModuleScope, - DcrLogger* logger, NotNull dfg, std::vector requireCycles); + ConstraintGraphBuilder(ModulePtr module, NotNull normalizer, NotNull moduleResolver, + NotNull builtinTypes, NotNull ice, const ScopePtr& globalScope, + std::function prepareModuleScope, DcrLogger* logger, NotNull dfg, + std::vector requireCycles); /** * Fabricates a new free type belonging to a given scope. diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 47effcea..76520b2c 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -279,8 +279,6 @@ private: TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; - TypeId unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes); - TypePackId anyifyModuleReturnTypePackGenerics(TypePackId tp); void throwTimeLimitError(); diff --git a/Analysis/include/Luau/Differ.h b/Analysis/include/Luau/Differ.h index e9656ad4..43c44b71 100644 --- a/Analysis/include/Luau/Differ.h +++ b/Analysis/include/Luau/Differ.h @@ -128,10 +128,10 @@ struct DiffError checkValidInitialization(left, right); } - std::string toString() const; + std::string toString(bool multiLine = false) const; private: - std::string toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf) const; + std::string toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf, bool multiLine) const; void checkValidInitialization(const DiffPathNodeLeaf& left, const DiffPathNodeLeaf& right); void checkNonMissingPropertyLeavesHaveNulloptTableProperty() const; }; @@ -152,12 +152,17 @@ struct DifferEnvironment { TypeId rootLeft; TypeId rootRight; + std::optional externalSymbolLeft; + std::optional externalSymbolRight; DenseHashMap genericMatchedPairs; DenseHashMap genericTpMatchedPairs; - DifferEnvironment(TypeId rootLeft, TypeId rootRight) + DifferEnvironment( + TypeId rootLeft, TypeId rootRight, std::optional externalSymbolLeft, std::optional externalSymbolRight) : rootLeft(rootLeft) , rootRight(rootRight) + , externalSymbolLeft(externalSymbolLeft) + , externalSymbolRight(externalSymbolRight) , genericMatchedPairs(nullptr) , genericTpMatchedPairs(nullptr) { @@ -170,6 +175,8 @@ struct DifferEnvironment void popVisiting(); std::vector>::const_reverse_iterator visitingBegin() const; std::vector>::const_reverse_iterator visitingEnd() const; + std::string getDevFixFriendlyNameLeft() const; + std::string getDevFixFriendlyNameRight() const; private: // TODO: consider using DenseHashSet @@ -179,6 +186,7 @@ private: std::vector> visitingStack; }; DifferResult diff(TypeId ty1, TypeId ty2); +DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional symbol1, std::optional symbol2); /** * True if ty is a "simple" type, i.e. cannot contain types. diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 13758b37..db3e8e55 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Location.h" +#include "Luau/NotNull.h" #include "Luau/Type.h" #include "Luau/Variant.h" @@ -432,7 +433,7 @@ std::string toString(const TypeError& error, TypeErrorToStringOptions options); bool containsParseErrorName(const TypeError& error); // Copy any types named in the error into destArena. -void copyErrors(ErrorVec& errors, struct TypeArena& destArena); +void copyErrors(ErrorVec& errors, struct TypeArena& destArena, NotNull builtinTypes); // Internal Compiler Error struct InternalErrorReporter diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 5853eb32..afe47322 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -2,12 +2,12 @@ #pragma once #include "Luau/Config.h" +#include "Luau/GlobalTypes.h" #include "Luau/Module.h" #include "Luau/ModuleResolver.h" #include "Luau/RequireTracer.h" #include "Luau/Scope.h" #include "Luau/TypeCheckLimits.h" -#include "Luau/TypeInfer.h" #include "Luau/Variant.h" #include @@ -100,6 +100,12 @@ struct FrontendOptions std::optional enabledLintWarnings; std::shared_ptr cancellationToken; + + // Time limit for typechecking a single module + std::optional moduleTimeLimitSec; + + // When true, some internal complexity limits will be scaled down for modules that miss the limit set by moduleTimeLimitSec + bool applyInternalLimitScaling = false; }; struct CheckResult diff --git a/Analysis/include/Luau/GlobalTypes.h b/Analysis/include/Luau/GlobalTypes.h new file mode 100644 index 00000000..7a34f935 --- /dev/null +++ b/Analysis/include/Luau/GlobalTypes.h @@ -0,0 +1,26 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#pragma once + +#include "Luau/Module.h" +#include "Luau/NotNull.h" +#include "Luau/Scope.h" +#include "Luau/TypeArena.h" + +namespace Luau +{ + +struct BuiltinTypes; + +struct GlobalTypes +{ + explicit GlobalTypes(NotNull builtinTypes); + + NotNull builtinTypes; // Global types are based on builtin types + + TypeArena globalTypes; + SourceModule globalNames; // names for symbols entered into globalScope + ScopePtr globalScope; // shared by all modules +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Instantiation.h b/Analysis/include/Luau/Instantiation.h index 642f2b9e..1dbf6b67 100644 --- a/Analysis/include/Luau/Instantiation.h +++ b/Analysis/include/Luau/Instantiation.h @@ -17,8 +17,8 @@ struct TypeCheckLimits; // A substitution which replaces generic types in a given set by free types. struct ReplaceGenerics : Substitution { - ReplaceGenerics(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope, const std::vector& generics, - const std::vector& genericPacks) + ReplaceGenerics(const TxnLog* log, TypeArena* arena, NotNull builtinTypes, TypeLevel level, Scope* scope, + const std::vector& generics, const std::vector& genericPacks) : Substitution(log, arena) , builtinTypes(builtinTypes) , level(level) @@ -77,6 +77,7 @@ struct Instantiation : Substitution * Instantiation fails only when processing the type causes internal recursion * limits to be exceeded. */ -std::optional instantiate(NotNull builtinTypes, NotNull arena, NotNull limits, NotNull scope, TypeId ty); +std::optional instantiate( + NotNull builtinTypes, NotNull arena, NotNull limits, NotNull scope, TypeId ty); } // namespace Luau diff --git a/Analysis/include/Luau/Linter.h b/Analysis/include/Luau/Linter.h index 6bbc3d66..303cd9e9 100644 --- a/Analysis/include/Luau/Linter.h +++ b/Analysis/include/Luau/Linter.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/LinterConfig.h" #include "Luau/Location.h" #include @@ -15,86 +16,15 @@ class AstStat; class AstNameTable; struct TypeChecker; struct Module; -struct HotComment; using ScopePtr = std::shared_ptr; -struct LintWarning -{ - // Make sure any new lint codes are documented here: https://luau-lang.org/lint - // Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints - enum Code - { - Code_Unknown = 0, - - Code_UnknownGlobal = 1, // superseded by type checker - Code_DeprecatedGlobal = 2, - Code_GlobalUsedAsLocal = 3, - Code_LocalShadow = 4, // disabled in Studio - Code_SameLineStatement = 5, // disabled in Studio - Code_MultiLineStatement = 6, - Code_LocalUnused = 7, // disabled in Studio - Code_FunctionUnused = 8, // disabled in Studio - Code_ImportUnused = 9, // disabled in Studio - Code_BuiltinGlobalWrite = 10, - Code_PlaceholderRead = 11, - Code_UnreachableCode = 12, - Code_UnknownType = 13, - Code_ForRange = 14, - Code_UnbalancedAssignment = 15, - Code_ImplicitReturn = 16, // disabled in Studio, superseded by type checker in strict mode - Code_DuplicateLocal = 17, - Code_FormatString = 18, - Code_TableLiteral = 19, - Code_UninitializedLocal = 20, - Code_DuplicateFunction = 21, - Code_DeprecatedApi = 22, - Code_TableOperations = 23, - Code_DuplicateCondition = 24, - Code_MisleadingAndOr = 25, - Code_CommentDirective = 26, - Code_IntegerParsing = 27, - Code_ComparisonPrecedence = 28, - - Code__Count - }; - - Code code; - Location location; - std::string text; - - static const char* getName(Code code); - static Code parseName(const char* name); - static uint64_t parseMask(const std::vector& hotcomments); -}; - struct LintResult { std::vector errors; std::vector warnings; }; -struct LintOptions -{ - uint64_t warningMask = 0; - - void enableWarning(LintWarning::Code code) - { - warningMask |= 1ull << code; - } - void disableWarning(LintWarning::Code code) - { - warningMask &= ~(1ull << code); - } - - bool isEnabled(LintWarning::Code code) const - { - return 0 != (warningMask & (1ull << code)); - } - - void setDefaults(); -}; - std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const std::vector& hotcomments, const LintOptions& options); diff --git a/Analysis/include/Luau/Metamethods.h b/Analysis/include/Luau/Metamethods.h index 84b0092f..747b7201 100644 --- a/Analysis/include/Luau/Metamethods.h +++ b/Analysis/include/Luau/Metamethods.h @@ -19,6 +19,7 @@ static const std::unordered_map kBinaryOpMetamet {AstExprBinary::Op::Sub, "__sub"}, {AstExprBinary::Op::Mul, "__mul"}, {AstExprBinary::Op::Div, "__div"}, + {AstExprBinary::Op::FloorDiv, "__idiv"}, {AstExprBinary::Op::Pow, "__pow"}, {AstExprBinary::Op::Mod, "__mod"}, {AstExprBinary::Op::Concat, "__concat"}, diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index cb761714..d647750f 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -111,6 +111,7 @@ struct Module LintResult lintResult; Mode mode; SourceCode::Type type; + double checkDurationSec = 0.0; bool timeout = false; bool cancelled = false; diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 398d0ab6..b18c6484 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -147,9 +147,6 @@ struct Tarjan void visitEdge(int index, int parentIndex); void visitSCC(int index); - TarjanResult loop_DEPRECATED(); - void visitSCC_DEPRECATED(int index); - // Each subclass can decide to ignore some nodes. virtual bool ignoreChildren(TypeId ty) { @@ -178,13 +175,6 @@ struct Tarjan virtual bool isDirty(TypePackId tp) = 0; virtual void foundDirty(TypeId ty) = 0; virtual void foundDirty(TypePackId tp) = 0; - - // TODO: remove with FFlagLuauTarjanSingleArr - std::vector indexToType; - std::vector indexToPack; - std::vector onStack; - std::vector lowlink; - std::vector dirty; }; // And finally substitution, which finds all the reachable dirty vertices diff --git a/Analysis/include/Luau/Subtyping.h b/Analysis/include/Luau/Subtyping.h new file mode 100644 index 00000000..de702d53 --- /dev/null +++ b/Analysis/include/Luau/Subtyping.h @@ -0,0 +1,146 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Type.h" +#include "Luau/TypePack.h" +#include "Luau/UnifierSharedState.h" + +#include +#include + +namespace Luau +{ + +template +struct TryPair; +struct InternalErrorReporter; + +class TypeIds; +class Normalizer; +struct NormalizedType; +struct NormalizedClassType; +struct NormalizedStringType; +struct NormalizedFunctionType; + + +struct SubtypingResult +{ + bool isSubtype = false; + bool isErrorSuppressing = false; + bool normalizationTooComplex = false; + + SubtypingResult& andAlso(const SubtypingResult& other); + SubtypingResult& orElse(const SubtypingResult& other); + + // Only negates the `isSubtype`. + static SubtypingResult negate(const SubtypingResult& result); + static SubtypingResult all(const std::vector& results); + static SubtypingResult any(const std::vector& results); +}; + +struct Subtyping +{ + NotNull builtinTypes; + NotNull arena; + NotNull normalizer; + NotNull iceReporter; + + NotNull scope; + + enum class Variance + { + Covariant, + Contravariant + }; + + Variance variance = Variance::Covariant; + + struct GenericBounds + { + DenseHashSet lowerBound{nullptr}; + DenseHashSet upperBound{nullptr}; + }; + + /* + * When we encounter a generic over the course of a subtyping test, we need + * to tentatively map that generic onto a type on the other side. + */ + DenseHashMap mappedGenerics{nullptr}; + DenseHashMap mappedGenericPacks{nullptr}; + + using SeenSet = std::unordered_set, TypeIdPairHash>; + + SeenSet seenTypes; + + Subtyping(const Subtyping&) = delete; + Subtyping& operator=(const Subtyping&) = delete; + + Subtyping(Subtyping&&) = default; + Subtyping& operator=(Subtyping&&) = default; + + // TODO cache + // TODO cyclic types + // TODO recursion limits + + SubtypingResult isSubtype(TypeId subTy, TypeId superTy); + SubtypingResult isSubtype(TypePackId subTy, TypePackId superTy); + +private: + SubtypingResult isCovariantWith(TypeId subTy, TypeId superTy); + SubtypingResult isCovariantWith(TypePackId subTy, TypePackId superTy); + + template + SubtypingResult isContravariantWith(SubTy&& subTy, SuperTy&& superTy); + + template + SubtypingResult isInvariantWith(SubTy&& subTy, SuperTy&& superTy); + + template + SubtypingResult isCovariantWith(const TryPair& pair); + + template + SubtypingResult isContravariantWith(const TryPair& pair); + + template + SubtypingResult isInvariantWith(const TryPair& pair); + + SubtypingResult isCovariantWith(TypeId subTy, const UnionType* superUnion); + SubtypingResult isCovariantWith(const UnionType* subUnion, TypeId superTy); + SubtypingResult isCovariantWith(TypeId subTy, const IntersectionType* superIntersection); + SubtypingResult isCovariantWith(const IntersectionType* subIntersection, TypeId superTy); + + SubtypingResult isCovariantWith(const NegationType* subNegation, TypeId superTy); + SubtypingResult isCovariantWith(const TypeId subTy, const NegationType* superNegation); + + SubtypingResult isCovariantWith(const PrimitiveType* subPrim, const PrimitiveType* superPrim); + SubtypingResult isCovariantWith(const SingletonType* subSingleton, const PrimitiveType* superPrim); + SubtypingResult isCovariantWith(const SingletonType* subSingleton, const SingletonType* superSingleton); + SubtypingResult isCovariantWith(const TableType* subTable, const TableType* superTable); + SubtypingResult isCovariantWith(const MetatableType* subMt, const MetatableType* superMt); + SubtypingResult isCovariantWith(const MetatableType* subMt, const TableType* superTable); + SubtypingResult isCovariantWith(const ClassType* subClass, const ClassType* superClass); + SubtypingResult isCovariantWith(const ClassType* subClass, const TableType* superTable); + SubtypingResult isCovariantWith(const FunctionType* subFunction, const FunctionType* superFunction); + SubtypingResult isCovariantWith(const PrimitiveType* subPrim, const TableType* superTable); + SubtypingResult isCovariantWith(const SingletonType* subSingleton, const TableType* superTable); + + SubtypingResult isCovariantWith(const NormalizedType* subNorm, const NormalizedType* superNorm); + SubtypingResult isCovariantWith(const NormalizedClassType& subClass, const NormalizedClassType& superClass); + SubtypingResult isCovariantWith(const NormalizedClassType& subClass, const TypeIds& superTables); + SubtypingResult isCovariantWith(const NormalizedStringType& subString, const NormalizedStringType& superString); + SubtypingResult isCovariantWith(const NormalizedStringType& subString, const TypeIds& superTables); + SubtypingResult isCovariantWith(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction); + SubtypingResult isCovariantWith(const TypeIds& subTypes, const TypeIds& superTypes); + + SubtypingResult isCovariantWith(const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic); + + bool bindGeneric(TypeId subTp, TypeId superTp); + bool bindGeneric(TypePackId subTp, TypePackId superTp); + + template + TypeId makeAggregateType(const Container& container, TypeId orElse); + + [[noreturn]] void unexpected(TypePackId tp); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index c152fc02..d43266cb 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -733,9 +733,17 @@ struct Type final using SeenSet = std::set>; bool areEqual(SeenSet& seen, const Type& lhs, const Type& rhs); +enum class FollowOption +{ + Normal, + DisableLazyTypeThunks, +}; + // Follow BoundTypes until we get to something real TypeId follow(TypeId t); +TypeId follow(TypeId t, FollowOption followOption); TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)); +TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId)); std::vector flattenIntersection(TypeId ty); @@ -790,12 +798,13 @@ struct BuiltinTypes TypeId errorRecoveryType() const; TypePackId errorRecoveryTypePack() const; + friend TypeId makeStringMetatable(NotNull builtinTypes); + friend struct GlobalTypes; + private: std::unique_ptr arena; bool debugFreezeArena = false; - TypeId makeStringMetatable(); - public: const TypeId nilType; const TypeId numberType; @@ -818,6 +827,7 @@ public: const TypeId optionalNumberType; const TypeId optionalStringType; + const TypePackId emptyTypePack; const TypePackId anyTypePack; const TypePackId neverTypePack; const TypePackId uninhabitableTypePack; @@ -839,6 +849,18 @@ bool isSubclass(const ClassType* cls, const ClassType* parent); Type* asMutable(TypeId ty); +template +bool is(T&& tv) +{ + if (!tv) + return false; + + if constexpr (std::is_same_v && !(std::is_same_v || ...)) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); + + return (get(tv) || ...); +} + template const T* get(TypeId tv) { diff --git a/Analysis/include/Luau/TypeChecker2.h b/Analysis/include/Luau/TypeChecker2.h index db81d9cf..aeeab0f8 100644 --- a/Analysis/include/Luau/TypeChecker2.h +++ b/Analysis/include/Luau/TypeChecker2.h @@ -14,7 +14,7 @@ struct DcrLogger; struct TypeCheckLimits; struct UnifierSharedState; -void check(NotNull builtinTypes, NotNull sharedState, NotNull limits, DcrLogger* logger, const SourceModule& sourceModule, - Module* module); +void check(NotNull builtinTypes, NotNull sharedState, NotNull limits, DcrLogger* logger, + const SourceModule& sourceModule, Module* module); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 9a44af49..abae8b92 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -57,17 +57,6 @@ struct HashBoolNamePair size_t operator()(const std::pair& pair) const; }; -struct GlobalTypes -{ - GlobalTypes(NotNull builtinTypes); - - NotNull builtinTypes; // Global types are based on builtin types - - TypeArena globalTypes; - SourceModule globalNames; // names for symbols entered into globalScope - ScopePtr globalScope; // shared by all modules -}; - // All Types are retained via Environment::types. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index 793415ee..4d41926c 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -101,6 +101,31 @@ ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypeId ty1 */ ErrorSuppression shouldSuppressErrors(NotNull normalizer, TypePackId tp1, TypePackId tp2); +// Similar to `std::optional>`, but whose `sizeof()` is the same as `std::pair` +// and cooperates with C++'s `if (auto p = ...)` syntax without the extra fatness of `std::optional`. +template +struct TryPair +{ + A first; + B second; + + explicit operator bool() const + { + return bool(first) && bool(second); + } +}; + +template +TryPair get2(Ty one, Ty two) +{ + const A* a = get(one); + const B* b = get(two); + if (a && b) + return {a, b}; + else + return {nullptr, nullptr}; +} + template const T* get(std::optional ty) { diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index f7c5c94c..1260ac93 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -105,10 +105,12 @@ struct Unifier * Populate the vector errors with any type errors that may arise. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. */ - void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); + void tryUnify( + TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); private: - void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); + void tryUnify_( + TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false, const LiteralProperties* aliasableMap = nullptr); void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy); // Traverse the two types provided and block on any BlockedTypes we find. diff --git a/Analysis/include/Luau/Unifier2.h b/Analysis/include/Luau/Unifier2.h index cf769da3..6d32e03f 100644 --- a/Analysis/include/Luau/Unifier2.h +++ b/Analysis/include/Luau/Unifier2.h @@ -55,8 +55,8 @@ struct Unifier2 bool unify(TypePackId subTp, TypePackId superTp); std::optional generalize(NotNull scope, TypeId ty); -private: +private: /** * @returns simplify(left | right) */ @@ -72,4 +72,4 @@ private: OccursCheckResult occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); }; -} +} // namespace Luau diff --git a/Analysis/include/Luau/VisitType.h b/Analysis/include/Luau/VisitType.h index a84fb48c..28dfffbf 100644 --- a/Analysis/include/Luau/VisitType.h +++ b/Analysis/include/Luau/VisitType.h @@ -10,6 +10,7 @@ LUAU_FASTINT(LuauVisitRecursionLimit) LUAU_FASTFLAG(LuauBoundLazyTypes2) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(DebugLuauReadWriteProperties) namespace Luau @@ -220,7 +221,21 @@ struct GenericTypeVisitor traverse(btv->boundTo); } else if (auto ftv = get(ty)) - visit(ty, *ftv); + { + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (visit(ty, *ftv)) + { + LUAU_ASSERT(ftv->lowerBound); + traverse(ftv->lowerBound); + + LUAU_ASSERT(ftv->upperBound); + traverse(ftv->upperBound); + } + } + else + visit(ty, *ftv); + } else if (auto gtv = get(ty)) visit(ty, *gtv); else if (auto etv = get(ty)) diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index f2943c4d..5c4f9504 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -8,6 +8,8 @@ #include +LUAU_FASTFLAG(LuauFloorDivision) + namespace Luau { @@ -514,6 +516,9 @@ struct AstJsonEncoder : public AstVisitor return writeString("Mul"); case AstExprBinary::Div: return writeString("Div"); + case AstExprBinary::FloorDiv: + LUAU_ASSERT(FFlag::LuauFloorDivision); + return writeString("FloorDiv"); case AstExprBinary::Mod: return writeString("Mod"); case AstExprBinary::Pow: @@ -536,6 +541,8 @@ struct AstJsonEncoder : public AstVisitor return writeString("And"); case AstExprBinary::Or: return writeString("Or"); + default: + LUAU_ASSERT(!"Unknown Op"); } } diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index 6a6f10e8..a71dd592 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -12,6 +12,7 @@ #include LUAU_FASTFLAG(DebugLuauReadWriteProperties) +LUAU_FASTFLAGVARIABLE(FixFindBindingAtFunctionName, false); namespace Luau { @@ -148,6 +149,23 @@ struct FindNode : public AstVisitor return false; } + bool visit(AstStatFunction* node) override + { + if (FFlag::FixFindBindingAtFunctionName) + { + visit(static_cast(node)); + if (node->name->location.contains(pos)) + node->name->visit(this); + else if (node->func->location.contains(pos)) + node->func->visit(this); + return false; + } + else + { + return AstVisitor::visit(node); + } + } + bool visit(AstStatBlock* block) override { visit(static_cast(block)); @@ -188,6 +206,23 @@ struct FindFullAncestry final : public AstVisitor return false; } + bool visit(AstStatFunction* node) override + { + if (FFlag::FixFindBindingAtFunctionName) + { + visit(static_cast(node)); + if (node->name->location.contains(pos)) + node->name->visit(this); + else if (node->func->location.contains(pos)) + node->func->visit(this); + return false; + } + else + { + return AstVisitor::visit(node); + } + } + bool visit(AstNode* node) override { if (node->location.contains(pos)) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index eaf47b77..3eba2e12 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,10 +13,8 @@ #include LUAU_FASTFLAG(DebugLuauReadWriteProperties) -LUAU_FASTFLAGVARIABLE(LuauDisableCompletionOutsideQuotes, false) -LUAU_FASTFLAGVARIABLE(LuauAnonymousAutofilled1, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteLastTypecheck, false) -LUAU_FASTFLAGVARIABLE(LuauAutocompleteHideSelfArg, false) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteDoEnd, false) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteStringLiteralBounds, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -284,38 +282,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, NotNul ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); - if (FFlag::LuauAutocompleteHideSelfArg) - { - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Property, - type, - prop.deprecated, - isWrongIndexer(type), - typeCorrect, - containingClass, - &prop, - prop.documentationSymbol, - {}, - parens, - {}, - indexType == PropIndexType::Colon - }; - } - else - { - result[name] = AutocompleteEntry{ - AutocompleteEntryKind::Property, - type, - prop.deprecated, - isWrongIndexer(type), - typeCorrect, - containingClass, - &prop, - prop.documentationSymbol, - {}, - parens - }; - } + result[name] = AutocompleteEntry{AutocompleteEntryKind::Property, type, prop.deprecated, isWrongIndexer(type), typeCorrect, + containingClass, &prop, prop.documentationSymbol, {}, parens, {}, indexType == PropIndexType::Colon}; } } }; @@ -485,8 +453,19 @@ AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position posi return result; } -static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AutocompleteEntryMap& result) +static void autocompleteStringSingleton(TypeId ty, bool addQuotes, AstNode* node, Position position, AutocompleteEntryMap& result) { + if (FFlag::LuauAutocompleteStringLiteralBounds) + { + if (position == node->location.begin || position == node->location.end) + { + if (auto str = node->as(); str && str->quoteStyle == AstExprConstantString::Quoted) + return; + else if (node->is()) + return; + } + } + auto formatKey = [addQuotes](const std::string& key) { if (addQuotes) return "\"" + escape(key) + "\""; @@ -615,10 +594,9 @@ std::optional getLocalTypeInScopeAt(const Module& module, Position posit return {}; } -template +template static std::optional tryToStringDetailed(const ScopePtr& scope, T ty, bool functionTypeArguments) { - LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1); ToStringOptions opts; opts.useLineBreaks = false; opts.hideTableKind = true; @@ -637,23 +615,7 @@ static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty, bool if (!canSuggestInferredType(scope, ty)) return std::nullopt; - if (FFlag::LuauAnonymousAutofilled1) - { - return tryToStringDetailed(scope, ty, functionTypeArguments); - } - else - { - ToStringOptions opts; - opts.useLineBreaks = false; - opts.hideTableKind = true; - opts.scope = scope; - ToStringResult name = toStringDetailed(ty, opts); - - if (name.error || name.invalid || name.cycle || name.truncated) - return std::nullopt; - - return name.name; - } + return tryToStringDetailed(scope, ty, functionTypeArguments); } static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) @@ -1097,14 +1059,19 @@ static AutocompleteEntryMap autocompleteStatement( { if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->hasEnd) + else if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatIf* statIf = (*it)->as(); statIf && !statIf->hasEnd) + else if (AstStatIf* statIf = (*it)->as(); statIf && !statIf->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->hasEnd) + else if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); - if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->hasEnd) + else if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->hasEnd) result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (FFlag::LuauAutocompleteDoEnd) + { + if (AstStatBlock* exprBlock = (*it)->as(); exprBlock && !exprBlock->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } } if (ancestry.size() >= 2) @@ -1239,7 +1206,7 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; if (auto ty = findExpectedTypeAt(module, node, position)) - autocompleteStringSingleton(*ty, true, result); + autocompleteStringSingleton(*ty, true, node, position, result); } return AutocompleteContext::Expression; @@ -1345,7 +1312,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - if (FFlag::LuauDisableCompletionOutsideQuotes && !nodes.back()->is()) + if (!nodes.back()->is()) { if (nodes.back()->location.end == position || nodes.back()->location.begin == position) { @@ -1419,7 +1386,6 @@ static AutocompleteResult autocompleteWhileLoopKeywords(std::vector an static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& funcTy) { - LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1); std::string result = "function("; auto [args, tail] = Luau::flatten(funcTy.argTypes); @@ -1440,7 +1406,7 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func name = "a" + std::to_string(argIdx); if (std::optional type = tryGetTypeNameInScope(scope, args[argIdx], true)) - result += name + ": " + *type; + result += name + ": " + *type; else result += name; } @@ -1456,7 +1422,7 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func if (std::optional res = tryToStringDetailed(scope, pack->ty, true)) varArgType = std::move(res); } - + if (varArgType) result += "...: " + *varArgType; else @@ -1483,9 +1449,9 @@ static std::string makeAnonymous(const ScopePtr& scope, const FunctionType& func return result; } -static std::optional makeAnonymousAutofilled(const ModulePtr& module, Position position, const AstNode* node, const std::vector& ancestry) +static std::optional makeAnonymousAutofilled( + const ModulePtr& module, Position position, const AstNode* node, const std::vector& ancestry) { - LUAU_ASSERT(FFlag::LuauAnonymousAutofilled1); const AstExprCall* call = node->as(); if (!call && ancestry.size() > 1) call = ancestry[ancestry.size() - 2]->as(); @@ -1521,10 +1487,10 @@ static std::optional makeAnonymousAutofilled(const ModulePtr& auto [args, tail] = flatten(outerFunction->argTypes); if (argument < args.size()) argType = args[argument]; - + if (!argType) return std::nullopt; - + TypeId followed = follow(*argType); const FunctionType* type = get(followed); if (!type) @@ -1720,7 +1686,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M auto result = autocompleteProps(*module, typeArena, builtinTypes, *it, PropIndexType::Key, ancestry); if (auto nodeIt = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*nodeIt, !node->is(), result); + autocompleteStringSingleton(*nodeIt, !node->is(), node, position, result); if (!key) { @@ -1732,7 +1698,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // suggest those too. if (auto ttv = get(follow(*it)); ttv && ttv->indexer) { - autocompleteStringSingleton(ttv->indexer->indexType, false, result); + autocompleteStringSingleton(ttv->indexer->indexType, false, node, position, result); } } @@ -1769,7 +1735,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteEntryMap result; if (auto it = module->astExpectedTypes.find(node->asExpr())) - autocompleteStringSingleton(*it, false, result); + autocompleteStringSingleton(*it, false, node, position, result); if (ancestry.size() >= 2) { @@ -1783,7 +1749,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (binExpr->op == AstExprBinary::CompareEq || binExpr->op == AstExprBinary::CompareNe) { if (auto it = module->astTypes.find(node == binExpr->left ? binExpr->right : binExpr->left)) - autocompleteStringSingleton(*it, false, result); + autocompleteStringSingleton(*it, false, node, position, result); } } } @@ -1803,17 +1769,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (node->asExpr()) { - if (FFlag::LuauAnonymousAutofilled1) - { - AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) - ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); - return ret; - } - else - { - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - } + AutocompleteResult ret = autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + if (std::optional generated = makeAnonymousAutofilled(module, position, node, ancestry)) + ret.entryMap[kGeneratedAnonymousFunctionEntryName] = std::move(*generated); + return ret; } else if (node->asStat()) return {autocompleteStatement(sourceModule, *module, ancestry, position), ancestry, AutocompleteContext::Statement}; @@ -1823,15 +1782,6 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) { - if (!FFlag::LuauAutocompleteLastTypecheck) - { - // FIXME: We can improve performance here by parsing without checking. - // The old type graph is probably fine. (famous last words!) - FrontendOptions opts; - opts.forAutocomplete = true; - frontend.check(moduleName, opts); - } - const SourceModule* sourceModule = frontend.getSourceModule(moduleName); if (!sourceModule) return {}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index c55a88eb..0200ee3e 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -201,18 +201,6 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string& } } -void registerBuiltinTypes(GlobalTypes& globals) -{ - globals.globalScope->addBuiltinTypeBinding("any", TypeFun{{}, globals.builtinTypes->anyType}); - globals.globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, globals.builtinTypes->nilType}); - globals.globalScope->addBuiltinTypeBinding("number", TypeFun{{}, globals.builtinTypes->numberType}); - globals.globalScope->addBuiltinTypeBinding("string", TypeFun{{}, globals.builtinTypes->stringType}); - globals.globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, globals.builtinTypes->booleanType}); - globals.globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, globals.builtinTypes->threadType}); - globals.globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, globals.builtinTypes->unknownType}); - globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType}); -} - void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeCheckForAutocomplete) { LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); @@ -310,6 +298,520 @@ void registerBuiltinGlobals(Frontend& frontend, GlobalTypes& globals, bool typeC attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); } +static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) +{ + const char* options = "cdiouxXeEfgGqs*"; + + std::vector result; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + i++; + + if (i < size && data[i] == '%') + continue; + + // we just ignore all characters (including flags/precision) up until first alphabetic character + while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*'))) + i++; + + if (i == size) + break; + + if (data[i] == 'q' || data[i] == 's') + result.push_back(builtinTypes->stringType); + else if (data[i] == '*') + result.push_back(builtinTypes->unknownType); + else if (strchr(options, data[i])) + result.push_back(builtinTypes->numberType); + else + result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType)); + } + } + + return result; +} + +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* fmt = nullptr; + if (auto index = expr.func->as(); index && expr.self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!expr.self && expr.args.size > 0) + fmt = expr.args.data[0]->as(); + + if (!fmt) + return std::nullopt; + + std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(paramPack); + + size_t paramOffset = 1; + size_t dataOffset = expr.self ? 0 : 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; + + typechecker.unify(params[i + paramOffset], expected[i], scope, location); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + return WithPredicate{arena.addTypePack({typechecker.stringType})}; +} + +static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) +{ + TypeArena* arena = context.solver->arena; + + AstExprConstantString* fmt = nullptr; + if (auto index = context.callSite->func->as(); index && context.callSite->self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!context.callSite->self && context.callSite->args.size > 0) + fmt = context.callSite->args.data[0]->as(); + + if (!fmt) + return false; + + std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(context.arguments); + + size_t paramOffset = 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) + { + context.solver->unify(context.solver->rootScope, context.callSite->location, params[i + paramOffset], expected[i]); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string + + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); + + TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType}); + asMutable(context.result)->ty.emplace(resultPack); + + return true; +} + +static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) +{ + std::vector result; + int depth = 0; + bool parsingSet = false; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + ++i; + if (!parsingSet && i < size && data[i] == 'b') + i += 2; + } + else if (!parsingSet && data[i] == '[') + { + parsingSet = true; + if (i + 1 < size && data[i + 1] == ']') + i += 1; + } + else if (parsingSet && data[i] == ']') + { + parsingSet = false; + } + else if (data[i] == '(') + { + if (parsingSet) + continue; + + if (i + 1 < size && data[i + 1] == ')') + { + i++; + result.push_back(builtinTypes->optionalNumberType); + continue; + } + + ++depth; + result.push_back(builtinTypes->optionalStringType); + } + else if (data[i] == ')') + { + if (parsingSet) + continue; + + --depth; + + if (depth < 0) + break; + } + } + + if (depth != 0 || parsingSet) + return std::vector(); + + if (result.empty()) + result.push_back(builtinTypes->optionalStringType); + + return result; +} + +static std::optional> magicFunctionGmatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() != 2) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t index = expr.self ? 0 : 1; + if (expr.args.size > index) + pattern = expr.args.data[index]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypePackId emptyPack = arena.addTypePack({}); + const TypePackId returnList = arena.addTypePack(returnTypes); + const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList}); + return WithPredicate{arena.addTypePack({iteratorType})}; +} + +static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() != 2) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t index = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > index) + pattern = context.callSite->args.data[index]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId returnList = arena->addTypePack(returnTypes); + const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList}); + const TypePackId resTypePack = arena->addTypePack({iteratorType}); + asMutable(context.result)->ty.emplace(resTypePack); + + return true; +} + +static std::optional> magicFunctionMatch( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 3) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() == 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 3) + return false; + + TypeArena* arena = context.solver->arena; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + + context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType); + + const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() == 3 && context.callSite->args.size > initIndex) + context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + + return true; +} + +static std::optional> magicFunctionFind( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) +{ + auto [paramPack, _predicates] = withPredicate; + const auto& [params, tail] = flatten(paramPack); + + if (params.size() < 2 || params.size() > 4) + return std::nullopt; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = expr.self ? 0 : 1; + if (expr.args.size > patternIndex) + pattern = expr.args.data[patternIndex]->as(); + + if (!pattern) + return std::nullopt; + + bool plain = false; + size_t plainIndex = expr.self ? 2 : 3; + if (expr.args.size > plainIndex) + { + AstExprConstantBool* p = expr.args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return std::nullopt; + } + + typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); + + const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); + const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}}); + + size_t initIndex = expr.self ? 1 : 2; + if (params.size() >= 3 && expr.args.size > initIndex) + typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); + + if (params.size() == 4 && expr.args.size > plainIndex) + typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena.addTypePack(returnTypes); + return WithPredicate{returnList}; +} + +static bool dcrMagicFunctionFind(MagicFunctionCallContext context) +{ + const auto& [params, tail] = flatten(context.arguments); + + if (params.size() < 2 || params.size() > 4) + return false; + + TypeArena* arena = context.solver->arena; + NotNull builtinTypes = context.solver->builtinTypes; + + AstExprConstantString* pattern = nullptr; + size_t patternIndex = context.callSite->self ? 0 : 1; + if (context.callSite->args.size > patternIndex) + pattern = context.callSite->args.data[patternIndex]->as(); + + if (!pattern) + return false; + + bool plain = false; + size_t plainIndex = context.callSite->self ? 2 : 3; + if (context.callSite->args.size > plainIndex) + { + AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); + plain = p && p->value; + } + + std::vector returnTypes; + if (!plain) + { + returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size); + + if (returnTypes.empty()) + return false; + } + + context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], builtinTypes->stringType); + + const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}}); + const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}}); + + size_t initIndex = context.callSite->self ? 1 : 2; + if (params.size() >= 3 && context.callSite->args.size > initIndex) + context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber); + + if (params.size() == 4 && context.callSite->args.size > plainIndex) + context.solver->unify(context.solver->rootScope, context.callSite->location, params[3], optionalBoolean); + + returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); + + const TypePackId returnList = arena->addTypePack(returnTypes); + asMutable(context.result)->ty.emplace(returnList); + return true; +} + +TypeId makeStringMetatable(NotNull builtinTypes) +{ + NotNull arena{builtinTypes->arena.get()}; + + const TypeId nilType = builtinTypes->nilType; + const TypeId numberType = builtinTypes->numberType; + const TypeId booleanType = builtinTypes->booleanType; + const TypeId stringType = builtinTypes->stringType; + const TypeId anyType = builtinTypes->anyType; + + const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}}); + const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}}); + const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}}); + + const TypePackId oneStringPack = arena->addTypePack({stringType}); + const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); + + FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; + formatFTV.magicFunction = &magicFunctionFormat; + const TypeId formatFn = arena->addType(formatFTV); + attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); + const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); + + const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); + + const TypeId replArgType = + arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), + makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); + const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); + const TypeId gmatchFunc = + makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); + attachMagicFunction(gmatchFunc, magicFunctionGmatch); + attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); + + const TypeId matchFunc = arena->addType( + FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); + attachMagicFunction(matchFunc, magicFunctionMatch); + attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); + + const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), + arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); + attachMagicFunction(findFunc, magicFunctionFind); + attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); + + TableType::Props stringLib = { + {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, + {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, + {"find", {findFunc}}, + {"format", {formatFn}}, // FIXME + {"gmatch", {gmatchFunc}}, + {"gsub", {gsubFunc}}, + {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, + {"lower", {stringToStringType}}, + {"match", {matchFunc}}, + {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, + {"reverse", {stringToStringType}}, + {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, + {"upper", {stringToStringType}}, + {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, + {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, + {"pack", {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType}, anyTypePack}), + oneStringPack, + })}}, + {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, + {"unpack", {arena->addType(FunctionType{ + arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), + anyTypePack, + })}}, + }; + + assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + + TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + + if (TableType* ttv = getMutable(tableType)) + ttv->name = "typeof(string)"; + + return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); +} + static std::optional> magicFunctionSelect( TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) { diff --git a/Analysis/src/Clone.cpp b/Analysis/src/Clone.cpp index 9080b1fc..c28e1678 100644 --- a/Analysis/src/Clone.cpp +++ b/Analysis/src/Clone.cpp @@ -1,8 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Clone.h" +#include "Luau/NotNull.h" #include "Luau/RecursionCounter.h" #include "Luau/TxnLog.h" +#include "Luau/Type.h" #include "Luau/TypePack.h" #include "Luau/Unifiable.h" @@ -13,12 +15,416 @@ LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAGVARIABLE(LuauCloneCyclicUnions, false) +LUAU_FASTFLAGVARIABLE(LuauStacklessTypeClone2, false) +LUAU_FASTINTVARIABLE(LuauTypeCloneIterationLimit, 100'000) + namespace Luau { namespace { +using Kind = Variant; + +template +const T* get(const Kind& kind) +{ + return get_if(&kind); +} + +class TypeCloner2 +{ + NotNull arena; + NotNull builtinTypes; + + // A queue of kinds where we cloned it, but whose interior types hasn't + // been updated to point to new clones. Once all of its interior types + // has been updated, it gets removed from the queue. + std::vector queue; + + NotNull types; + NotNull packs; + + int steps = 0; + +public: + TypeCloner2(NotNull arena, NotNull builtinTypes, NotNull types, NotNull packs) + : arena(arena) + , builtinTypes(builtinTypes) + , types(types) + , packs(packs) + { + } + + TypeId clone(TypeId ty) + { + shallowClone(ty); + run(); + + if (hasExceededIterationLimit()) + { + TypeId error = builtinTypes->errorRecoveryType(); + (*types)[ty] = error; + return error; + } + + return find(ty).value_or(builtinTypes->errorRecoveryType()); + } + + TypePackId clone(TypePackId tp) + { + shallowClone(tp); + run(); + + if (hasExceededIterationLimit()) + { + TypePackId error = builtinTypes->errorRecoveryTypePack(); + (*packs)[tp] = error; + return error; + } + + return find(tp).value_or(builtinTypes->errorRecoveryTypePack()); + } + +private: + bool hasExceededIterationLimit() const + { + if (FInt::LuauTypeCloneIterationLimit == 0) + return false; + + return steps + queue.size() >= size_t(FInt::LuauTypeCloneIterationLimit); + } + + void run() + { + while (!queue.empty()) + { + ++steps; + + if (hasExceededIterationLimit()) + break; + + Kind kind = queue.back(); + queue.pop_back(); + + if (find(kind)) + continue; + + cloneChildren(kind); + } + } + + std::optional find(TypeId ty) const + { + ty = follow(ty, FollowOption::DisableLazyTypeThunks); + if (auto it = types->find(ty); it != types->end()) + return it->second; + return std::nullopt; + } + + std::optional find(TypePackId tp) const + { + tp = follow(tp); + if (auto it = packs->find(tp); it != packs->end()) + return it->second; + return std::nullopt; + } + + std::optional find(Kind kind) const + { + if (auto ty = get(kind)) + return find(*ty); + else if (auto tp = get(kind)) + return find(*tp); + else + { + LUAU_ASSERT(!"Unknown kind?"); + return std::nullopt; + } + } + +private: + TypeId shallowClone(TypeId ty) + { + // We want to [`Luau::follow`] but without forcing the expansion of [`LazyType`]s. + ty = follow(ty, FollowOption::DisableLazyTypeThunks); + + if (auto clone = find(ty)) + return *clone; + else if (ty->persistent) + return ty; + + TypeId target = arena->addType(ty->ty); + asMutable(target)->documentationSymbol = ty->documentationSymbol; + + (*types)[ty] = target; + queue.push_back(target); + return target; + } + + TypePackId shallowClone(TypePackId tp) + { + tp = follow(tp); + + if (auto clone = find(tp)) + return *clone; + else if (tp->persistent) + return tp; + + TypePackId target = arena->addTypePack(tp->ty); + + (*packs)[tp] = target; + queue.push_back(target); + return target; + } + + Property shallowClone(const Property& p) + { + if (FFlag::DebugLuauReadWriteProperties) + { + std::optional cloneReadTy; + if (auto ty = p.readType()) + cloneReadTy = shallowClone(*ty); + + std::optional cloneWriteTy; + if (auto ty = p.writeType()) + cloneWriteTy = shallowClone(*ty); + + std::optional cloned = Property::create(cloneReadTy, cloneWriteTy); + LUAU_ASSERT(cloned); + cloned->deprecated = p.deprecated; + cloned->deprecatedSuggestion = p.deprecatedSuggestion; + cloned->location = p.location; + cloned->tags = p.tags; + cloned->documentationSymbol = p.documentationSymbol; + return *cloned; + } + else + { + return Property{ + shallowClone(p.type()), + p.deprecated, + p.deprecatedSuggestion, + p.location, + p.tags, + p.documentationSymbol, + }; + } + } + + void cloneChildren(TypeId ty) + { + return visit( + [&](auto&& t) { + return cloneChildren(&t); + }, + asMutable(ty)->ty); + } + + void cloneChildren(TypePackId tp) + { + return visit( + [&](auto&& t) { + return cloneChildren(&t); + }, + asMutable(tp)->ty); + } + + void cloneChildren(Kind kind) + { + if (auto ty = get(kind)) + return cloneChildren(*ty); + else if (auto tp = get(kind)) + return cloneChildren(*tp); + else + LUAU_ASSERT(!"Item holds neither TypeId nor TypePackId when enqueuing its children?"); + } + + // ErrorType and ErrorTypePack is an alias to this type. + void cloneChildren(Unifiable::Error* t) + { + // noop. + } + + void cloneChildren(BoundType* t) + { + t->boundTo = shallowClone(t->boundTo); + } + + void cloneChildren(FreeType* t) + { + // TODO: clone lower and upper bounds. + // TODO: In the new solver, we should ice. + } + + void cloneChildren(GenericType* t) + { + // TOOD: clone upper bounds. + } + + void cloneChildren(PrimitiveType* t) + { + // noop. + } + + void cloneChildren(BlockedType* t) + { + // TODO: In the new solver, we should ice. + } + + void cloneChildren(PendingExpansionType* t) + { + // TODO: In the new solver, we should ice. + } + + void cloneChildren(SingletonType* t) + { + // noop. + } + + void cloneChildren(FunctionType* t) + { + for (TypeId& g : t->generics) + g = shallowClone(g); + + for (TypePackId& gp : t->genericPacks) + gp = shallowClone(gp); + + t->argTypes = shallowClone(t->argTypes); + t->retTypes = shallowClone(t->retTypes); + } + + void cloneChildren(TableType* t) + { + if (t->indexer) + { + t->indexer->indexType = shallowClone(t->indexer->indexType); + t->indexer->indexResultType = shallowClone(t->indexer->indexResultType); + } + + for (auto& [_, p] : t->props) + p = shallowClone(p); + + for (TypeId& ty : t->instantiatedTypeParams) + ty = shallowClone(ty); + + for (TypePackId& tp : t->instantiatedTypePackParams) + tp = shallowClone(tp); + } + + void cloneChildren(MetatableType* t) + { + t->table = shallowClone(t->table); + t->metatable = shallowClone(t->metatable); + } + + void cloneChildren(ClassType* t) + { + for (auto& [_, p] : t->props) + p = shallowClone(p); + + if (t->parent) + t->parent = shallowClone(*t->parent); + + if (t->metatable) + t->metatable = shallowClone(*t->metatable); + + if (t->indexer) + { + t->indexer->indexType = shallowClone(t->indexer->indexType); + t->indexer->indexResultType = shallowClone(t->indexer->indexResultType); + } + } + + void cloneChildren(AnyType* t) + { + // noop. + } + + void cloneChildren(UnionType* t) + { + for (TypeId& ty : t->options) + ty = shallowClone(ty); + } + + void cloneChildren(IntersectionType* t) + { + for (TypeId& ty : t->parts) + ty = shallowClone(ty); + } + + void cloneChildren(LazyType* t) + { + if (auto unwrapped = t->unwrapped.load()) + t->unwrapped.store(shallowClone(unwrapped)); + } + + void cloneChildren(UnknownType* t) + { + // noop. + } + + void cloneChildren(NeverType* t) + { + // noop. + } + + void cloneChildren(NegationType* t) + { + t->ty = shallowClone(t->ty); + } + + void cloneChildren(TypeFamilyInstanceType* t) + { + // TODO: In the new solver, we should ice. + } + + void cloneChildren(FreeTypePack* t) + { + // TODO: clone lower and upper bounds. + // TODO: In the new solver, we should ice. + } + + void cloneChildren(GenericTypePack* t) + { + // TOOD: clone upper bounds. + } + + void cloneChildren(BlockedTypePack* t) + { + // TODO: In the new solver, we should ice. + } + + void cloneChildren(BoundTypePack* t) + { + t->boundTo = shallowClone(t->boundTo); + } + + void cloneChildren(VariadicTypePack* t) + { + t->ty = shallowClone(t->ty); + } + + void cloneChildren(TypePack* t) + { + for (TypeId& ty : t->head) + ty = shallowClone(ty); + + if (t->tail) + t->tail = shallowClone(*t->tail); + } + + void cloneChildren(TypeFamilyInstanceTypePack* t) + { + // TODO: In the new solver, we should ice. + } +}; + +} // namespace + +namespace +{ + Property clone(const Property& prop, TypeArena& dest, CloneState& cloneState) { if (FFlag::DebugLuauReadWriteProperties) @@ -374,7 +780,7 @@ void TypeCloner::operator()(const UnionType& t) // We're just using this FreeType as a placeholder until we've finished // cloning the parts of this union so it is okay that its bounds are // nullptr. We'll never indirect them. - TypeId result = dest.addType(FreeType{nullptr, /*lowerBound*/nullptr, /*upperBound*/nullptr}); + TypeId result = dest.addType(FreeType{nullptr, /*lowerBound*/ nullptr, /*upperBound*/ nullptr}); seenTypes[typeId] = result; std::vector options; @@ -470,17 +876,25 @@ TypePackId clone(TypePackId tp, TypeArena& dest, CloneState& cloneState) if (tp->persistent) return tp; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypePackId& res = cloneState.seenTypePacks[tp]; - - if (res == nullptr) + if (FFlag::LuauStacklessTypeClone2) { - TypePackCloner cloner{dest, tp, cloneState}; - Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. + TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.clone(tp); } + else + { + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - return res; + TypePackId& res = cloneState.seenTypePacks[tp]; + + if (res == nullptr) + { + TypePackCloner cloner{dest, tp, cloneState}; + Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. + } + + return res; + } } TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) @@ -488,54 +902,91 @@ TypeId clone(TypeId typeId, TypeArena& dest, CloneState& cloneState) if (typeId->persistent) return typeId; - RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - - TypeId& res = cloneState.seenTypes[typeId]; - - if (res == nullptr) + if (FFlag::LuauStacklessTypeClone2) { - TypeCloner cloner{dest, typeId, cloneState}; - Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - - // Persistent types are not being cloned and we get the original type back which might be read-only - if (!res->persistent) - { - asMutable(res)->documentationSymbol = typeId->documentationSymbol; - } + TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; + return cloner.clone(typeId); } + else + { + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); - return res; + TypeId& res = cloneState.seenTypes[typeId]; + + if (res == nullptr) + { + TypeCloner cloner{dest, typeId, cloneState}; + Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. + + // Persistent types are not being cloned and we get the original type back which might be read-only + if (!res->persistent) + { + asMutable(res)->documentationSymbol = typeId->documentationSymbol; + } + } + + return res; + } } TypeFun clone(const TypeFun& typeFun, TypeArena& dest, CloneState& cloneState) { - TypeFun result; - - for (auto param : typeFun.typeParams) + if (FFlag::LuauStacklessTypeClone2) { - TypeId ty = clone(param.ty, dest, cloneState); - std::optional defaultValue; + TypeCloner2 cloner{NotNull{&dest}, cloneState.builtinTypes, NotNull{&cloneState.seenTypes}, NotNull{&cloneState.seenTypePacks}}; - if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, cloneState); + TypeFun copy = typeFun; - result.typeParams.push_back({ty, defaultValue}); + for (auto& param : copy.typeParams) + { + param.ty = cloner.clone(param.ty); + + if (param.defaultValue) + param.defaultValue = cloner.clone(*param.defaultValue); + } + + for (auto& param : copy.typePackParams) + { + param.tp = cloner.clone(param.tp); + + if (param.defaultValue) + param.defaultValue = cloner.clone(*param.defaultValue); + } + + copy.type = cloner.clone(copy.type); + + return copy; } - - for (auto param : typeFun.typePackParams) + else { - TypePackId tp = clone(param.tp, dest, cloneState); - std::optional defaultValue; + TypeFun result; - if (param.defaultValue) - defaultValue = clone(*param.defaultValue, dest, cloneState); + for (auto param : typeFun.typeParams) + { + TypeId ty = clone(param.ty, dest, cloneState); + std::optional defaultValue; - result.typePackParams.push_back({tp, defaultValue}); + if (param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, cloneState); + + result.typeParams.push_back({ty, defaultValue}); + } + + for (auto param : typeFun.typePackParams) + { + TypePackId tp = clone(param.tp, dest, cloneState); + std::optional defaultValue; + + if (param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, cloneState); + + result.typePackParams.push_back({tp, defaultValue}); + } + + result.type = clone(typeFun.type, dest, cloneState); + + return result; } - - result.type = clone(typeFun.type, dest, cloneState); - - return result; } } // namespace Luau diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 10cd05fa..3b15ec62 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -25,6 +25,7 @@ LUAU_FASTFLAG(DebugLuauLogSolverToJson); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauParseDeclareClassIndexer); LUAU_FASTFLAG(LuauLoopControlFlowAnalysis); +LUAU_FASTFLAG(LuauFloorDivision); namespace Luau { @@ -420,17 +421,17 @@ void ConstraintGraphBuilder::applyRefinements(const ScopePtr& scope, Location lo { switch (shouldSuppressErrors(normalizer, ty)) { - case ErrorSuppression::DoNotSuppress: - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; - break; - case ErrorSuppression::Suppress: - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; - ty = simplifyUnion(builtinTypes, arena, ty, builtinTypes->errorType).result; - break; - case ErrorSuppression::NormalizationFailed: - reportError(location, NormalizationTooComplex{}); - ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; - break; + case ErrorSuppression::DoNotSuppress: + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + break; + case ErrorSuppression::Suppress: + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + ty = simplifyUnion(builtinTypes, arena, ty, builtinTypes->errorType).result; + break; + case ErrorSuppression::NormalizationFailed: + reportError(location, NormalizationTooComplex{}); + ty = simplifyIntersection(builtinTypes, arena, ty, dt).result; + break; } } } @@ -1188,7 +1189,8 @@ static bool isMetamethod(const Name& name) { return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || - name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len" || + (FFlag::LuauFloorDivision && name == "__idiv"); } ControlFlow ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareClass* declaredClass) diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index e1291eeb..6faea1d2 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -22,6 +22,7 @@ #include "Luau/VisitType.h" LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); +LUAU_FASTFLAG(LuauFloorDivision); namespace Luau { @@ -429,6 +430,35 @@ bool ConstraintSolver::isDone() return unsolvedConstraints.empty(); } +namespace +{ + +struct TypeAndLocation +{ + TypeId typeId; + Location location; +}; + +struct FreeTypeSearcher : TypeOnceVisitor +{ + std::deque* result; + Location location; + + FreeTypeSearcher(std::deque* result, Location location) + : result(result) + , location(location) + { + } + + bool visit(TypeId ty, const FreeType&) override + { + result->push_back({ty, location}); + return false; + } +}; + +} // namespace + void ConstraintSolver::finalizeModule() { Anyification a{arena, rootScope, builtinTypes, &iceReporter, builtinTypes->anyType, builtinTypes->anyTypePack}; @@ -445,12 +475,28 @@ void ConstraintSolver::finalizeModule() Unifier2 u2{NotNull{arena}, builtinTypes, NotNull{&iceReporter}}; + std::deque queue; for (auto& [name, binding] : rootScope->bindings) + queue.push_back({binding.typeId, binding.location}); + + DenseHashSet seen{nullptr}; + + while (!queue.empty()) { - auto generalizedTy = u2.generalize(rootScope, binding.typeId); - if (generalizedTy) - binding.typeId = *generalizedTy; - else + TypeAndLocation binding = queue.front(); + queue.pop_front(); + + TypeId ty = follow(binding.typeId); + + if (seen.find(ty)) + continue; + seen.insert(ty); + + FreeTypeSearcher fts{&queue, binding.location}; + fts.traverse(ty); + + auto result = u2.generalize(rootScope, ty); + if (!result) reportError(CodeTooComplex{}, binding.location); } } @@ -719,6 +765,8 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNullnormalize(leftType); if (hasTypeInIntersection(leftType) && force) asMutable(leftType)->ty.emplace(anyPresent ? builtinTypes->anyType : builtinTypes->numberType); @@ -2636,20 +2687,14 @@ ErrorVec ConstraintSolver::unify(NotNull scope, Location location, TypeId ErrorVec ConstraintSolver::unify(NotNull scope, Location location, TypePackId subPack, TypePackId superPack) { - UnifierSharedState sharedState{&iceReporter}; - Unifier u{normalizer, scope, Location{}, Covariant}; - u.enableNewSolver(); + Unifier2 u{arena, builtinTypes, NotNull{&iceReporter}}; - u.tryUnify(subPack, superPack); + u.unify(subPack, superPack); - const auto [changedTypes, changedPacks] = u.log.getChanges(); + unblock(subPack, Location{}); + unblock(superPack, Location{}); - u.log.commit(); - - unblock(changedTypes, Location{}); - unblock(changedPacks, Location{}); - - return std::move(u.errors); + return {}; } NotNull ConstraintSolver::pushConstraint(NotNull scope, const Location& location, ConstraintV cv) @@ -2727,41 +2772,6 @@ TypePackId ConstraintSolver::errorRecoveryTypePack() const return builtinTypes->errorRecoveryTypePack(); } -TypeId ConstraintSolver::unionOfTypes(TypeId a, TypeId b, NotNull scope, bool unifyFreeTypes) -{ - a = follow(a); - b = follow(b); - - if (unifyFreeTypes && (get(a) || get(b))) - { - Unifier u{normalizer, scope, Location{}, Covariant}; - u.enableNewSolver(); - u.tryUnify(b, a); - - if (u.errors.empty()) - { - u.log.commit(); - return a; - } - else - { - return builtinTypes->errorRecoveryType(builtinTypes->anyType); - } - } - - if (*a == *b) - return a; - - std::vector types = reduceUnion({a, b}); - if (types.empty()) - return builtinTypes->neverType; - - if (types.size() == 1) - return types[0]; - - return arena->addType(UnionType{types}); -} - TypePackId ConstraintSolver::anyifyModuleReturnTypePackGenerics(TypePackId tp) { tp = follow(tp); diff --git a/Analysis/src/Differ.cpp b/Analysis/src/Differ.cpp index 20f059d3..5b614cef 100644 --- a/Analysis/src/Differ.cpp +++ b/Analysis/src/Differ.cpp @@ -107,15 +107,17 @@ std::string DiffPath::toString(bool prependDot) const } return pathStr; } -std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf) const +std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLeaf& leaf, const DiffPathNodeLeaf& otherLeaf, bool multiLine) const { + std::string conditionalNewline = multiLine ? "\n" : " "; + std::string conditionalIndent = multiLine ? " " : ""; std::string pathStr{rootName + diffPath.toString(true)}; switch (kind) { case DiffError::Kind::Normal: { checkNonMissingPropertyLeavesHaveNulloptTableProperty(); - return pathStr + " has type " + Luau::toString(*leaf.ty); + return pathStr + conditionalNewline + "has type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty); } case DiffError::Kind::MissingTableProperty: { @@ -123,13 +125,14 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea { if (!leaf.tableProperty.has_value()) throw InternalCompilerError{"leaf.tableProperty is nullopt"}; - return pathStr + "." + *leaf.tableProperty + " has type " + Luau::toString(*leaf.ty); + return pathStr + "." + *leaf.tableProperty + conditionalNewline + "has type" + conditionalNewline + conditionalIndent + + Luau::toString(*leaf.ty); } else if (otherLeaf.ty.has_value()) { if (!otherLeaf.tableProperty.has_value()) throw InternalCompilerError{"otherLeaf.tableProperty is nullopt"}; - return pathStr + " is missing the property " + *otherLeaf.tableProperty; + return pathStr + conditionalNewline + "is missing the property" + conditionalNewline + conditionalIndent + *otherLeaf.tableProperty; } throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; } @@ -140,11 +143,11 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea { if (!leaf.unionIndex.has_value()) throw InternalCompilerError{"leaf.unionIndex is nullopt"}; - return pathStr + " is a union containing type " + Luau::toString(*leaf.ty); + return pathStr + conditionalNewline + "is a union containing type" + conditionalNewline + conditionalIndent + Luau::toString(*leaf.ty); } else if (otherLeaf.ty.has_value()) { - return pathStr + " is a union missing type " + Luau::toString(*otherLeaf.ty); + return pathStr + conditionalNewline + "is a union missing type" + conditionalNewline + conditionalIndent + Luau::toString(*otherLeaf.ty); } throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; } @@ -157,11 +160,13 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea { if (!leaf.unionIndex.has_value()) throw InternalCompilerError{"leaf.unionIndex is nullopt"}; - return pathStr + " is an intersection containing type " + Luau::toString(*leaf.ty); + return pathStr + conditionalNewline + "is an intersection containing type" + conditionalNewline + conditionalIndent + + Luau::toString(*leaf.ty); } else if (otherLeaf.ty.has_value()) { - return pathStr + " is an intersection missing type " + Luau::toString(*otherLeaf.ty); + return pathStr + conditionalNewline + "is an intersection missing type" + conditionalNewline + conditionalIndent + + Luau::toString(*otherLeaf.ty); } throw InternalCompilerError{"Both leaf.ty and otherLeaf.ty is nullopt"}; } @@ -169,13 +174,13 @@ std::string DiffError::toStringALeaf(std::string rootName, const DiffPathNodeLea { if (!leaf.minLength.has_value()) throw InternalCompilerError{"leaf.minLength is nullopt"}; - return pathStr + " takes " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " arguments"; + return pathStr + conditionalNewline + "takes " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " arguments"; } case DiffError::Kind::LengthMismatchInFnRets: { if (!leaf.minLength.has_value()) throw InternalCompilerError{"leaf.minLength is nullopt"}; - return pathStr + " returns " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " values"; + return pathStr + conditionalNewline + "returns " + std::to_string(*leaf.minLength) + (leaf.isVariadic ? " or more" : "") + " values"; } default: { @@ -190,8 +195,11 @@ void DiffError::checkNonMissingPropertyLeavesHaveNulloptTableProperty() const throw InternalCompilerError{"Non-MissingProperty DiffError should have nullopt tableProperty in both leaves"}; } -std::string getDevFixFriendlyName(TypeId ty) +std::string getDevFixFriendlyName(const std::optional& maybeSymbol, TypeId ty) { + if (maybeSymbol.has_value()) + return *maybeSymbol; + if (auto table = get(ty)) { if (table->name.has_value()) @@ -206,27 +214,37 @@ std::string getDevFixFriendlyName(TypeId ty) return *metatable->syntheticName; } } - // else if (auto primitive = get(ty)) - //{ - // return ""; - //} return ""; } -std::string DiffError::toString() const +std::string DifferEnvironment::getDevFixFriendlyNameLeft() const { + return getDevFixFriendlyName(externalSymbolLeft, rootLeft); +} + +std::string DifferEnvironment::getDevFixFriendlyNameRight() const +{ + return getDevFixFriendlyName(externalSymbolRight, rootRight); +} + +std::string DiffError::toString(bool multiLine) const +{ + std::string conditionalNewline = multiLine ? "\n" : " "; + std::string conditionalIndent = multiLine ? " " : ""; switch (kind) { case DiffError::Kind::IncompatibleGeneric: { std::string diffPathStr{diffPath.toString(true)}; - return "DiffError: these two types are not equal because the left generic at " + leftRootName + diffPathStr + - " cannot be the same type parameter as the right generic at " + rightRootName + diffPathStr; + return "DiffError: these two types are not equal because the left generic at" + conditionalNewline + conditionalIndent + leftRootName + + diffPathStr + conditionalNewline + "cannot be the same type parameter as the right generic at" + conditionalNewline + + conditionalIndent + rightRootName + diffPathStr; } default: { - return "DiffError: these two types are not equal because the left type at " + toStringALeaf(leftRootName, left, right) + - ", while the right type at " + toStringALeaf(rightRootName, right, left); + return "DiffError: these two types are not equal because the left type at" + conditionalNewline + conditionalIndent + + toStringALeaf(leftRootName, left, right, multiLine) + "," + conditionalNewline + "while the right type at" + conditionalNewline + + conditionalIndent + toStringALeaf(rightRootName, right, left, multiLine); } } } @@ -296,8 +314,8 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) DiffError::Kind::MissingTableProperty, DiffPathNodeLeaf::detailsTableProperty(value.type(), field), DiffPathNodeLeaf::nullopts(), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } } @@ -307,8 +325,7 @@ static DifferResult diffTable(DifferEnvironment& env, TypeId left, TypeId right) { // right has a field the left doesn't return DifferResult{DiffError{DiffError::Kind::MissingTableProperty, DiffPathNodeLeaf::nullopts(), - DiffPathNodeLeaf::detailsTableProperty(value.type(), field), getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight)}}; + DiffPathNodeLeaf::detailsTableProperty(value.type(), field), env.getDevFixFriendlyNameLeft(), env.getDevFixFriendlyNameRight()}}; } } // left and right have the same set of keys @@ -360,8 +377,8 @@ static DifferResult diffPrimitive(DifferEnvironment& env, TypeId left, TypeId ri DiffError::Kind::Normal, DiffPathNodeLeaf::detailsNormal(left), DiffPathNodeLeaf::detailsNormal(right), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } return DifferResult{}; @@ -380,8 +397,8 @@ static DifferResult diffSingleton(DifferEnvironment& env, TypeId left, TypeId ri DiffError::Kind::Normal, DiffPathNodeLeaf::detailsNormal(left), DiffPathNodeLeaf::detailsNormal(right), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } return DifferResult{}; @@ -419,8 +436,8 @@ static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId righ DiffError::Kind::IncompatibleGeneric, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } @@ -432,8 +449,8 @@ static DifferResult diffGeneric(DifferEnvironment& env, TypeId left, TypeId righ DiffError::Kind::IncompatibleGeneric, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } @@ -468,8 +485,8 @@ static DifferResult diffClass(DifferEnvironment& env, TypeId left, TypeId right) DiffError::Kind::Normal, DiffPathNodeLeaf::detailsNormal(left), DiffPathNodeLeaf::detailsNormal(right), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } @@ -521,16 +538,16 @@ static DifferResult diffUnion(DifferEnvironment& env, TypeId left, TypeId right) DiffError::Kind::MissingUnionMember, DiffPathNodeLeaf::detailsUnionIndex(leftUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), DiffPathNodeLeaf::nullopts(), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; else return DifferResult{DiffError{ DiffError::Kind::MissingUnionMember, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::detailsUnionIndex(rightUnion->options[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } @@ -554,16 +571,16 @@ static DifferResult diffIntersection(DifferEnvironment& env, TypeId left, TypeId DiffError::Kind::MissingIntersectionMember, DiffPathNodeLeaf::detailsUnionIndex(leftIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), DiffPathNodeLeaf::nullopts(), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; else return DifferResult{DiffError{ DiffError::Kind::MissingIntersectionMember, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::detailsUnionIndex(rightIntersection->parts[*findSeteqCexResult.mismatchIdx], *findSeteqCexResult.mismatchIdx), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } @@ -583,8 +600,8 @@ static DifferResult diffUsingEnv(DifferEnvironment& env, TypeId left, TypeId rig DiffError::Kind::Normal, DiffPathNodeLeaf::detailsNormal(left), DiffPathNodeLeaf::detailsNormal(right), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } @@ -753,8 +770,8 @@ static DifferResult diffCanonicalTpShape(DifferEnvironment& env, DiffError::Kind possibleNonNormalErrorKind, DiffPathNodeLeaf::detailsLength(int(left.first.size()), left.second.has_value()), DiffPathNodeLeaf::detailsLength(int(right.first.size()), right.second.has_value()), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } @@ -769,8 +786,8 @@ static DifferResult diffHandleFlattenedTail(DifferEnvironment& env, DiffError::K DiffError::Kind::Normal, DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->first), DiffPathNodeLeaf::detailsNormal(env.visitingBegin()->second), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } @@ -847,8 +864,8 @@ static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypeP DiffError::Kind::IncompatibleGeneric, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } @@ -860,8 +877,8 @@ static DifferResult diffGenericTp(DifferEnvironment& env, TypePackId left, TypeP DiffError::Kind::IncompatibleGeneric, DiffPathNodeLeaf::nullopts(), DiffPathNodeLeaf::nullopts(), - getDevFixFriendlyName(env.rootLeft), - getDevFixFriendlyName(env.rootRight), + env.getDevFixFriendlyNameLeft(), + env.getDevFixFriendlyNameRight(), }}; } @@ -910,7 +927,13 @@ std::vector>::const_reverse_iterator DifferEnvironment DifferResult diff(TypeId ty1, TypeId ty2) { - DifferEnvironment differEnv{ty1, ty2}; + DifferEnvironment differEnv{ty1, ty2, std::nullopt, std::nullopt}; + return diffUsingEnv(differEnv, ty1, ty2); +} + +DifferResult diffWithSymbols(TypeId ty1, TypeId ty2, std::optional symbol1, std::optional symbol2) +{ + DifferEnvironment differEnv{ty1, ty2, symbol1, symbol2}; return diffUsingEnv(differEnv, ty1, ty2); } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index dfc6ff07..65b04d62 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -148,8 +148,7 @@ declare coroutine: { resume: (co: thread, A...) -> (boolean, R...), running: () -> thread, status: (co: thread) -> "dead" | "running" | "normal" | "suspended", - -- FIXME: This technically returns a function, but we can't represent this yet. - wrap: (f: (A...) -> R...) -> any, + wrap: (f: (A...) -> R...) -> ((A...) -> R...), yield: (A...) -> R..., isyieldable: () -> boolean, close: (co: thread) -> (boolean, any) diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 1a690a50..4505f627 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -4,12 +4,15 @@ #include "Luau/Clone.h" #include "Luau/Common.h" #include "Luau/FileResolver.h" +#include "Luau/NotNull.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" +#include #include #include +LUAU_FASTINTVARIABLE(LuauIndentTypeMismatchMaxTypeLength, 10) static std::string wrongNumberOfArgsString( size_t expectedCount, std::optional maximumCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) @@ -66,6 +69,20 @@ struct ErrorConverter std::string result; + auto quote = [&](std::string s) { + return "'" + s + "'"; + }; + + auto constructErrorMessage = [&](std::string givenType, std::string wantedType, std::optional givenModule, + std::optional wantedModule) -> std::string { + std::string given = givenModule ? quote(givenType) + " from " + quote(*givenModule) : quote(givenType); + std::string wanted = wantedModule ? quote(wantedType) + " from " + quote(*wantedModule) : quote(wantedType); + size_t luauIndentTypeMismatchMaxTypeLength = size_t(FInt::LuauIndentTypeMismatchMaxTypeLength); + if (givenType.length() <= luauIndentTypeMismatchMaxTypeLength || wantedType.length() <= luauIndentTypeMismatchMaxTypeLength) + return "Type " + given + " could not be converted into " + wanted; + return "Type\n " + given + "\ncould not be converted into\n " + wanted; + }; + if (givenTypeName == wantedTypeName) { if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) @@ -76,20 +93,18 @@ struct ErrorConverter { std::string givenModuleName = fileResolver->getHumanReadableModuleName(*givenDefinitionModule); std::string wantedModuleName = fileResolver->getHumanReadableModuleName(*wantedDefinitionModule); - result = "Type '" + givenTypeName + "' from '" + givenModuleName + "' could not be converted into '" + wantedTypeName + - "' from '" + wantedModuleName + "'"; + result = constructErrorMessage(givenTypeName, wantedTypeName, givenModuleName, wantedModuleName); } else { - result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + - "' from '" + *wantedDefinitionModule + "'"; + result = constructErrorMessage(givenTypeName, wantedTypeName, *givenDefinitionModule, *wantedDefinitionModule); } } } } if (result.empty()) - result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; + result = constructErrorMessage(givenTypeName, wantedTypeName, std::nullopt, std::nullopt); if (tm.error) @@ -97,7 +112,7 @@ struct ErrorConverter result += "\ncaused by:\n "; if (!tm.reason.empty()) - result += tm.reason + " "; + result += tm.reason + " \n"; result += Luau::toString(*tm.error, TypeErrorToStringOptions{fileResolver}); } @@ -845,7 +860,7 @@ bool containsParseErrorName(const TypeError& error) } template -void copyError(T& e, TypeArena& destArena, CloneState cloneState) +void copyError(T& e, TypeArena& destArena, CloneState& cloneState) { auto clone = [&](auto&& ty) { return ::Luau::clone(ty, destArena, cloneState); @@ -998,9 +1013,9 @@ void copyError(T& e, TypeArena& destArena, CloneState cloneState) static_assert(always_false_v, "Non-exhaustive type switch"); } -void copyErrors(ErrorVec& errors, TypeArena& destArena) +void copyErrors(ErrorVec& errors, TypeArena& destArena, NotNull builtinTypes) { - CloneState cloneState; + CloneState cloneState{builtinTypes}; auto visitErrorData = [&](auto&& e) { copyError(e, destArena, cloneState); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 52eedece..677458ab 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -31,11 +31,12 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTarjanChildLimit) LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) +LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) // TODO: Remove with FFlagLuauTypecheckLimitControls LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false) LUAU_FASTFLAGVARIABLE(DebugLuauReadWriteProperties, false) -LUAU_FASTFLAGVARIABLE(LuauTypecheckCancellation, false) +LUAU_FASTFLAGVARIABLE(LuauTypecheckLimitControls, false) +LUAU_FASTFLAGVARIABLE(CorrectEarlyReturnInMarkDirty, false) namespace Luau { @@ -126,7 +127,7 @@ static ParseResult parseSourceForModule(std::string_view source, Luau::SourceMod static void persistCheckedTypes(ModulePtr checkedModule, GlobalTypes& globals, ScopePtr targetScope, const std::string& packageName) { - CloneState cloneState; + CloneState cloneState{globals.builtinTypes}; std::vector typesToPersist; typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); @@ -462,7 +463,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optionalcancelled) + if (item.module->cancelled) return {}; checkResult.errors.insert(checkResult.errors.end(), item.module->errors.begin(), item.module->errors.end()); @@ -635,7 +636,7 @@ std::vector Frontend::checkQueuedModules(std::optionalcancelled) + if (item.module && item.module->cancelled) cancelled = true; if (itemWithException || cancelled) @@ -677,7 +678,7 @@ std::vector Frontend::checkQueuedModules(std::optional& items, std::vecto } } +static void applyInternalLimitScaling(SourceNode& sourceNode, const ModulePtr module, double limit) +{ + if (module->timeout) + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + else if (module->checkDurationSec < limit / 2.0) + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); +} + void Frontend::checkBuildQueueItem(BuildQueueItem& item) { SourceNode& sourceNode = *item.sourceNode; @@ -884,44 +893,84 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) double timestamp = getTimestamp(); const std::vector& requireCycles = item.requireCycles; - if (item.options.forAutocomplete) + TypeCheckLimits typeCheckLimits; + + if (FFlag::LuauTypecheckLimitControls) { - double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; - - // The autocomplete typecheck is always in strict mode with DM awareness - // to provide better type information for IDE features - TypeCheckLimits typeCheckLimits; - - if (autocompleteTimeLimit != 0.0) - typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; + if (item.options.moduleTimeLimitSec) + typeCheckLimits.finishTime = TimeTrace::getClock() + *item.options.moduleTimeLimitSec; else typeCheckLimits.finishTime = std::nullopt; // TODO: This is a dirty ad hoc solution for autocomplete timeouts // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle - if (FInt::LuauTarjanChildLimit > 0) - typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.instantiationChildLimit = std::nullopt; + if (item.options.applyInternalLimitScaling) + { + if (FInt::LuauTarjanChildLimit > 0) + typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.instantiationChildLimit = std::nullopt; - if (FInt::LuauTypeInferIterationLimit > 0) - typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); - else - typeCheckLimits.unifierIterationLimit = std::nullopt; + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.unifierIterationLimit = std::nullopt; + } + + typeCheckLimits.cancellationToken = item.options.cancellationToken; + } + + if (item.options.forAutocomplete) + { + double autocompleteTimeLimit = FInt::LuauAutocompleteCheckTimeoutMs / 1000.0; + + if (!FFlag::LuauTypecheckLimitControls) + { + // The autocomplete typecheck is always in strict mode with DM awareness + // to provide better type information for IDE features + + if (autocompleteTimeLimit != 0.0) + typeCheckLimits.finishTime = TimeTrace::getClock() + autocompleteTimeLimit; + else + typeCheckLimits.finishTime = std::nullopt; + + // TODO: This is a dirty ad hoc solution for autocomplete timeouts + // We are trying to dynamically adjust our existing limits to lower total typechecking time under the limit + // so that we'll have type information for the whole file at lower quality instead of a full abort in the middle + if (FInt::LuauTarjanChildLimit > 0) + typeCheckLimits.instantiationChildLimit = std::max(1, int(FInt::LuauTarjanChildLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.instantiationChildLimit = std::nullopt; + + if (FInt::LuauTypeInferIterationLimit > 0) + typeCheckLimits.unifierIterationLimit = std::max(1, int(FInt::LuauTypeInferIterationLimit * sourceNode.autocompleteLimitsMult)); + else + typeCheckLimits.unifierIterationLimit = std::nullopt; - if (FFlag::LuauTypecheckCancellation) typeCheckLimits.cancellationToken = item.options.cancellationToken; + } + // The autocomplete typecheck is always in strict mode with DM awareness to provide better type information for IDE features ModulePtr moduleForAutocomplete = check(sourceModule, Mode::Strict, requireCycles, environmentScope, /*forAutocomplete*/ true, /*recordJsonLog*/ false, typeCheckLimits); double duration = getTimestamp() - timestamp; - if (moduleForAutocomplete->timeout) - sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; - else if (duration < autocompleteTimeLimit / 2.0) - sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); + if (FFlag::LuauTypecheckLimitControls) + { + moduleForAutocomplete->checkDurationSec = duration; + + if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling) + applyInternalLimitScaling(sourceNode, moduleForAutocomplete, *item.options.moduleTimeLimitSec); + } + else + { + if (moduleForAutocomplete->timeout) + sourceNode.autocompleteLimitsMult = sourceNode.autocompleteLimitsMult / 2.0; + else if (duration < autocompleteTimeLimit / 2.0) + sourceNode.autocompleteLimitsMult = std::min(sourceNode.autocompleteLimitsMult * 2.0, 1.0); + } item.stats.timeCheck += duration; item.stats.filesStrict += 1; @@ -930,14 +979,29 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) return; } - TypeCheckLimits typeCheckLimits; - - if (FFlag::LuauTypecheckCancellation) + if (!FFlag::LuauTypecheckLimitControls) + { typeCheckLimits.cancellationToken = item.options.cancellationToken; + } ModulePtr module = check(sourceModule, mode, requireCycles, environmentScope, /*forAutocomplete*/ false, item.recordJsonLog, typeCheckLimits); - item.stats.timeCheck += getTimestamp() - timestamp; + if (FFlag::LuauTypecheckLimitControls) + { + double duration = getTimestamp() - timestamp; + + module->checkDurationSec = duration; + + if (item.options.moduleTimeLimitSec && item.options.applyInternalLimitScaling) + applyInternalLimitScaling(sourceNode, module, *item.options.moduleTimeLimitSec); + + item.stats.timeCheck += duration; + } + else + { + item.stats.timeCheck += getTimestamp() - timestamp; + } + item.stats.filesStrict += mode == Mode::Strict; item.stats.filesNonstrict += mode == Mode::Nonstrict; @@ -969,7 +1033,7 @@ void Frontend::checkBuildQueueItem(BuildQueueItem& item) // copyErrors needs to allocate into interfaceTypes as it copies // types out of internalTypes, so we unfreeze it here. unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes, builtinTypes); freeze(module->interfaceTypes); module->internalTypes.clear(); @@ -1014,7 +1078,7 @@ void Frontend::checkBuildQueueItems(std::vector& items) { checkBuildQueueItem(item); - if (FFlag::LuauTypecheckCancellation && item.module && item.module->cancelled) + if (item.module && item.module->cancelled) break; recordItemResult(item); @@ -1085,8 +1149,16 @@ bool Frontend::isDirty(const ModuleName& name, bool forAutocomplete) const */ void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) { - if (!moduleResolver.getModule(name) && !moduleResolverForAutocomplete.getModule(name)) - return; + if (FFlag::CorrectEarlyReturnInMarkDirty) + { + if (sourceNodes.count(name) == 0) + return; + } + else + { + if (!moduleResolver.getModule(name) && !moduleResolverForAutocomplete.getModule(name)) + return; + } std::unordered_map> reverseDeps; for (const auto& module : sourceNodes) @@ -1217,7 +1289,8 @@ ModulePtr check(const SourceModule& sourceModule, const std::vectortimeout || result->cancelled) { - // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending types + // If solver was interrupted, skip typechecking and replace all module results with error-supressing types to avoid leaking blocked/pending + // types ScopePtr moduleScope = result->getModuleScope(); moduleScope->returnType = builtinTypes->errorRecoveryTypePack(); @@ -1295,9 +1368,7 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect typeChecker.finishTime = typeCheckLimits.finishTime; typeChecker.instantiationChildLimit = typeCheckLimits.instantiationChildLimit; typeChecker.unifierIterationLimit = typeCheckLimits.unifierIterationLimit; - - if (FFlag::LuauTypecheckCancellation) - typeChecker.cancellationToken = typeCheckLimits.cancellationToken; + typeChecker.cancellationToken = typeCheckLimits.cancellationToken; return typeChecker.check(sourceModule, mode, environmentScope); } diff --git a/Analysis/src/GlobalTypes.cpp b/Analysis/src/GlobalTypes.cpp new file mode 100644 index 00000000..654cfa5d --- /dev/null +++ b/Analysis/src/GlobalTypes.cpp @@ -0,0 +1,34 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/GlobalTypes.h" + +LUAU_FASTFLAG(LuauInitializeStringMetatableInGlobalTypes) + +namespace Luau +{ + +GlobalTypes::GlobalTypes(NotNull builtinTypes) + : builtinTypes(builtinTypes) +{ + globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); + + globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType}); + globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType}); + globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType}); + globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType}); + globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType}); + globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType}); + globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType}); + globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); + + if (FFlag::LuauInitializeStringMetatableInGlobalTypes) + { + unfreeze(*builtinTypes->arena); + TypeId stringMetatableTy = makeStringMetatable(builtinTypes); + asMutable(builtinTypes->stringType)->ty.emplace(PrimitiveType::String, stringMetatableTy); + persist(stringMetatableTy); + freeze(*builtinTypes->arena); + } +} + +} // namespace Luau diff --git a/Analysis/src/Instantiation.cpp b/Analysis/src/Instantiation.cpp index 0c8bc1ec..e74ece06 100644 --- a/Analysis/src/Instantiation.cpp +++ b/Analysis/src/Instantiation.cpp @@ -174,7 +174,8 @@ struct Replacer : Substitution } }; -std::optional instantiate(NotNull builtinTypes, NotNull arena, NotNull limits, NotNull scope, TypeId ty) +std::optional instantiate( + NotNull builtinTypes, NotNull arena, NotNull limits, NotNull scope, TypeId ty) { ty = follow(ty); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 4abc1aa1..d4a16a75 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,48 +14,9 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) -LUAU_FASTFLAGVARIABLE(LuauLintNativeComment, false) - namespace Luau { -// clang-format off -static const char* kWarningNames[] = { - "Unknown", - - "UnknownGlobal", - "DeprecatedGlobal", - "GlobalUsedAsLocal", - "LocalShadow", - "SameLineStatement", - "MultiLineStatement", - "LocalUnused", - "FunctionUnused", - "ImportUnused", - "BuiltinGlobalWrite", - "PlaceholderRead", - "UnreachableCode", - "UnknownType", - "ForRange", - "UnbalancedAssignment", - "ImplicitReturn", - "DuplicateLocal", - "FormatString", - "TableLiteral", - "UninitializedLocal", - "DuplicateFunction", - "DeprecatedApi", - "TableOperations", - "DuplicateCondition", - "MisleadingAndOr", - "CommentDirective", - "IntegerParsing", - "ComparisonPrecedence", -}; -// clang-format on - -static_assert(std::size(kWarningNames) == unsigned(LintWarning::Code__Count), "did you forget to add warning to the list?"); - struct LintContext { struct Global @@ -2827,11 +2788,11 @@ static void lintComments(LintContext& context, const std::vector& ho "optimize directive uses unknown optimization level '%s', 0..2 expected", level); } } - else if (FFlag::LuauLintNativeComment && first == "native") + else if (first == "native") { if (space != std::string::npos) - emitWarning(context, LintWarning::Code_CommentDirective, hc.location, - "native directive has extra symbols at the end of the line"); + emitWarning( + context, LintWarning::Code_CommentDirective, hc.location, "native directive has extra symbols at the end of the line"); } else { @@ -2855,12 +2816,6 @@ static void lintComments(LintContext& context, const std::vector& ho } } -void LintOptions::setDefaults() -{ - // By default, we enable all warnings - warningMask = ~0ull; -} - std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const std::vector& hotcomments, const LintOptions& options) { @@ -2952,54 +2907,6 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc return context.result; } -const char* LintWarning::getName(Code code) -{ - LUAU_ASSERT(unsigned(code) < Code__Count); - - return kWarningNames[code]; -} - -LintWarning::Code LintWarning::parseName(const char* name) -{ - for (int code = Code_Unknown; code < Code__Count; ++code) - if (strcmp(name, getName(Code(code))) == 0) - return Code(code); - - return Code_Unknown; -} - -uint64_t LintWarning::parseMask(const std::vector& hotcomments) -{ - uint64_t result = 0; - - for (const HotComment& hc : hotcomments) - { - if (!hc.header) - continue; - - if (hc.content.compare(0, 6, "nolint") != 0) - continue; - - size_t name = hc.content.find_first_not_of(" \t", 6); - - // --!nolint disables everything - if (name == std::string::npos) - return ~0ull; - - // --!nolint needs to be followed by a whitespace character - if (name == 6) - continue; - - // --!nolint name disables the specific lint - LintWarning::Code code = LintWarning::parseName(hc.content.c_str() + name); - - if (code != LintWarning::Code_Unknown) - result |= 1ull << int(code); - } - - return result; -} - std::vector getDeprecatedGlobals(const AstNameTable& names) { LintContext context; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index cb2114ab..580f59f3 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -199,7 +199,7 @@ void Module::clonePublicInterface(NotNull builtinTypes, InternalEr LUAU_ASSERT(interfaceTypes.types.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); - CloneState cloneState; + CloneState cloneState{builtinTypes}; ScopePtr moduleScope = getModuleScope(); diff --git a/Analysis/src/Normalize.cpp b/Analysis/src/Normalize.cpp index bcad75b0..ac58c7f6 100644 --- a/Analysis/src/Normalize.cpp +++ b/Analysis/src/Normalize.cpp @@ -176,7 +176,7 @@ const NormalizedStringType NormalizedStringType::never; bool isSubtype(const NormalizedStringType& subStr, const NormalizedStringType& superStr) { - if (subStr.isUnion() && superStr.isUnion()) + if (subStr.isUnion() && (superStr.isUnion() && !superStr.isNever())) { for (auto [name, ty] : subStr.singletons) { @@ -1983,18 +1983,68 @@ void Normalizer::intersectClassesWithClass(NormalizedClassType& heres, TypeId th void Normalizer::intersectStrings(NormalizedStringType& here, const NormalizedStringType& there) { + /* There are 9 cases to worry about here + Normalized Left | Normalized Right + C1 string | string ===> trivial + C2 string - {u_1,..} | string ===> trivial + C3 {u_1, ..} | string ===> trivial + C4 string | string - {v_1, ..} ===> string - {v_1, ..} + C5 string - {u_1,..} | string - {v_1, ..} ===> string - ({u_s} U {v_s}) + C6 {u_1, ..} | string - {v_1, ..} ===> {u_s} - {v_s} + C7 string | {v_1, ..} ===> {v_s} + C8 string - {u_1,..} | {v_1, ..} ===> {v_s} - {u_s} + C9 {u_1, ..} | {v_1, ..} ===> {u_s} ∩ {v_s} + */ + // Case 1,2,3 if (there.isString()) return; - if (here.isString()) - here.resetToNever(); - - for (auto it = here.singletons.begin(); it != here.singletons.end();) + // Case 4, Case 7 + else if (here.isString()) { - if (there.singletons.count(it->first)) - it++; - else - it = here.singletons.erase(it); + here.singletons.clear(); + for (const auto& [key, type] : there.singletons) + here.singletons[key] = type; + here.isCofinite = here.isCofinite && there.isCofinite; } + // Case 5 + else if (here.isIntersection() && there.isIntersection()) + { + here.isCofinite = true; + for (const auto& [key, type] : there.singletons) + here.singletons[key] = type; + } + // Case 6 + else if (here.isUnion() && there.isIntersection()) + { + here.isCofinite = false; + for (const auto& [key, _] : there.singletons) + here.singletons.erase(key); + } + // Case 8 + else if (here.isIntersection() && there.isUnion()) + { + here.isCofinite = false; + std::map result(there.singletons); + for (const auto& [key, _] : here.singletons) + result.erase(key); + here.singletons = result; + } + // Case 9 + else if (here.isUnion() && there.isUnion()) + { + here.isCofinite = false; + std::map result; + result.insert(here.singletons.begin(), here.singletons.end()); + result.insert(there.singletons.begin(), there.singletons.end()); + for (auto it = result.begin(); it != result.end();) + if (!here.singletons.count(it->first) || !there.singletons.count(it->first)) + it = result.erase(it); + else + ++it; + here.singletons = result; + } + else + LUAU_ASSERT(0 && "Internal Error - unrecognized case"); } std::optional Normalizer::intersectionOfTypePacks(TypePackId here, TypePackId there) diff --git a/Analysis/src/Simplify.cpp b/Analysis/src/Simplify.cpp index 15b3b2c6..6519c6ff 100644 --- a/Analysis/src/Simplify.cpp +++ b/Analysis/src/Simplify.cpp @@ -2,10 +2,12 @@ #include "Luau/Simplify.h" +#include "Luau/Normalize.h" // TypeIds #include "Luau/RecursionCounter.h" #include "Luau/ToString.h" #include "Luau/TypeArena.h" -#include "Luau/Normalize.h" // TypeIds +#include "Luau/TypeUtils.h" + #include LUAU_FASTINT(LuauTypeReductionRecursionLimit) @@ -47,14 +49,6 @@ struct TypeSimplifier TypeId simplify(TypeId ty, DenseHashSet& seen); }; -template -static std::pair get2(TID one, TID two) -{ - const A* a = get(one); - const B* b = get(two); - return a && b ? std::make_pair(a, b) : std::make_pair(nullptr, nullptr); -} - // Match the exact type false|nil static bool isFalsyType(TypeId ty) { diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 4c6c35b0..a34a45c8 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -10,7 +10,6 @@ LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 10000) LUAU_FASTFLAG(DebugLuauReadWriteProperties) -LUAU_FASTFLAGVARIABLE(LuauTarjanSingleArr, false) namespace Luau { @@ -269,67 +268,30 @@ std::pair Tarjan::indexify(TypeId ty) { ty = log->follow(ty); - if (FFlag::LuauTarjanSingleArr) + auto [index, fresh] = typeToIndex.try_insert(ty, false); + + if (fresh) { - auto [index, fresh] = typeToIndex.try_insert(ty, false); - - if (fresh) - { - index = int(nodes.size()); - nodes.push_back({ty, nullptr, false, false, index}); - } - - return {index, fresh}; + index = int(nodes.size()); + nodes.push_back({ty, nullptr, false, false, index}); } - else - { - bool fresh = !typeToIndex.contains(ty); - int& index = typeToIndex[ty]; - if (fresh) - { - index = int(indexToType.size()); - indexToType.push_back(ty); - indexToPack.push_back(nullptr); - onStack.push_back(false); - lowlink.push_back(index); - } - return {index, fresh}; - } + return {index, fresh}; } std::pair Tarjan::indexify(TypePackId tp) { tp = log->follow(tp); - if (FFlag::LuauTarjanSingleArr) + auto [index, fresh] = packToIndex.try_insert(tp, false); + + if (fresh) { - auto [index, fresh] = packToIndex.try_insert(tp, false); - - if (fresh) - { - index = int(nodes.size()); - nodes.push_back({nullptr, tp, false, false, index}); - } - - return {index, fresh}; + index = int(nodes.size()); + nodes.push_back({nullptr, tp, false, false, index}); } - else - { - bool fresh = !packToIndex.contains(tp); - int& index = packToIndex[tp]; - - if (fresh) - { - index = int(indexToPack.size()); - indexToType.push_back(nullptr); - indexToPack.push_back(tp); - onStack.push_back(false); - lowlink.push_back(index); - } - return {index, fresh}; - } + return {index, fresh}; } void Tarjan::visitChild(TypeId ty) @@ -350,9 +312,6 @@ void Tarjan::visitChild(TypePackId tp) TarjanResult Tarjan::loop() { - if (!FFlag::LuauTarjanSingleArr) - return loop_DEPRECATED(); - // Normally Tarjan is presented recursively, but this is a hot loop, so worth optimizing while (!worklist.empty()) { @@ -476,27 +435,11 @@ TarjanResult Tarjan::visitRoot(TypePackId tp) void Tarjan::clearTarjan() { - if (FFlag::LuauTarjanSingleArr) - { - typeToIndex.clear(); - packToIndex.clear(); - nodes.clear(); + typeToIndex.clear(); + packToIndex.clear(); + nodes.clear(); - stack.clear(); - } - else - { - dirty.clear(); - - typeToIndex.clear(); - packToIndex.clear(); - indexToType.clear(); - indexToPack.clear(); - - stack.clear(); - onStack.clear(); - lowlink.clear(); - } + stack.clear(); edgesTy.clear(); edgesTp.clear(); @@ -505,32 +448,14 @@ void Tarjan::clearTarjan() bool Tarjan::getDirty(int index) { - if (FFlag::LuauTarjanSingleArr) - { - LUAU_ASSERT(size_t(index) < nodes.size()); - return nodes[index].dirty; - } - else - { - if (dirty.size() <= size_t(index)) - dirty.resize(index + 1, false); - return dirty[index]; - } + LUAU_ASSERT(size_t(index) < nodes.size()); + return nodes[index].dirty; } void Tarjan::setDirty(int index, bool d) { - if (FFlag::LuauTarjanSingleArr) - { - LUAU_ASSERT(size_t(index) < nodes.size()); - nodes[index].dirty = d; - } - else - { - if (dirty.size() <= size_t(index)) - dirty.resize(index + 1, false); - dirty[index] = d; - } + LUAU_ASSERT(size_t(index) < nodes.size()); + nodes[index].dirty = d; } void Tarjan::visitEdge(int index, int parentIndex) @@ -541,9 +466,6 @@ void Tarjan::visitEdge(int index, int parentIndex) void Tarjan::visitSCC(int index) { - if (!FFlag::LuauTarjanSingleArr) - return visitSCC_DEPRECATED(index); - bool d = getDirty(index); for (auto it = stack.rbegin(); !d && it != stack.rend(); it++) @@ -588,132 +510,6 @@ TarjanResult Tarjan::findDirty(TypePackId tp) return visitRoot(tp); } -TarjanResult Tarjan::loop_DEPRECATED() -{ - // Normally Tarjan is presented recursively, but this is a hot loop, so worth optimizing - while (!worklist.empty()) - { - auto [index, currEdge, lastEdge] = worklist.back(); - - // First visit - if (currEdge == -1) - { - ++childCount; - if (childLimit > 0 && childLimit <= childCount) - return TarjanResult::TooManyChildren; - - stack.push_back(index); - onStack[index] = true; - - currEdge = int(edgesTy.size()); - - // Fill in edge list of this vertex - if (TypeId ty = indexToType[index]) - visitChildren(ty, index); - else if (TypePackId tp = indexToPack[index]) - visitChildren(tp, index); - - lastEdge = int(edgesTy.size()); - } - - // Visit children - bool foundFresh = false; - - for (; currEdge < lastEdge; currEdge++) - { - int childIndex = -1; - bool fresh = false; - - if (auto ty = edgesTy[currEdge]) - std::tie(childIndex, fresh) = indexify(ty); - else if (auto tp = edgesTp[currEdge]) - std::tie(childIndex, fresh) = indexify(tp); - else - LUAU_ASSERT(false); - - if (fresh) - { - // Original recursion point, update the parent continuation point and start the new element - worklist.back() = {index, currEdge + 1, lastEdge}; - worklist.push_back({childIndex, -1, -1}); - - // We need to continue the top-level loop from the start with the new worklist element - foundFresh = true; - break; - } - else if (onStack[childIndex]) - { - lowlink[index] = std::min(lowlink[index], childIndex); - } - - visitEdge(childIndex, index); - } - - if (foundFresh) - continue; - - if (lowlink[index] == index) - { - visitSCC(index); - while (!stack.empty()) - { - int popped = stack.back(); - stack.pop_back(); - onStack[popped] = false; - if (popped == index) - break; - } - } - - worklist.pop_back(); - - // Original return from recursion into a child - if (!worklist.empty()) - { - auto [parentIndex, _, parentEndEdge] = worklist.back(); - - // No need to keep child edges around - edgesTy.resize(parentEndEdge); - edgesTp.resize(parentEndEdge); - - lowlink[parentIndex] = std::min(lowlink[parentIndex], lowlink[index]); - visitEdge(index, parentIndex); - } - } - - return TarjanResult::Ok; -} - - -void Tarjan::visitSCC_DEPRECATED(int index) -{ - bool d = getDirty(index); - - for (auto it = stack.rbegin(); !d && it != stack.rend(); it++) - { - if (TypeId ty = indexToType[*it]) - d = isDirty(ty); - else if (TypePackId tp = indexToPack[*it]) - d = isDirty(tp); - if (*it == index) - break; - } - - if (!d) - return; - - for (auto it = stack.rbegin(); it != stack.rend(); it++) - { - setDirty(*it, true); - if (TypeId ty = indexToType[*it]) - foundDirty(ty); - else if (TypePackId tp = indexToPack[*it]) - foundDirty(tp); - if (*it == index) - return; - } -} - std::optional Substitution::substitute(TypeId ty) { ty = log->follow(ty); diff --git a/Analysis/src/Subtyping.cpp b/Analysis/src/Subtyping.cpp new file mode 100644 index 00000000..e216d623 --- /dev/null +++ b/Analysis/src/Subtyping.cpp @@ -0,0 +1,1050 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Subtyping.h" + +#include "Luau/Common.h" +#include "Luau/Error.h" +#include "Luau/Normalize.h" +#include "Luau/Scope.h" +#include "Luau/StringUtils.h" +#include "Luau/ToString.h" +#include "Luau/Type.h" +#include "Luau/TypeArena.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" + +#include + +namespace Luau +{ + +struct VarianceFlipper +{ + Subtyping::Variance* variance; + Subtyping::Variance oldValue; + + VarianceFlipper(Subtyping::Variance* v) + : variance(v) + , oldValue(*v) + { + switch (oldValue) + { + case Subtyping::Variance::Covariant: + *variance = Subtyping::Variance::Contravariant; + break; + case Subtyping::Variance::Contravariant: + *variance = Subtyping::Variance::Covariant; + break; + } + } + + ~VarianceFlipper() + { + *variance = oldValue; + } +}; + +SubtypingResult& SubtypingResult::andAlso(const SubtypingResult& other) +{ + isSubtype &= other.isSubtype; + // `|=` is intentional here, we want to preserve error related flags. + isErrorSuppressing |= other.isErrorSuppressing; + normalizationTooComplex |= other.normalizationTooComplex; + return *this; +} + +SubtypingResult& SubtypingResult::orElse(const SubtypingResult& other) +{ + isSubtype |= other.isSubtype; + isErrorSuppressing |= other.isErrorSuppressing; + normalizationTooComplex |= other.normalizationTooComplex; + return *this; +} + +SubtypingResult SubtypingResult::negate(const SubtypingResult& result) +{ + return SubtypingResult{ + !result.isSubtype, + result.isErrorSuppressing, + result.normalizationTooComplex, + }; +} + +SubtypingResult SubtypingResult::all(const std::vector& results) +{ + SubtypingResult acc{true, false}; + for (const SubtypingResult& current : results) + acc.andAlso(current); + return acc; +} + +SubtypingResult SubtypingResult::any(const std::vector& results) +{ + SubtypingResult acc{false, false}; + for (const SubtypingResult& current : results) + acc.orElse(current); + return acc; +} + +SubtypingResult Subtyping::isSubtype(TypeId subTy, TypeId superTy) +{ + mappedGenerics.clear(); + mappedGenericPacks.clear(); + + SubtypingResult result = isCovariantWith(subTy, superTy); + + for (const auto& [subTy, bounds] : mappedGenerics) + { + const auto& lb = bounds.lowerBound; + const auto& ub = bounds.upperBound; + + TypeId lowerBound = makeAggregateType(lb, builtinTypes->neverType); + TypeId upperBound = makeAggregateType(ub, builtinTypes->unknownType); + + result.andAlso(isCovariantWith(lowerBound, upperBound)); + } + + return result; +} + +SubtypingResult Subtyping::isSubtype(TypePackId subTp, TypePackId superTp) +{ + return isCovariantWith(subTp, superTp); +} + +namespace +{ +struct SeenSetPopper +{ + Subtyping::SeenSet* seenTypes; + std::pair pair; + + SeenSetPopper(Subtyping::SeenSet* seenTypes, std::pair pair) + : seenTypes(seenTypes) + , pair(pair) + { + } + + ~SeenSetPopper() + { + seenTypes->erase(pair); + } +}; +} // namespace + +SubtypingResult Subtyping::isCovariantWith(TypeId subTy, TypeId superTy) +{ + subTy = follow(subTy); + superTy = follow(superTy); + + // TODO: Do we care about returning a proof that this is error-suppressing? + // e.g. given `a | error <: a | error` where both operands are pointer equal, + // then should it also carry the information that it's error-suppressing? + // If it should, then `error <: error` should also do the same. + if (subTy == superTy) + return {true}; + + std::pair typePair{subTy, superTy}; + if (!seenTypes.insert(typePair).second) + return {true}; + + SeenSetPopper ssp{&seenTypes, typePair}; + + // Within the scope to which a generic belongs, that generic should be + // tested as though it were its upper bounds. We do not yet support bounded + // generics, so the upper bound is always unknown. + if (auto subGeneric = get(subTy); subGeneric && subsumes(subGeneric->scope, scope)) + return isCovariantWith(builtinTypes->unknownType, superTy); + if (auto superGeneric = get(superTy); superGeneric && subsumes(superGeneric->scope, scope)) + return isCovariantWith(subTy, builtinTypes->unknownType); + + if (auto subUnion = get(subTy)) + return isCovariantWith(subUnion, superTy); + else if (auto superUnion = get(superTy)) + return isCovariantWith(subTy, superUnion); + else if (auto superIntersection = get(superTy)) + return isCovariantWith(subTy, superIntersection); + else if (auto subIntersection = get(subTy)) + { + SubtypingResult result = isCovariantWith(subIntersection, superTy); + if (result.isSubtype || result.isErrorSuppressing || result.normalizationTooComplex) + return result; + else + return isCovariantWith(normalizer->normalize(subTy), normalizer->normalize(superTy)); + } + else if (get(superTy)) + return {true}; // This is always true. + else if (get(subTy)) + { + // any = unknown | error, so we rewrite this to match. + // As per TAPL: A | B <: T iff A <: T && B <: T + return isCovariantWith(builtinTypes->unknownType, superTy).andAlso(isCovariantWith(builtinTypes->errorType, superTy)); + } + else if (get(superTy)) + { + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + LUAU_ASSERT(!get(subTy)); // TODO: replace with ice. + + bool errorSuppressing = get(subTy); + return {!errorSuppressing, errorSuppressing}; + } + else if (get(subTy)) + return {true}; + else if (get(superTy)) + return {false, true}; + else if (get(subTy)) + return {false, true}; + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p.first->ty, p.second->ty); + else if (auto subNegation = get(subTy)) + return isCovariantWith(subNegation, superTy); + else if (auto superNegation = get(superTy)) + return isCovariantWith(subTy, superNegation); + else if (auto subGeneric = get(subTy); subGeneric && variance == Variance::Covariant) + { + bool ok = bindGeneric(subTy, superTy); + return {ok}; + } + else if (auto superGeneric = get(superTy); superGeneric && variance == Variance::Contravariant) + { + bool ok = bindGeneric(subTy, superTy); + return {ok}; + } + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + else if (auto p = get2(subTy, superTy)) + return isCovariantWith(p); + + return {false}; +} + +SubtypingResult Subtyping::isCovariantWith(TypePackId subTp, TypePackId superTp) +{ + subTp = follow(subTp); + superTp = follow(superTp); + + auto [subHead, subTail] = flatten(subTp); + auto [superHead, superTail] = flatten(superTp); + + const size_t headSize = std::min(subHead.size(), superHead.size()); + + std::vector results; + results.reserve(std::max(subHead.size(), superHead.size()) + 1); + + if (subTp == superTp) + return {true}; + + // Match head types pairwise + + for (size_t i = 0; i < headSize; ++i) + { + results.push_back(isCovariantWith(subHead[i], superHead[i])); + if (!results.back().isSubtype) + return {false}; + } + + // Handle mismatched head sizes + + if (subHead.size() < superHead.size()) + { + if (subTail) + { + if (auto vt = get(*subTail)) + { + for (size_t i = headSize; i < superHead.size(); ++i) + results.push_back(isCovariantWith(vt->ty, superHead[i])); + } + else if (auto gt = get(*subTail)) + { + if (variance == Variance::Covariant) + { + // For any non-generic type T: + // + // (X) -> () <: (T) -> () + + // Possible optimization: If headSize == 0 then we can just use subTp as-is. + std::vector headSlice(begin(superHead), begin(superHead) + headSize); + TypePackId superTailPack = arena->addTypePack(std::move(headSlice), superTail); + + if (TypePackId* other = mappedGenericPacks.find(*subTail)) + results.push_back(isCovariantWith(*other, superTailPack)); + else + mappedGenericPacks.try_insert(*subTail, superTailPack); + + // FIXME? Not a fan of the early return here. It makes the + // control flow harder to reason about. + return SubtypingResult::all(results); + } + else + { + // For any non-generic type T: + // + // (T) -> () (X) -> () + // + return {false}; + } + } + else + unexpected(*subTail); + } + else + return {false}; + } + else if (subHead.size() > superHead.size()) + { + if (superTail) + { + if (auto vt = get(*superTail)) + { + for (size_t i = headSize; i < subHead.size(); ++i) + results.push_back(isCovariantWith(subHead[i], vt->ty)); + } + else if (auto gt = get(*superTail)) + { + if (variance == Variance::Contravariant) + { + // For any non-generic type T: + // + // (X...) -> () <: (T) -> () + + // Possible optimization: If headSize == 0 then we can just use subTp as-is. + std::vector headSlice(begin(subHead), begin(subHead) + headSize); + TypePackId subTailPack = arena->addTypePack(std::move(headSlice), subTail); + + if (TypePackId* other = mappedGenericPacks.find(*superTail)) + results.push_back(isCovariantWith(*other, subTailPack)); + else + mappedGenericPacks.try_insert(*superTail, subTailPack); + + // FIXME? Not a fan of the early return here. It makes the + // control flow harder to reason about. + return SubtypingResult::all(results); + } + else + { + // For any non-generic type T: + // + // () -> T () -> X... + return {false}; + } + } + else + unexpected(*superTail); + } + else + return {false}; + } + + // Handle tails + + if (subTail && superTail) + { + if (auto p = get2(*subTail, *superTail)) + { + results.push_back(isCovariantWith(p)); + } + else if (auto p = get2(*subTail, *superTail)) + { + bool ok = bindGeneric(*subTail, *superTail); + results.push_back({ok}); + } + else if (get2(*subTail, *superTail)) + { + if (variance == Variance::Contravariant) + { + // (A...) -> number <: (...number) -> number + bool ok = bindGeneric(*subTail, *superTail); + results.push_back({ok}); + } + else + { + // (number) -> ...number (number) -> A... + results.push_back({false}); + } + } + else if (get2(*subTail, *superTail)) + { + if (variance == Variance::Contravariant) + { + // (...number) -> number (A...) -> number + results.push_back({false}); + } + else + { + // () -> A... <: () -> ...number + bool ok = bindGeneric(*subTail, *superTail); + results.push_back({ok}); + } + } + else + iceReporter->ice( + format("Subtyping::isSubtype got unexpected type packs %s and %s", toString(*subTail).c_str(), toString(*superTail).c_str())); + } + else if (subTail) + { + if (get(*subTail)) + { + return {false}; + } + else if (get(*subTail)) + { + bool ok = bindGeneric(*subTail, builtinTypes->emptyTypePack); + return {ok}; + } + else + unexpected(*subTail); + } + else if (superTail) + { + if (get(*superTail)) + { + /* + * A variadic type pack ...T can be thought of as an infinite union of finite type packs. + * () | (T) | (T, T) | (T, T, T) | ... + * + * And, per TAPL: + * T <: A | B iff T <: A or T <: B + * + * All variadic type packs are therefore supertypes of the empty type pack. + */ + } + else if (get(*superTail)) + { + if (variance == Variance::Contravariant) + { + bool ok = bindGeneric(builtinTypes->emptyTypePack, *superTail); + results.push_back({ok}); + } + else + results.push_back({false}); + } + else + LUAU_ASSERT(0); // TODO + } + + return SubtypingResult::all(results); +} + +template +SubtypingResult Subtyping::isContravariantWith(SubTy&& subTy, SuperTy&& superTy) +{ + return isCovariantWith(superTy, subTy); +} + +template +SubtypingResult Subtyping::isInvariantWith(SubTy&& subTy, SuperTy&& superTy) +{ + return isCovariantWith(subTy, superTy).andAlso(isContravariantWith(subTy, superTy)); +} + +template +SubtypingResult Subtyping::isCovariantWith(const TryPair& pair) +{ + return isCovariantWith(pair.first, pair.second); +} + +template +SubtypingResult Subtyping::isContravariantWith(const TryPair& pair) +{ + return isCovariantWith(pair.second, pair.first); +} + +template +SubtypingResult Subtyping::isInvariantWith(const TryPair& pair) +{ + return isCovariantWith(pair).andAlso(isContravariantWith(pair)); +} + +/* + * This is much simpler than the Unifier implementation because we don't + * actually care about potential "cross-talk" between union parts that match the + * left side. + * + * In fact, we're very limited in what we can do: If multiple choices match, but + * all of them have non-overlapping constraints, then we're stuck with an "or" + * conjunction of constraints. Solving this in the general case is quite + * difficult. + * + * For example, we cannot dispatch anything from this constraint: + * + * {x: number, y: string} <: {x: number, y: 'a} | {x: 'b, y: string} + * + * From this constraint, we can know that either string <: 'a or number <: 'b, + * but we don't know which! + * + * However: + * + * {x: number, y: string} <: {x: number, y: 'a} | {x: number, y: string} + * + * We can dispatch this constraint because there is no 'or' conjunction. One of + * the arms requires 0 matches. + * + * {x: number, y: string, z: boolean} | {x: number, y: 'a, z: 'b} | {x: number, + * y: string, z: 'b} + * + * Here, we have two matches. One asks for string ~ 'a and boolean ~ 'b. The + * other just asks for boolean ~ 'b. We can dispatch this and only commit + * boolean ~ 'b. This constraint does not teach us anything about 'a. + */ +SubtypingResult Subtyping::isCovariantWith(TypeId subTy, const UnionType* superUnion) +{ + // As per TAPL: T <: A | B iff T <: A || T <: B + std::vector subtypings; + for (TypeId ty : superUnion) + subtypings.push_back(isCovariantWith(subTy, ty)); + return SubtypingResult::any(subtypings); +} + +SubtypingResult Subtyping::isCovariantWith(const UnionType* subUnion, TypeId superTy) +{ + // As per TAPL: A | B <: T iff A <: T && B <: T + std::vector subtypings; + for (TypeId ty : subUnion) + subtypings.push_back(isCovariantWith(ty, superTy)); + return SubtypingResult::all(subtypings); +} + +SubtypingResult Subtyping::isCovariantWith(TypeId subTy, const IntersectionType* superIntersection) +{ + // As per TAPL: T <: A & B iff T <: A && T <: B + std::vector subtypings; + for (TypeId ty : superIntersection) + subtypings.push_back(isCovariantWith(subTy, ty)); + return SubtypingResult::all(subtypings); +} + +SubtypingResult Subtyping::isCovariantWith(const IntersectionType* subIntersection, TypeId superTy) +{ + // As per TAPL: A & B <: T iff A <: T || B <: T + std::vector subtypings; + for (TypeId ty : subIntersection) + subtypings.push_back(isCovariantWith(ty, superTy)); + return SubtypingResult::any(subtypings); +} + +SubtypingResult Subtyping::isCovariantWith(const NegationType* subNegation, TypeId superTy) +{ + TypeId negatedTy = follow(subNegation->ty); + + // In order to follow a consistent codepath, rather than folding the + // isCovariantWith test down to its conclusion here, we test the subtyping test + // of the result of negating the type for never, unknown, any, and error. + if (is(negatedTy)) + { + // ¬never ~ unknown + return isCovariantWith(builtinTypes->unknownType, superTy); + } + else if (is(negatedTy)) + { + // ¬unknown ~ never + return isCovariantWith(builtinTypes->neverType, superTy); + } + else if (is(negatedTy)) + { + // ¬any ~ any + return isCovariantWith(negatedTy, superTy); + } + else if (auto u = get(negatedTy)) + { + // ¬(A ∪ B) ~ ¬A ∩ ¬B + // follow intersection rules: A & B <: T iff A <: T && B <: T + std::vector subtypings; + + for (TypeId ty : u) + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(&negatedTmp, superTy)); + } + + return SubtypingResult::all(subtypings); + } + else if (auto i = get(negatedTy)) + { + // ¬(A ∩ B) ~ ¬A ∪ ¬B + // follow union rules: A | B <: T iff A <: T || B <: T + std::vector subtypings; + + for (TypeId ty : i) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(negatedPart->ty, superTy)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(&negatedTmp, superTy)); + } + } + + return SubtypingResult::any(subtypings); + } + else if (is(negatedTy)) + { + iceReporter->ice("attempting to negate a non-testable type"); + } + // negating a different subtype will get you a very wide type that's not a + // subtype of other stuff. + else + { + return {false}; + } +} + +SubtypingResult Subtyping::isCovariantWith(const TypeId subTy, const NegationType* superNegation) +{ + TypeId negatedTy = follow(superNegation->ty); + + if (is(negatedTy)) + { + // ¬never ~ unknown + return isCovariantWith(subTy, builtinTypes->unknownType); + } + else if (is(negatedTy)) + { + // ¬unknown ~ never + return isCovariantWith(subTy, builtinTypes->neverType); + } + else if (is(negatedTy)) + { + // ¬any ~ any + return isSubtype(subTy, negatedTy); + } + else if (auto u = get(negatedTy)) + { + // ¬(A ∪ B) ~ ¬A ∩ ¬B + // follow intersection rules: A & B <: T iff A <: T && B <: T + std::vector subtypings; + + for (TypeId ty : u) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(subTy, negatedPart->ty)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(subTy, &negatedTmp)); + } + } + + return SubtypingResult::all(subtypings); + } + else if (auto i = get(negatedTy)) + { + // ¬(A ∩ B) ~ ¬A ∪ ¬B + // follow union rules: A | B <: T iff A <: T || B <: T + std::vector subtypings; + + for (TypeId ty : i) + { + if (auto negatedPart = get(follow(ty))) + subtypings.push_back(isCovariantWith(subTy, negatedPart->ty)); + else + { + NegationType negatedTmp{ty}; + subtypings.push_back(isCovariantWith(subTy, &negatedTmp)); + } + } + + return SubtypingResult::any(subtypings); + } + else if (auto p = get2(subTy, negatedTy)) + { + // number <: ¬boolean + // number type != p.second->type}; + } + else if (auto p = get2(subTy, negatedTy)) + { + // "foo" (p.first) && p.second->type == PrimitiveType::String) + return {false}; + // false (p.first) && p.second->type == PrimitiveType::Boolean) + return {false}; + // other cases are true + else + return {true}; + } + else if (auto p = get2(subTy, negatedTy)) + { + if (p.first->type == PrimitiveType::String && get(p.second)) + return {false}; + else if (p.first->type == PrimitiveType::Boolean && get(p.second)) + return {false}; + else + return {true}; + } + // the top class type is not actually a primitive type, so the negation of + // any one of them includes the top class type. + else if (auto p = get2(subTy, negatedTy)) + return {true}; + else if (auto p = get(negatedTy); p && is(subTy)) + return {p->type != PrimitiveType::Table}; + else if (auto p = get2(subTy, negatedTy)) + return {p.second->type != PrimitiveType::Function}; + else if (auto p = get2(subTy, negatedTy)) + return {*p.first != *p.second}; + else if (auto p = get2(subTy, negatedTy)) + return SubtypingResult::negate(isCovariantWith(p.first, p.second)); + else if (get2(subTy, negatedTy)) + return {true}; + else if (is(negatedTy)) + iceReporter->ice("attempting to negate a non-testable type"); + + return {false}; +} + +SubtypingResult Subtyping::isCovariantWith(const PrimitiveType* subPrim, const PrimitiveType* superPrim) +{ + return {subPrim->type == superPrim->type}; +} + +SubtypingResult Subtyping::isCovariantWith(const SingletonType* subSingleton, const PrimitiveType* superPrim) +{ + if (get(subSingleton) && superPrim->type == PrimitiveType::String) + return {true}; + else if (get(subSingleton) && superPrim->type == PrimitiveType::Boolean) + return {true}; + else + return {false}; +} + +SubtypingResult Subtyping::isCovariantWith(const SingletonType* subSingleton, const SingletonType* superSingleton) +{ + return {*subSingleton == *superSingleton}; +} + +SubtypingResult Subtyping::isCovariantWith(const TableType* subTable, const TableType* superTable) +{ + SubtypingResult result{true}; + + for (const auto& [name, prop] : superTable->props) + { + auto it = subTable->props.find(name); + if (it != subTable->props.end()) + result.andAlso(isInvariantWith(prop.type(), it->second.type())); + else + return SubtypingResult{false}; + } + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(const MetatableType* subMt, const MetatableType* superMt) +{ + return SubtypingResult::all({ + isCovariantWith(subMt->table, superMt->table), + isCovariantWith(subMt->metatable, superMt->metatable), + }); +} + +SubtypingResult Subtyping::isCovariantWith(const MetatableType* subMt, const TableType* superTable) +{ + if (auto subTable = get(subMt->table)) + { + // Metatables cannot erase properties from the table they're attached to, so + // the subtyping rule for this is just if the table component is a subtype + // of the supertype table. + // + // There's a flaw here in that if the __index metamethod contributes a new + // field that would satisfy the subtyping relationship, we'll erronously say + // that the metatable isn't a subtype of the table, even though they have + // compatible properties/shapes. We'll revisit this later when we have a + // better understanding of how important this is. + return isCovariantWith(subTable, superTable); + } + else + { + // TODO: This may be a case we actually hit? + return {false}; + } +} + +SubtypingResult Subtyping::isCovariantWith(const ClassType* subClass, const ClassType* superClass) +{ + return {isSubclass(subClass, superClass)}; +} + +SubtypingResult Subtyping::isCovariantWith(const ClassType* subClass, const TableType* superTable) +{ + SubtypingResult result{true}; + + for (const auto& [name, prop] : superTable->props) + { + if (auto classProp = lookupClassProp(subClass, name)) + result.andAlso(isInvariantWith(prop.type(), classProp->type())); + else + return SubtypingResult{false}; + } + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(const FunctionType* subFunction, const FunctionType* superFunction) +{ + SubtypingResult result; + { + VarianceFlipper vf{&variance}; + result.orElse(isContravariantWith(subFunction->argTypes, superFunction->argTypes)); + } + + result.andAlso(isCovariantWith(subFunction->retTypes, superFunction->retTypes)); + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(const PrimitiveType* subPrim, const TableType* superTable) +{ + SubtypingResult result{false}; + if (subPrim->type == PrimitiveType::String) + { + if (auto metatable = getMetatable(builtinTypes->stringType, builtinTypes)) + { + if (auto mttv = get(follow(metatable))) + { + if (auto it = mttv->props.find("__index"); it != mttv->props.end()) + { + if (auto stringTable = get(it->second.type())) + result.orElse(isCovariantWith(stringTable, superTable)); + } + } + } + } + + return result; +} + +SubtypingResult Subtyping::isCovariantWith(const SingletonType* subSingleton, const TableType* superTable) +{ + SubtypingResult result{false}; + if (auto stringleton = get(subSingleton)) + { + if (auto metatable = getMetatable(builtinTypes->stringType, builtinTypes)) + { + if (auto mttv = get(follow(metatable))) + { + if (auto it = mttv->props.find("__index"); it != mttv->props.end()) + { + if (auto stringTable = get(it->second.type())) + result.orElse(isCovariantWith(stringTable, superTable)); + } + } + } + } + return result; +} + +SubtypingResult Subtyping::isCovariantWith(const NormalizedType* subNorm, const NormalizedType* superNorm) +{ + if (!subNorm || !superNorm) + return {false, true, true}; + + SubtypingResult result = isCovariantWith(subNorm->tops, superNorm->tops); + result.andAlso(isCovariantWith(subNorm->booleans, superNorm->booleans)); + result.andAlso(isCovariantWith(subNorm->classes, superNorm->classes).orElse(isCovariantWith(subNorm->classes, superNorm->tables))); + result.andAlso(isCovariantWith(subNorm->errors, superNorm->errors)); + result.andAlso(isCovariantWith(subNorm->nils, superNorm->nils)); + result.andAlso(isCovariantWith(subNorm->numbers, superNorm->numbers)); + result.andAlso(isCovariantWith(subNorm->strings, superNorm->strings)); + result.andAlso(isCovariantWith(subNorm->strings, superNorm->tables)); + result.andAlso(isCovariantWith(subNorm->threads, superNorm->threads)); + result.andAlso(isCovariantWith(subNorm->tables, superNorm->tables)); + result.andAlso(isCovariantWith(subNorm->functions, superNorm->functions)); + // isCovariantWith(subNorm->tyvars, superNorm->tyvars); + return result; +} + +SubtypingResult Subtyping::isCovariantWith(const NormalizedClassType& subClass, const NormalizedClassType& superClass) +{ + for (const auto& [subClassTy, _] : subClass.classes) + { + SubtypingResult result; + + for (const auto& [superClassTy, superNegations] : superClass.classes) + { + result.orElse(isCovariantWith(subClassTy, superClassTy)); + if (!result.isSubtype) + continue; + + for (TypeId negation : superNegations) + { + result.andAlso(SubtypingResult::negate(isCovariantWith(subClassTy, negation))); + if (result.isSubtype) + break; + } + } + + if (!result.isSubtype) + return result; + } + + return {true}; +} + +SubtypingResult Subtyping::isCovariantWith(const NormalizedClassType& subClass, const TypeIds& superTables) +{ + for (const auto& [subClassTy, _] : subClass.classes) + { + SubtypingResult result; + + for (TypeId superTableTy : superTables) + result.orElse(isCovariantWith(subClassTy, superTableTy)); + + if (!result.isSubtype) + return result; + } + + return {true}; +} + +SubtypingResult Subtyping::isCovariantWith(const NormalizedStringType& subString, const NormalizedStringType& superString) +{ + bool isSubtype = Luau::isSubtype(subString, superString); + return {isSubtype}; +} + +SubtypingResult Subtyping::isCovariantWith(const NormalizedStringType& subString, const TypeIds& superTables) +{ + if (subString.isNever()) + return {true}; + + if (subString.isCofinite) + { + SubtypingResult result; + for (const auto& superTable : superTables) + { + result.orElse(isCovariantWith(builtinTypes->stringType, superTable)); + if (result.isSubtype) + return result; + } + return result; + } + + // Finite case + // S = s1 | s2 | s3 ... sn <: t1 | t2 | ... | tn + // iff for some ti, S <: ti + // iff for all sj, sj <: ti + for (const auto& superTable : superTables) + { + SubtypingResult result{true}; + for (const auto& [_, subString] : subString.singletons) + { + result.andAlso(isCovariantWith(subString, superTable)); + if (!result.isSubtype) + break; + } + + if (!result.isSubtype) + continue; + else + return result; + } + + return {false}; +} + +SubtypingResult Subtyping::isCovariantWith(const NormalizedFunctionType& subFunction, const NormalizedFunctionType& superFunction) +{ + if (subFunction.isNever()) + return {true}; + else if (superFunction.isTop) + return {true}; + else + return isCovariantWith(subFunction.parts, superFunction.parts); +} + +SubtypingResult Subtyping::isCovariantWith(const TypeIds& subTypes, const TypeIds& superTypes) +{ + std::vector results; + + for (TypeId subTy : subTypes) + { + results.emplace_back(); + for (TypeId superTy : superTypes) + results.back().orElse(isCovariantWith(subTy, superTy)); + } + + return SubtypingResult::all(results); +} + +SubtypingResult Subtyping::isCovariantWith(const VariadicTypePack* subVariadic, const VariadicTypePack* superVariadic) +{ + return isCovariantWith(subVariadic->ty, superVariadic->ty); +} + +bool Subtyping::bindGeneric(TypeId subTy, TypeId superTy) +{ + if (variance == Variance::Covariant) + { + if (!get(subTy)) + return false; + + mappedGenerics[subTy].upperBound.insert(superTy); + } + else + { + if (!get(superTy)) + return false; + + mappedGenerics[superTy].lowerBound.insert(subTy); + } + + return true; +} + +/* + * If, when performing a subtyping test, we encounter a generic on the left + * side, it is permissible to tentatively bind that generic to the right side + * type. + */ +bool Subtyping::bindGeneric(TypePackId subTp, TypePackId superTp) +{ + if (variance == Variance::Contravariant) + std::swap(superTp, subTp); + + if (!get(subTp)) + return false; + + if (TypePackId* m = mappedGenericPacks.find(subTp)) + return *m == superTp; + + mappedGenericPacks[subTp] = superTp; + + return true; +} + +template +TypeId Subtyping::makeAggregateType(const Container& container, TypeId orElse) +{ + if (container.empty()) + return orElse; + else if (container.size() == 1) + return *begin(container); + else + return arena->addType(T{std::vector(begin(container), end(container))}); +} + +void Subtyping::unexpected(TypePackId tp) +{ + iceReporter->ice(format("Unexpected type pack %s", toString(tp).c_str())); +} + +} // namespace Luau diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index c3a1db4c..09851024 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -9,6 +9,8 @@ #include #include +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); + namespace Luau { @@ -52,7 +54,7 @@ bool StateDot::canDuplicatePrimitive(TypeId ty) if (get(ty)) return false; - return get(ty) || get(ty); + return get(ty) || get(ty) || get(ty) || get(ty); } void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) @@ -76,6 +78,10 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str()); else if (get(ty)) formatAppend(result, "n%d [label=\"any\"];\n", index); + else if (get(ty)) + formatAppend(result, "n%d [label=\"unknown\"];\n", index); + else if (get(ty)) + formatAppend(result, "n%d [label=\"never\"];\n", index); } else { @@ -139,159 +145,214 @@ void StateDot::visitChildren(TypeId ty, int index) startNode(index); startNodeLabel(); - if (const BoundType* btv = get(ty)) - { - formatAppend(result, "BoundType %d", index); - finishNodeLabel(ty); - finishNode(); + auto go = [&](auto&& t) { + using T = std::decay_t; - visitChild(btv->boundTo, index); - } - else if (const FunctionType* ftv = get(ty)) - { - formatAppend(result, "FunctionType %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retTypes, index, "ret"); - } - else if (const TableType* ttv = get(ty)) - { - if (ttv->name) - formatAppend(result, "TableType %s", ttv->name->c_str()); - else if (ttv->syntheticName) - formatAppend(result, "TableType %s", ttv->syntheticName->c_str()); - else - formatAppend(result, "TableType %d", index); - finishNodeLabel(ty); - finishNode(); - - if (ttv->boundTo) - return visitChild(*ttv->boundTo, index, "boundTo"); - - for (const auto& [name, prop] : ttv->props) - visitChild(prop.type(), index, name.c_str()); - if (ttv->indexer) + if constexpr (std::is_same_v) { - visitChild(ttv->indexer->indexType, index, "[index]"); - visitChild(ttv->indexer->indexResultType, index, "[value]"); + formatAppend(result, "BoundType %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(t.boundTo, index); } - for (TypeId itp : ttv->instantiatedTypeParams) - visitChild(itp, index, "typeParam"); - - for (TypePackId itp : ttv->instantiatedTypePackParams) - visitChild(itp, index, "typePackParam"); - } - else if (const MetatableType* mtv = get(ty)) - { - formatAppend(result, "MetatableType %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(mtv->table, index, "table"); - visitChild(mtv->metatable, index, "metatable"); - } - else if (const UnionType* utv = get(ty)) - { - formatAppend(result, "UnionType %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId opt : utv->options) - visitChild(opt, index); - } - else if (const IntersectionType* itv = get(ty)) - { - formatAppend(result, "IntersectionType %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId part : itv->parts) - visitChild(part, index); - } - else if (const GenericType* gtv = get(ty)) - { - if (gtv->explicitName) - formatAppend(result, "GenericType %s", gtv->name.c_str()); - else - formatAppend(result, "GenericType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const FreeType* ftv = get(ty)) - { - formatAppend(result, "FreeType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "AnyType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "ErrorType %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const ClassType* ctv = get(ty)) - { - formatAppend(result, "ClassType %s", ctv->name.c_str()); - finishNodeLabel(ty); - finishNode(); - - for (const auto& [name, prop] : ctv->props) - visitChild(prop.type(), index, name.c_str()); - - if (ctv->parent) - visitChild(*ctv->parent, index, "[parent]"); - - if (ctv->metatable) - visitChild(*ctv->metatable, index, "[metatable]"); - - if (ctv->indexer) + else if constexpr (std::is_same_v) { - visitChild(ctv->indexer->indexType, index, "[index]"); - visitChild(ctv->indexer->indexResultType, index, "[value]"); + formatAppend(result, "BlockedType %d", index); + finishNodeLabel(ty); + finishNode(); } - } - else if (const SingletonType* stv = get(ty)) - { - std::string res; + else if constexpr (std::is_same_v) + { + formatAppend(result, "FunctionType %d", index); + finishNodeLabel(ty); + finishNode(); - if (const StringSingleton* ss = get(stv)) - { - // Don't put in quotes anywhere. If it's outside of the call to escape, - // then it's invalid syntax. If it's inside, then escaping is super noisy. - res = "string: " + escape(ss->value); + visitChild(t.argTypes, index, "arg"); + visitChild(t.retTypes, index, "ret"); } - else if (const BooleanSingleton* bs = get(stv)) + else if constexpr (std::is_same_v) { - res = "boolean: "; - res += bs->value ? "true" : "false"; + if (t.name) + formatAppend(result, "TableType %s", t.name->c_str()); + else if (t.syntheticName) + formatAppend(result, "TableType %s", t.syntheticName->c_str()); + else + formatAppend(result, "TableType %d", index); + finishNodeLabel(ty); + finishNode(); + + if (t.boundTo) + return visitChild(*t.boundTo, index, "boundTo"); + + for (const auto& [name, prop] : t.props) + visitChild(prop.type(), index, name.c_str()); + if (t.indexer) + { + visitChild(t.indexer->indexType, index, "[index]"); + visitChild(t.indexer->indexResultType, index, "[value]"); + } + for (TypeId itp : t.instantiatedTypeParams) + visitChild(itp, index, "typeParam"); + + for (TypePackId itp : t.instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "MetatableType %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(t.table, index, "table"); + visitChild(t.metatable, index, "metatable"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "UnionType %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId opt : t.options) + visitChild(opt, index); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "IntersectionType %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : t.parts) + visitChild(part, index); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "LazyType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "PendingExpansionType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + if (t.explicitName) + formatAppend(result, "GenericType %s", t.name.c_str()); + else + formatAppend(result, "GenericType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "FreeType %d", index); + finishNodeLabel(ty); + finishNode(); + + if (FFlag::DebugLuauDeferredConstraintResolution) + { + if (!get(t.lowerBound)) + visitChild(t.lowerBound, index, "[lowerBound]"); + + if (!get(t.upperBound)) + visitChild(t.upperBound, index, "[upperBound]"); + } + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "AnyType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "UnknownType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "NeverType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "PrimitiveType %s", toString(ty).c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "ErrorType %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "ClassType %s", t.name.c_str()); + finishNodeLabel(ty); + finishNode(); + + for (const auto& [name, prop] : t.props) + visitChild(prop.type(), index, name.c_str()); + + if (t.parent) + visitChild(*t.parent, index, "[parent]"); + + if (t.metatable) + visitChild(*t.metatable, index, "[metatable]"); + + if (t.indexer) + { + visitChild(t.indexer->indexType, index, "[index]"); + visitChild(t.indexer->indexResultType, index, "[value]"); + } + } + else if constexpr (std::is_same_v) + { + std::string res; + + if (const StringSingleton* ss = get(&t)) + { + // Don't put in quotes anywhere. If it's outside of the call to escape, + // then it's invalid syntax. If it's inside, then escaping is super noisy. + res = "string: " + escape(ss->value); + } + else if (const BooleanSingleton* bs = get(&t)) + { + res = "boolean: "; + res += bs->value ? "true" : "false"; + } + else + LUAU_ASSERT(!"unknown singleton type"); + + formatAppend(result, "SingletonType %s", res.c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "NegationType %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(t.ty, index, "[negated]"); + } + else if constexpr (std::is_same_v) + { + formatAppend(result, "TypeFamilyInstanceType %d", index); + finishNodeLabel(ty); + finishNode(); } else - LUAU_ASSERT(!"unknown singleton type"); + static_assert(always_false_v, "unknown type kind"); + }; - formatAppend(result, "SingletonType %s", res.c_str()); - finishNodeLabel(ty); - finishNode(); - } - else - { - LUAU_ASSERT(!"unknown type kind"); - finishNodeLabel(ty); - finishNode(); - } + visit(go, ty->ty); } void StateDot::visitChildren(TypePackId tp, int index) diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index cdfe6549..8fd40772 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_FASTFLAG(LuauFloorDivision) + namespace { bool isIdentifierStartChar(char c) @@ -467,10 +469,13 @@ struct Printer case AstExprBinary::Sub: case AstExprBinary::Mul: case AstExprBinary::Div: + case AstExprBinary::FloorDiv: case AstExprBinary::Mod: case AstExprBinary::Pow: case AstExprBinary::CompareLt: case AstExprBinary::CompareGt: + LUAU_ASSERT(FFlag::LuauFloorDivision || a->op != AstExprBinary::FloorDiv); + writer.maybeSpace(a->right->location.begin, 2); writer.symbol(toString(a->op)); break; @@ -487,6 +492,8 @@ struct Printer writer.maybeSpace(a->right->location.begin, 4); writer.keyword(toString(a->op)); break; + default: + LUAU_ASSERT(!"Unknown Op"); } visualize(*a->right); @@ -753,6 +760,12 @@ struct Printer writer.maybeSpace(a->value->location.begin, 2); writer.symbol("/="); break; + case AstExprBinary::FloorDiv: + LUAU_ASSERT(FFlag::LuauFloorDivision); + + writer.maybeSpace(a->value->location.begin, 2); + writer.symbol("//="); + break; case AstExprBinary::Mod: writer.maybeSpace(a->value->location.begin, 2); writer.symbol("%="); diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index fb72bc12..86564cf5 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -27,26 +27,11 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(DebugLuauReadWriteProperties) +LUAU_FASTFLAGVARIABLE(LuauInitializeStringMetatableInGlobalTypes, false) namespace Luau { -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context); - -static std::optional> magicFunctionGmatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context); - -static std::optional> magicFunctionMatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context); - -static std::optional> magicFunctionFind( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate); -static bool dcrMagicFunctionFind(MagicFunctionCallContext context); - // LUAU_NOINLINE prevents unwrapLazy from being inlined into advance below; advance is important to keep inlineable static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv) { @@ -69,14 +54,24 @@ static LUAU_NOINLINE TypeId unwrapLazy(LazyType* ltv) TypeId follow(TypeId t) { - return follow(t, nullptr, [](const void*, TypeId t) -> TypeId { + return follow(t, FollowOption::Normal); +} + +TypeId follow(TypeId t, FollowOption followOption) +{ + return follow(t, followOption, nullptr, [](const void*, TypeId t) -> TypeId { return t; }); } TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeId)) { - auto advance = [context, mapper](TypeId ty) -> std::optional { + return follow(t, FollowOption::Normal, context, mapper); +} + +TypeId follow(TypeId t, FollowOption followOption, const void* context, TypeId (*mapper)(const void*, TypeId)) +{ + auto advance = [followOption, context, mapper](TypeId ty) -> std::optional { TypeId mapped = mapper(context, ty); if (auto btv = get>(mapped)) @@ -85,7 +80,7 @@ TypeId follow(TypeId t, const void* context, TypeId (*mapper)(const void*, TypeI if (auto ttv = get(mapped)) return ttv->boundTo; - if (auto ltv = getMutable(mapped)) + if (auto ltv = getMutable(mapped); ltv && followOption != FollowOption::DisableLazyTypeThunks) return unwrapLazy(ltv); return std::nullopt; @@ -923,6 +918,8 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); +TypeId makeStringMetatable(NotNull builtinTypes); // BuiltinDefinitions.cpp + BuiltinTypes::BuiltinTypes() : arena(new TypeArena) , debugFreezeArena(FFlag::DebugLuauFreezeArena) @@ -945,14 +942,18 @@ BuiltinTypes::BuiltinTypes() , truthyType(arena->addType(Type{NegationType{falsyType}, /*persistent*/ true})) , optionalNumberType(arena->addType(Type{UnionType{{numberType, nilType}}, /*persistent*/ true})) , optionalStringType(arena->addType(Type{UnionType{{stringType, nilType}}, /*persistent*/ true})) + , emptyTypePack(arena->addTypePack(TypePackVar{TypePack{{}}, /*persistent*/ true})) , anyTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, /*persistent*/ true})) , neverTypePack(arena->addTypePack(TypePackVar{VariadicTypePack{neverType}, /*persistent*/ true})) , uninhabitableTypePack(arena->addTypePack(TypePackVar{TypePack{{neverType}, neverTypePack}, /*persistent*/ true})) , errorTypePack(arena->addTypePack(TypePackVar{Unifiable::Error{}, /*persistent*/ true})) { - TypeId stringMetatable = makeStringMetatable(); - asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable}; - persist(stringMetatable); + if (!FFlag::LuauInitializeStringMetatableInGlobalTypes) + { + TypeId stringMetatable = makeStringMetatable(NotNull{this}); + asMutable(stringType)->ty = PrimitiveType{PrimitiveType::String, stringMetatable}; + persist(stringMetatable); + } freeze(*arena); } @@ -969,82 +970,6 @@ BuiltinTypes::~BuiltinTypes() FFlag::DebugLuauFreezeArena.value = prevFlag; } -TypeId BuiltinTypes::makeStringMetatable() -{ - const TypeId optionalNumber = arena->addType(UnionType{{nilType, numberType}}); - const TypeId optionalString = arena->addType(UnionType{{nilType, stringType}}); - const TypeId optionalBoolean = arena->addType(UnionType{{nilType, booleanType}}); - - const TypePackId oneStringPack = arena->addTypePack({stringType}); - const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); - - FunctionType formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; - formatFTV.magicFunction = &magicFunctionFormat; - const TypeId formatFn = arena->addType(formatFTV); - attachDcrMagicFunction(formatFn, dcrMagicFunctionFormat); - - const TypePackId emptyPack = arena->addTypePack({}); - const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); - const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); - - const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); - - const TypeId replArgType = - arena->addType(UnionType{{stringType, arena->addType(TableType({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), - makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); - const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); - const TypeId gmatchFunc = - makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionType{emptyPack, stringVariadicList})}); - attachMagicFunction(gmatchFunc, magicFunctionGmatch); - attachDcrMagicFunction(gmatchFunc, dcrMagicFunctionGmatch); - - const TypeId matchFunc = arena->addType( - FunctionType{arena->addTypePack({stringType, stringType, optionalNumber}), arena->addTypePack(TypePackVar{VariadicTypePack{stringType}})}); - attachMagicFunction(matchFunc, magicFunctionMatch); - attachDcrMagicFunction(matchFunc, dcrMagicFunctionMatch); - - const TypeId findFunc = arena->addType(FunctionType{arena->addTypePack({stringType, stringType, optionalNumber, optionalBoolean}), - arena->addTypePack(TypePack{{optionalNumber, optionalNumber}, stringVariadicList})}); - attachMagicFunction(findFunc, magicFunctionFind); - attachDcrMagicFunction(findFunc, dcrMagicFunctionFind); - - TableType::Props stringLib = { - {"byte", {arena->addType(FunctionType{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, - {"char", {arena->addType(FunctionType{numberVariadicList, arena->addTypePack({stringType})})}}, - {"find", {findFunc}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {matchFunc}}, - {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableType{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, - {"pack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType}, anyTypePack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, - {"unpack", {arena->addType(FunctionType{ - arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - anyTypePack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - - TypeId tableType = arena->addType(TableType{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); - - if (TableType* ttv = getMutable(tableType)) - ttv->name = "typeof(string)"; - - return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); -} - TypeId BuiltinTypes::errorRecoveryType() const { return errorType; @@ -1250,436 +1175,6 @@ IntersectionTypeIterator end(const IntersectionType* itv) return IntersectionTypeIterator{}; } -static std::vector parseFormatString(NotNull builtinTypes, const char* data, size_t size) -{ - const char* options = "cdiouxXeEfgGqs*"; - - std::vector result; - - for (size_t i = 0; i < size; ++i) - { - if (data[i] == '%') - { - i++; - - if (i < size && data[i] == '%') - continue; - - // we just ignore all characters (including flags/precision) up until first alphabetic character - while (i < size && !(data[i] > 0 && (isalpha(data[i]) || data[i] == '*'))) - i++; - - if (i == size) - break; - - if (data[i] == 'q' || data[i] == 's') - result.push_back(builtinTypes->stringType); - else if (data[i] == '*') - result.push_back(builtinTypes->unknownType); - else if (strchr(options, data[i])) - result.push_back(builtinTypes->numberType); - else - result.push_back(builtinTypes->errorRecoveryType(builtinTypes->anyType)); - } - } - - return result; -} - -std::optional> magicFunctionFormat( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* fmt = nullptr; - if (auto index = expr.func->as(); index && expr.self) - { - if (auto group = index->expr->as()) - fmt = group->expr->as(); - else - fmt = index->expr->as(); - } - - if (!expr.self && expr.args.size > 0) - fmt = expr.args.data[0]->as(); - - if (!fmt) - return std::nullopt; - - std::vector expected = parseFormatString(typechecker.builtinTypes, fmt->value.data, fmt->value.size); - const auto& [params, tail] = flatten(paramPack); - - size_t paramOffset = 1; - size_t dataOffset = expr.self ? 0 : 1; - - // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) - { - Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - - typechecker.unify(params[i + paramOffset], expected[i], scope, location); - } - - // if we know the argument count or if we have too many arguments for sure, we can issue an error - size_t numActualParams = params.size(); - size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - - if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - - return WithPredicate{arena.addTypePack({typechecker.stringType})}; -} - -static bool dcrMagicFunctionFormat(MagicFunctionCallContext context) -{ - TypeArena* arena = context.solver->arena; - - AstExprConstantString* fmt = nullptr; - if (auto index = context.callSite->func->as(); index && context.callSite->self) - { - if (auto group = index->expr->as()) - fmt = group->expr->as(); - else - fmt = index->expr->as(); - } - - if (!context.callSite->self && context.callSite->args.size > 0) - fmt = context.callSite->args.data[0]->as(); - - if (!fmt) - return false; - - std::vector expected = parseFormatString(context.solver->builtinTypes, fmt->value.data, fmt->value.size); - const auto& [params, tail] = flatten(context.arguments); - - size_t paramOffset = 1; - - // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) - { - context.solver->unify(context.solver->rootScope, context.callSite->location, params[i + paramOffset], expected[i]); - } - - // if we know the argument count or if we have too many arguments for sure, we can issue an error - size_t numActualParams = params.size(); - size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - - if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) - context.solver->reportError(TypeError{context.callSite->location, CountMismatch{numExpectedParams, std::nullopt, numActualParams}}); - - TypePackId resultPack = arena->addTypePack({context.solver->builtinTypes->stringType}); - asMutable(context.result)->ty.emplace(resultPack); - - return true; -} - -static std::vector parsePatternString(NotNull builtinTypes, const char* data, size_t size) -{ - std::vector result; - int depth = 0; - bool parsingSet = false; - - for (size_t i = 0; i < size; ++i) - { - if (data[i] == '%') - { - ++i; - if (!parsingSet && i < size && data[i] == 'b') - i += 2; - } - else if (!parsingSet && data[i] == '[') - { - parsingSet = true; - if (i + 1 < size && data[i + 1] == ']') - i += 1; - } - else if (parsingSet && data[i] == ']') - { - parsingSet = false; - } - else if (data[i] == '(') - { - if (parsingSet) - continue; - - if (i + 1 < size && data[i + 1] == ')') - { - i++; - result.push_back(builtinTypes->optionalNumberType); - continue; - } - - ++depth; - result.push_back(builtinTypes->optionalStringType); - } - else if (data[i] == ')') - { - if (parsingSet) - continue; - - --depth; - - if (depth < 0) - break; - } - } - - if (depth != 0 || parsingSet) - return std::vector(); - - if (result.empty()) - result.push_back(builtinTypes->optionalStringType); - - return result; -} - -static std::optional> magicFunctionGmatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() != 2) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t index = expr.self ? 0 : 1; - if (expr.args.size > index) - pattern = expr.args.data[index]->as(); - - if (!pattern) - return std::nullopt; - - std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypePackId emptyPack = arena.addTypePack({}); - const TypePackId returnList = arena.addTypePack(returnTypes); - const TypeId iteratorType = arena.addType(FunctionType{emptyPack, returnList}); - return WithPredicate{arena.addTypePack({iteratorType})}; -} - -static bool dcrMagicFunctionGmatch(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() != 2) - return false; - - TypeArena* arena = context.solver->arena; - - AstExprConstantString* pattern = nullptr; - size_t index = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > index) - pattern = context.callSite->args.data[index]->as(); - - if (!pattern) - return false; - - std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - - context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType); - - const TypePackId emptyPack = arena->addTypePack({}); - const TypePackId returnList = arena->addTypePack(returnTypes); - const TypeId iteratorType = arena->addType(FunctionType{emptyPack, returnList}); - const TypePackId resTypePack = arena->addTypePack({iteratorType}); - asMutable(context.result)->ty.emplace(resTypePack); - - return true; -} - -static std::optional> magicFunctionMatch( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() < 2 || params.size() > 3) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = expr.self ? 0 : 1; - if (expr.args.size > patternIndex) - pattern = expr.args.data[patternIndex]->as(); - - if (!pattern) - return std::nullopt; - - std::vector returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); - - size_t initIndex = expr.self ? 1 : 2; - if (params.size() == 3 && expr.args.size > initIndex) - typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); - - const TypePackId returnList = arena.addTypePack(returnTypes); - return WithPredicate{returnList}; -} - -static bool dcrMagicFunctionMatch(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() < 2 || params.size() > 3) - return false; - - TypeArena* arena = context.solver->arena; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > patternIndex) - pattern = context.callSite->args.data[patternIndex]->as(); - - if (!pattern) - return false; - - std::vector returnTypes = parsePatternString(context.solver->builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - - context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], context.solver->builtinTypes->stringType); - - const TypeId optionalNumber = arena->addType(UnionType{{context.solver->builtinTypes->nilType, context.solver->builtinTypes->numberType}}); - - size_t initIndex = context.callSite->self ? 1 : 2; - if (params.size() == 3 && context.callSite->args.size > initIndex) - context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber); - - const TypePackId returnList = arena->addTypePack(returnTypes); - asMutable(context.result)->ty.emplace(returnList); - - return true; -} - -static std::optional> magicFunctionFind( - TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, WithPredicate withPredicate) -{ - auto [paramPack, _predicates] = withPredicate; - const auto& [params, tail] = flatten(paramPack); - - if (params.size() < 2 || params.size() > 4) - return std::nullopt; - - TypeArena& arena = typechecker.currentModule->internalTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = expr.self ? 0 : 1; - if (expr.args.size > patternIndex) - pattern = expr.args.data[patternIndex]->as(); - - if (!pattern) - return std::nullopt; - - bool plain = false; - size_t plainIndex = expr.self ? 2 : 3; - if (expr.args.size > plainIndex) - { - AstExprConstantBool* p = expr.args.data[plainIndex]->as(); - plain = p && p->value; - } - - std::vector returnTypes; - if (!plain) - { - returnTypes = parsePatternString(typechecker.builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return std::nullopt; - } - - typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location); - - const TypeId optionalNumber = arena.addType(UnionType{{typechecker.nilType, typechecker.numberType}}); - const TypeId optionalBoolean = arena.addType(UnionType{{typechecker.nilType, typechecker.booleanType}}); - - size_t initIndex = expr.self ? 1 : 2; - if (params.size() >= 3 && expr.args.size > initIndex) - typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location); - - if (params.size() == 4 && expr.args.size > plainIndex) - typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location); - - returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); - - const TypePackId returnList = arena.addTypePack(returnTypes); - return WithPredicate{returnList}; -} - -static bool dcrMagicFunctionFind(MagicFunctionCallContext context) -{ - const auto& [params, tail] = flatten(context.arguments); - - if (params.size() < 2 || params.size() > 4) - return false; - - TypeArena* arena = context.solver->arena; - NotNull builtinTypes = context.solver->builtinTypes; - - AstExprConstantString* pattern = nullptr; - size_t patternIndex = context.callSite->self ? 0 : 1; - if (context.callSite->args.size > patternIndex) - pattern = context.callSite->args.data[patternIndex]->as(); - - if (!pattern) - return false; - - bool plain = false; - size_t plainIndex = context.callSite->self ? 2 : 3; - if (context.callSite->args.size > plainIndex) - { - AstExprConstantBool* p = context.callSite->args.data[plainIndex]->as(); - plain = p && p->value; - } - - std::vector returnTypes; - if (!plain) - { - returnTypes = parsePatternString(builtinTypes, pattern->value.data, pattern->value.size); - - if (returnTypes.empty()) - return false; - } - - context.solver->unify(context.solver->rootScope, context.callSite->location, params[0], builtinTypes->stringType); - - const TypeId optionalNumber = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->numberType}}); - const TypeId optionalBoolean = arena->addType(UnionType{{builtinTypes->nilType, builtinTypes->booleanType}}); - - size_t initIndex = context.callSite->self ? 1 : 2; - if (params.size() >= 3 && context.callSite->args.size > initIndex) - context.solver->unify(context.solver->rootScope, context.callSite->location, params[2], optionalNumber); - - if (params.size() == 4 && context.callSite->args.size > plainIndex) - context.solver->unify(context.solver->rootScope, context.callSite->location, params[3], optionalBoolean); - - returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); - - const TypePackId returnList = arena->addTypePack(returnTypes); - asMutable(context.result)->ty.emplace(returnList); - return true; -} - TypeId freshType(NotNull arena, NotNull builtinTypes, Scope* scope) { return arena->addType(FreeType{scope, builtinTypes->neverType, builtinTypes->unknownType}); diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index 08b1ffbc..1daabda3 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -23,6 +23,7 @@ #include LUAU_FASTFLAG(DebugLuauMagicTypes) +LUAU_FASTFLAG(LuauFloorDivision); namespace Luau { @@ -241,8 +242,8 @@ struct TypeChecker2 Normalizer normalizer; - TypeChecker2(NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, const SourceModule* sourceModule, - Module* module) + TypeChecker2(NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, + const SourceModule* sourceModule, Module* module) : builtinTypes(builtinTypes) , logger(logger) , limits(limits) @@ -1294,13 +1295,8 @@ struct TypeChecker2 else if (auto assertion = expr->as()) return isLiteral(assertion->expr); - return - expr->is() || - expr->is() || - expr->is() || - expr->is() || - expr->is() || - expr->is(); + return expr->is() || expr->is() || expr->is() || + expr->is() || expr->is() || expr->is(); } static std::unique_ptr buildLiteralPropertiesSet(AstExpr* expr) @@ -1422,7 +1418,7 @@ struct TypeChecker2 LUAU_ASSERT(argOffset == args->head.size()); const Location argLoc = argExprs->empty() ? Location{} // TODO - : argExprs->at(argExprs->size() - 1)->location; + : argExprs->at(argExprs->size() - 1)->location; if (paramIter.tail() && args->tail) { @@ -1817,6 +1813,8 @@ struct TypeChecker2 bool typesHaveIntersection = normalizer.isIntersectionInhabited(leftType, rightType); if (auto it = kBinaryOpMetamethods.find(expr->op); it != kBinaryOpMetamethods.end()) { + LUAU_ASSERT(FFlag::LuauFloorDivision || expr->op != AstExprBinary::Op::FloorDiv); + std::optional leftMt = getMetatable(leftType, builtinTypes); std::optional rightMt = getMetatable(rightType, builtinTypes); bool matches = leftMt == rightMt; @@ -2002,8 +2000,11 @@ struct TypeChecker2 case AstExprBinary::Op::Sub: case AstExprBinary::Op::Mul: case AstExprBinary::Op::Div: + case AstExprBinary::Op::FloorDiv: case AstExprBinary::Op::Pow: case AstExprBinary::Op::Mod: + LUAU_ASSERT(FFlag::LuauFloorDivision || expr->op != AstExprBinary::Op::FloorDiv); + reportErrors(tryUnify(scope, expr->left->location, leftType, builtinTypes->numberType)); reportErrors(tryUnify(scope, expr->right->location, rightType, builtinTypes->numberType)); @@ -2680,15 +2681,15 @@ struct TypeChecker2 } }; -void check( - NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, const SourceModule& sourceModule, Module* module) +void check(NotNull builtinTypes, NotNull unifierState, NotNull limits, DcrLogger* logger, + const SourceModule& sourceModule, Module* module) { TypeChecker2 typeChecker{builtinTypes, unifierState, limits, logger, &sourceModule, module}; typeChecker.visit(sourceModule.root); unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes, builtinTypes); freeze(module->interfaceTypes); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index f2a6c655..48adc292 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,14 +36,13 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false) LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) -LUAU_FASTFLAGVARIABLE(LuauFixCyclicModuleExports, false) LUAU_FASTFLAG(LuauOccursIsntAlwaysFailure) -LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) LUAU_FASTFLAGVARIABLE(LuauTinyControlFlowAnalysis, false) LUAU_FASTFLAGVARIABLE(LuauLoopControlFlowAnalysis, false) +LUAU_FASTFLAGVARIABLE(LuauVariadicOverloadFix, false) LUAU_FASTFLAGVARIABLE(LuauAlwaysCommitInferencesOfFunctionCalls, false) LUAU_FASTFLAG(LuauParseDeclareClassIndexer) -LUAU_FASTFLAGVARIABLE(LuauIndexTableIntersectionStringExpr, false) +LUAU_FASTFLAG(LuauFloorDivision); namespace Luau { @@ -204,7 +203,8 @@ static bool isMetamethod(const Name& name) { return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || - name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len"; + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode" || name == "__iter" || name == "__len" || + (FFlag::LuauFloorDivision && name == "__idiv"); } size_t HashBoolNamePair::operator()(const std::pair& pair) const @@ -212,21 +212,6 @@ size_t HashBoolNamePair::operator()(const std::pair& pair) const return std::hash()(pair.first) ^ std::hash()(pair.second); } -GlobalTypes::GlobalTypes(NotNull builtinTypes) - : builtinTypes(builtinTypes) -{ - globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); - - globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType}); - globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType}); - globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType}); - globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType}); - globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType}); - globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType}); - globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType}); - globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); -} - TypeChecker::TypeChecker(const ScopePtr& globalScope, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) : globalScope(globalScope) , resolver(resolver) @@ -1215,16 +1200,13 @@ ControlFlow TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) scope->importedTypeBindings[name] = module->exportedTypeBindings; scope->importedModules[name] = moduleInfo->name; - if (FFlag::LuauFixCyclicModuleExports) + // Imported types of requires that transitively refer to current module have to be replaced with 'any' + for (const auto& [location, path] : requireCycles) { - // Imported types of requires that transitively refer to current module have to be replaced with 'any' - for (const auto& [location, path] : requireCycles) + if (!path.empty() && path.front() == moduleInfo->name) { - if (!path.empty() && path.front() == moduleInfo->name) - { - for (auto& [name, tf] : scope->importedTypeBindings[name]) - tf = TypeFun{{}, {}, anyType}; - } + for (auto& [name, tf] : scope->importedTypeBindings[name]) + tf = TypeFun{{}, {}, anyType}; } } } @@ -2595,6 +2577,9 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op) return "__mul"; case AstExprBinary::Div: return "__div"; + case AstExprBinary::FloorDiv: + LUAU_ASSERT(FFlag::LuauFloorDivision); + return "__idiv"; case AstExprBinary::Mod: return "__mod"; case AstExprBinary::Pow: @@ -3088,8 +3073,11 @@ TypeId TypeChecker::checkBinaryOperation( case AstExprBinary::Sub: case AstExprBinary::Mul: case AstExprBinary::Div: + case AstExprBinary::FloorDiv: case AstExprBinary::Mod: case AstExprBinary::Pow: + LUAU_ASSERT(FFlag::LuauFloorDivision || expr.op != AstExprBinary::FloorDiv); + reportErrors(tryUnify(lhsType, numberType, scope, expr.left->location)); reportErrors(tryUnify(rhsType, numberType, scope, expr.right->location)); return numberType; @@ -3128,22 +3116,13 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { - if (!FFlag::LuauTypecheckTypeguards) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } - // For these, passing expectedType is worse than simply forcing them, because their implementation // may inadvertently check if expectedTypes exist first and use it, instead of forceSingleton first. WithPredicate lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/true); WithPredicate rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/true); - if (FFlag::LuauTypecheckTypeguards) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } + if (auto predicate = tryGetTypeGuardPredicate(expr)) + return {booleanType, {std::move(*predicate)}}; PredicateVec predicates; @@ -3423,7 +3402,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); return errorRecoveryType(scope); } - else if (FFlag::LuauIndexTableIntersectionStringExpr && get(exprType)) + else if (get(exprType)) { Name name = std::string(value->value.data, value->value.size); @@ -4063,7 +4042,13 @@ void TypeChecker::checkArgumentList(const ScopePtr& scope, const AstExpr& funNam if (argIndex < argLocations.size()) location = argLocations[argIndex]; - unify(*argIter, vtp->ty, scope, location); + if (FFlag::LuauVariadicOverloadFix) + { + state.location = location; + state.tryUnify(*argIter, vtp->ty); + } + else + unify(*argIter, vtp->ty, scope, location); ++argIter; ++argIndex; } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index db8e2008..4a27bbde 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -25,7 +25,7 @@ LUAU_FASTFLAGVARIABLE(LuauOccursIsntAlwaysFailure, false) LUAU_FASTFLAG(LuauNormalizeBlockedTypes) LUAU_FASTFLAG(LuauAlwaysCommitInferencesOfFunctionCalls) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) -LUAU_FASTFLAGVARIABLE(LuauTableUnifyRecursionLimit, false) +LUAU_FASTFLAGVARIABLE(LuauFixIndexerSubtypingOrdering, false) namespace Luau { @@ -605,6 +605,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { // TODO: there are probably cheaper ways to check if any <: T. const NormalizedType* superNorm = normalizer->normalize(superTy); + + if (!superNorm) + return reportError(location, UnificationTooComplex{}); + if (!log.get(superNorm->tops)) failure = true; } @@ -2256,23 +2260,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (superTable != newSuperTable || subTable != newSubTable) { - if (FFlag::LuauTableUnifyRecursionLimit) + if (errors.empty()) { - if (errors.empty()) - { - RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); - tryUnifyTables(subTy, superTy, isIntersection); - } + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnifyTables(subTy, superTy, isIntersection); + } - return; - } - else - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } + return; } } @@ -2292,7 +2286,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, variance = Invariant; Unifier innerState = makeChildUnifier(); - if (useNewSolver) + if (useNewSolver || FFlag::LuauFixIndexerSubtypingOrdering) innerState.tryUnify_(prop.type(), superTable->indexer->indexResultType); else { @@ -2347,23 +2341,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection, if (superTable != newSuperTable || subTable != newSubTable) { - if (FFlag::LuauTableUnifyRecursionLimit) + if (errors.empty()) { - if (errors.empty()) - { - RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); - tryUnifyTables(subTy, superTy, isIntersection); - } + RecursionLimiter _ra(&sharedState.counters.recursionCount, sharedState.counters.recursionLimit); + tryUnifyTables(subTy, superTy, isIntersection); + } - return; - } - else - { - if (errors.empty()) - return tryUnifyTables(subTy, superTy, isIntersection); - else - return; - } + return; } } diff --git a/Analysis/src/Unifier2.cpp b/Analysis/src/Unifier2.cpp index 0be6941b..1f3330dc 100644 --- a/Analysis/src/Unifier2.cpp +++ b/Analysis/src/Unifier2.cpp @@ -26,7 +26,6 @@ Unifier2::Unifier2(NotNull arena, NotNull builtinTypes, , ice(ice) , recursionLimit(FInt::LuauTypeInferRecursionLimit) { - } bool Unifier2::unify(TypeId subTy, TypeId superTy) @@ -99,10 +98,7 @@ bool Unifier2::unify(TypePackId subTp, TypePackId superTp) return true; } - size_t maxLength = std::max( - flatten(subTp).first.size(), - flatten(superTp).first.size() - ); + size_t maxLength = std::max(flatten(subTp).first.size(), flatten(superTp).first.size()); auto [subTypes, subTail] = extendTypePack(*arena, builtinTypes, subTp, maxLength); auto [superTypes, superTail] = extendTypePack(*arena, builtinTypes, superTp, maxLength); @@ -123,16 +119,25 @@ struct FreeTypeSearcher : TypeVisitor explicit FreeTypeSearcher(NotNull scope) : TypeVisitor(/*skipBoundTypes*/ true) , scope(scope) - {} + { + } - enum { Positive, Negative } polarity = Positive; + enum + { + Positive, + Negative + } polarity = Positive; void flip() { switch (polarity) { - case Positive: polarity = Negative; break; - case Negative: polarity = Positive; break; + case Positive: + polarity = Negative; + break; + case Negative: + polarity = Positive; + break; } } @@ -152,8 +157,12 @@ struct FreeTypeSearcher : TypeVisitor switch (polarity) { - case Positive: positiveTypes.insert(ty); break; - case Negative: negativeTypes.insert(ty); break; + case Positive: + positiveTypes.insert(ty); + break; + case Negative: + negativeTypes.insert(ty); + break; } return true; @@ -180,13 +189,17 @@ struct MutatingGeneralizer : TypeOnceVisitor std::unordered_set negativeTypes; std::vector generics; - MutatingGeneralizer(NotNull builtinTypes, NotNull scope, std::unordered_set positiveTypes, std::unordered_set negativeTypes) + bool isWithinFunction = false; + + MutatingGeneralizer( + NotNull builtinTypes, NotNull scope, std::unordered_set positiveTypes, std::unordered_set negativeTypes) : TypeOnceVisitor(/* skipBoundTypes */ true) , builtinTypes(builtinTypes) , scope(scope) , positiveTypes(std::move(positiveTypes)) , negativeTypes(std::move(negativeTypes)) - {} + { + } static void replace(DenseHashSet& seen, TypeId haystack, TypeId needle, TypeId replacement) { @@ -211,10 +224,7 @@ struct MutatingGeneralizer : TypeOnceVisitor // FIXME: I bet this function has reentrancy problems option = follow(option); if (option == needle) - { - LUAU_ASSERT(!seen.find(option)); option = replacement; - } // TODO seen set else if (get(option)) @@ -224,7 +234,21 @@ struct MutatingGeneralizer : TypeOnceVisitor } } - bool visit (TypeId ty, const FreeType&) override + bool visit(TypeId ty, const FunctionType& ft) override + { + const bool oldValue = isWithinFunction; + + isWithinFunction = true; + + traverse(ft.argTypes); + traverse(ft.retTypes); + + isWithinFunction = oldValue; + + return false; + } + + bool visit(TypeId ty, const FreeType&) override { const FreeType* ft = get(ty); LUAU_ASSERT(ft); @@ -232,7 +256,8 @@ struct MutatingGeneralizer : TypeOnceVisitor traverse(ft->lowerBound); traverse(ft->upperBound); - // ft is potentially invalid now. + // It is possible for the above traverse() calls to cause ty to be + // transmuted. We must reaquire ft if this happens. ty = follow(ty); ft = get(ty); if (!ft) @@ -252,8 +277,13 @@ struct MutatingGeneralizer : TypeOnceVisitor if (!hasLowerBound && !hasUpperBound) { - emplaceType(asMutable(ty), scope); - generics.push_back(ty); + if (isWithinFunction) + { + emplaceType(asMutable(ty), scope); + generics.push_back(ty); + } + else + emplaceType(asMutable(ty), builtinTypes->unknownType); } // It is possible that this free type has other free types in its upper @@ -264,19 +294,27 @@ struct MutatingGeneralizer : TypeOnceVisitor // If we do not do this, we get tautological bounds like a <: a <: unknown. else if (isPositive && !hasUpperBound) { - if (FreeType* lowerFree = getMutable(ft->lowerBound); lowerFree && lowerFree->upperBound == ty) + TypeId lb = follow(ft->lowerBound); + if (FreeType* lowerFree = getMutable(lb); lowerFree && lowerFree->upperBound == ty) lowerFree->upperBound = builtinTypes->unknownType; else - replace(seen, ft->lowerBound, ty, builtinTypes->unknownType); - emplaceType(asMutable(ty), ft->lowerBound); + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, lb, ty, builtinTypes->unknownType); + } + emplaceType(asMutable(ty), lb); } else { - if (FreeType* upperFree = getMutable(ft->upperBound); upperFree && upperFree->lowerBound == ty) + TypeId ub = follow(ft->upperBound); + if (FreeType* upperFree = getMutable(ub); upperFree && upperFree->lowerBound == ty) upperFree->lowerBound = builtinTypes->neverType; else - replace(seen, ft->upperBound, ty, builtinTypes->neverType); - emplaceType(asMutable(ty), ft->upperBound); + { + DenseHashSet replaceSeen{nullptr}; + replace(replaceSeen, ub, ty, builtinTypes->neverType); + } + emplaceType(asMutable(ty), ub); } return false; @@ -363,4 +401,4 @@ OccursCheckResult Unifier2::occursCheck(DenseHashSet& seen, TypePack return OccursCheckResult::Pass; } -} +} // namespace Luau diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index f9f9ab41..f3b5b149 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -272,11 +272,18 @@ class AstExprConstantString : public AstExpr public: LUAU_RTTI(AstExprConstantString) - AstExprConstantString(const Location& location, const AstArray& value); + enum QuoteStyle + { + Quoted, + Unquoted + }; + + AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle = Quoted); void visit(AstVisitor* visitor) override; AstArray value; + QuoteStyle quoteStyle = Quoted; }; class AstExprLocal : public AstExpr @@ -450,6 +457,7 @@ public: Sub, Mul, Div, + FloorDiv, Mod, Pow, Concat, @@ -460,7 +468,9 @@ public: CompareGt, CompareGe, And, - Or + Or, + + Op__Count }; AstExprBinary(const Location& location, Op op, AstExpr* left, AstExpr* right); @@ -524,11 +534,12 @@ class AstStatBlock : public AstStat public: LUAU_RTTI(AstStatBlock) - AstStatBlock(const Location& location, const AstArray& body); + AstStatBlock(const Location& location, const AstArray& body, bool hasEnd = true); void visit(AstVisitor* visitor) override; AstArray body; + bool hasEnd = false; }; class AstStatIf : public AstStat diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index 929402b3..7d15212a 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -62,6 +62,7 @@ struct Lexeme Dot3, SkinnyArrow, DoubleColon, + FloorDiv, InterpStringBegin, InterpStringMid, @@ -73,6 +74,7 @@ struct Lexeme SubAssign, MulAssign, DivAssign, + FloorDivAssign, ModAssign, PowAssign, ConcatAssign, @@ -204,7 +206,9 @@ private: Position position() const; + // consume() assumes current character is not a newline for performance; when that is not known, consumeAny() should be used instead. void consume(); + void consumeAny(); Lexeme readCommentBody(); diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 3c87e36c..52b77de3 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -3,6 +3,8 @@ #include "Luau/Common.h" +LUAU_FASTFLAG(LuauFloorDivision) + namespace Luau { @@ -62,9 +64,10 @@ void AstExprConstantNumber::visit(AstVisitor* visitor) visitor->visit(this); } -AstExprConstantString::AstExprConstantString(const Location& location, const AstArray& value) +AstExprConstantString::AstExprConstantString(const Location& location, const AstArray& value, QuoteStyle quoteStyle) : AstExpr(ClassIndex(), location) , value(value) + , quoteStyle(quoteStyle) { } @@ -278,6 +281,9 @@ std::string toString(AstExprBinary::Op op) return "*"; case AstExprBinary::Div: return "/"; + case AstExprBinary::FloorDiv: + LUAU_ASSERT(FFlag::LuauFloorDivision); + return "//"; case AstExprBinary::Mod: return "%"; case AstExprBinary::Pow: @@ -374,9 +380,10 @@ void AstExprError::visit(AstVisitor* visitor) } } -AstStatBlock::AstStatBlock(const Location& location, const AstArray& body) +AstStatBlock::AstStatBlock(const Location& location, const AstArray& body, bool hasEnd) : AstStat(ClassIndex(), location) , body(body) + , hasEnd(hasEnd) { } diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index 75b4fe30..1795243c 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,6 +6,9 @@ #include +LUAU_FASTFLAGVARIABLE(LuauFloorDivision, false) +LUAU_FASTFLAGVARIABLE(LuauLexerLookaheadRemembersBraceType, false) + namespace Luau { @@ -136,6 +139,9 @@ std::string Lexeme::toString() const case DoubleColon: return "'::'"; + case FloorDiv: + return FFlag::LuauFloorDivision ? "'//'" : ""; + case AddAssign: return "'+='"; @@ -148,6 +154,9 @@ std::string Lexeme::toString() const case DivAssign: return "'/='"; + case FloorDivAssign: + return FFlag::LuauFloorDivision ? "'//='" : ""; + case ModAssign: return "'%='"; @@ -373,7 +382,7 @@ const Lexeme& Lexer::next(bool skipComments, bool updatePrevLocation) { // consume whitespace before the token while (isSpace(peekch())) - consume(); + consumeAny(); if (updatePrevLocation) prevLocation = lexeme.location; @@ -400,6 +409,8 @@ Lexeme Lexer::lookahead() unsigned int currentLineOffset = lineOffset; Lexeme currentLexeme = lexeme; Location currentPrevLocation = prevLocation; + size_t currentBraceStackSize = braceStack.size(); + BraceType currentBraceType = braceStack.empty() ? BraceType::Normal : braceStack.back(); Lexeme result = next(); @@ -408,6 +419,13 @@ Lexeme Lexer::lookahead() lineOffset = currentLineOffset; lexeme = currentLexeme; prevLocation = currentPrevLocation; + if (FFlag::LuauLexerLookaheadRemembersBraceType) + { + if (braceStack.size() < currentBraceStackSize) + braceStack.push_back(currentBraceType); + else if (braceStack.size() > currentBraceStackSize) + braceStack.pop_back(); + } return result; } @@ -438,7 +456,17 @@ Position Lexer::position() const return Position(line, offset - lineOffset); } +LUAU_FORCEINLINE void Lexer::consume() +{ + // consume() assumes current character is known to not be a newline; use consumeAny if this is not guaranteed + LUAU_ASSERT(!isNewline(buffer[offset])); + + offset++; +} + +LUAU_FORCEINLINE +void Lexer::consumeAny() { if (isNewline(buffer[offset])) { @@ -524,7 +552,7 @@ Lexeme Lexer::readLongString(const Position& start, int sep, Lexeme::Type ok, Le } else { - consume(); + consumeAny(); } } @@ -540,7 +568,7 @@ void Lexer::readBackslashInString() case '\r': consume(); if (peekch() == '\n') - consume(); + consumeAny(); break; case 0: @@ -549,11 +577,11 @@ void Lexer::readBackslashInString() case 'z': consume(); while (isSpace(peekch())) - consume(); + consumeAny(); break; default: - consume(); + consumeAny(); } } @@ -878,15 +906,46 @@ Lexeme Lexer::readNext() return Lexeme(Location(start, 1), '+'); case '/': - consume(); - - if (peekch() == '=') + { + if (FFlag::LuauFloorDivision) { consume(); - return Lexeme(Location(start, 2), Lexeme::DivAssign); + + char ch = peekch(); + + if (ch == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::DivAssign); + } + else if (ch == '/') + { + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 3), Lexeme::FloorDivAssign); + } + else + return Lexeme(Location(start, 2), Lexeme::FloorDiv); + } + else + return Lexeme(Location(start, 1), '/'); } else - return Lexeme(Location(start, 1), '/'); + { + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::DivAssign); + } + else + return Lexeme(Location(start, 1), '/'); + } + } case '*': consume(); @@ -939,6 +998,9 @@ Lexeme Lexer::readNext() case ';': case ',': case '#': + case '?': + case '&': + case '|': { char ch = peekch(); consume(); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index cc5d7b38..d59b6b40 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -14,8 +14,7 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseDeclareClassIndexer, false) - -#define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" +LUAU_FASTFLAG(LuauFloorDivision) namespace Luau { @@ -462,11 +461,11 @@ AstStat* Parser::parseDo() Lexeme matchDo = lexer.current(); nextLexeme(); // do - AstStat* body = parseBlock(); + AstStatBlock* body = parseBlock(); body->location.begin = start.begin; - expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + body->hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); return body; } @@ -899,13 +898,13 @@ AstStat* Parser::parseDeclaration(const Location& start) expectAndConsume(':', "property type annotation"); AstType* type = parseType(); - // TODO: since AstName conains a char*, it can't contain null + // since AstName contains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); if (chars && !containsNull) props.push_back(AstDeclaredClassProp{AstName(chars->data), type, false}); else - report(begin.location, "String literal contains malformed escape sequence"); + report(begin.location, "String literal contains malformed escape sequence or \\0"); } else if (lexer.current().type == '[' && FFlag::LuauParseDeclareClassIndexer) { @@ -1328,13 +1327,13 @@ AstType* Parser::parseTableType() AstType* type = parseType(); - // TODO: since AstName conains a char*, it can't contain null + // since AstName contains a char*, it can't contain null bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); if (chars && !containsNull) props.push_back({AstName(chars->data), begin.location, type}); else - report(begin.location, "String literal contains malformed escape sequence"); + report(begin.location, "String literal contains malformed escape sequence or \\0"); } else if (lexer.current().type == '[') { @@ -1622,7 +1621,7 @@ AstTypeOrPack Parser::parseSimpleType(bool allowPack) else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return {reportTypeError(start, {}, "Malformed string")}; + return {reportTypeError(start, {}, "Malformed string; did you forget to finish it?")}; } else if (lexer.current().type == Lexeme::Name) { @@ -1741,7 +1740,8 @@ AstTypePack* Parser::parseTypePack() return allocator.alloc(Location(name.location, end), name.name); } - // No type pack annotation exists here. + // TODO: shouldParseTypePack can be removed and parseTypePack can be called unconditionally instead + LUAU_ASSERT(!"parseTypePack can't be called if shouldParseTypePack() returned false"); return nullptr; } @@ -1767,6 +1767,12 @@ std::optional Parser::parseBinaryOp(const Lexeme& l) return AstExprBinary::Mul; else if (l.type == '/') return AstExprBinary::Div; + else if (l.type == Lexeme::FloorDiv) + { + LUAU_ASSERT(FFlag::LuauFloorDivision); + + return AstExprBinary::FloorDiv; + } else if (l.type == '%') return AstExprBinary::Mod; else if (l.type == '^') @@ -1803,6 +1809,12 @@ std::optional Parser::parseCompoundOp(const Lexeme& l) return AstExprBinary::Mul; else if (l.type == Lexeme::DivAssign) return AstExprBinary::Div; + else if (l.type == Lexeme::FloorDivAssign) + { + LUAU_ASSERT(FFlag::LuauFloorDivision); + + return AstExprBinary::FloorDiv; + } else if (l.type == Lexeme::ModAssign) return AstExprBinary::Mod; else if (l.type == Lexeme::PowAssign) @@ -1826,7 +1838,7 @@ std::optional Parser::checkUnaryConfusables() if (curr.type == '!') { - report(start, "Unexpected '!', did you mean 'not'?"); + report(start, "Unexpected '!'; did you mean 'not'?"); return AstExprUnary::Not; } @@ -1848,20 +1860,20 @@ std::optional Parser::checkBinaryConfusables(const BinaryOpPr if (curr.type == '&' && next.type == '&' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::And].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '&&', did you mean 'and'?"); + report(Location(start, next.location), "Unexpected '&&'; did you mean 'and'?"); return AstExprBinary::And; } else if (curr.type == '|' && next.type == '|' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::Or].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '||', did you mean 'or'?"); + report(Location(start, next.location), "Unexpected '||'; did you mean 'or'?"); return AstExprBinary::Or; } else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::CompareNe].left > limit) { nextLexeme(); - report(Location(start, next.location), "Unexpected '!=', did you mean '~='?"); + report(Location(start, next.location), "Unexpected '!='; did you mean '~='?"); return AstExprBinary::CompareNe; } @@ -1873,12 +1885,13 @@ std::optional Parser::checkBinaryConfusables(const BinaryOpPr AstExpr* Parser::parseExpr(unsigned int limit) { static const BinaryOpPriority binaryPriority[] = { - {6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `%' - {10, 9}, {5, 4}, // power and concat (right associative) - {3, 3}, {3, 3}, // equality and inequality - {3, 3}, {3, 3}, {3, 3}, {3, 3}, // order - {2, 2}, {1, 1} // logical (and/or) + {6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `//' `%' + {10, 9}, {5, 4}, // power and concat (right associative) + {3, 3}, {3, 3}, // equality and inequality + {3, 3}, {3, 3}, {3, 3}, {3, 3}, // order + {2, 2}, {1, 1} // logical (and/or) }; + static_assert(sizeof(binaryPriority) / sizeof(binaryPriority[0]) == size_t(AstExprBinary::Op__Count), "binaryPriority needs an entry per op"); unsigned int recursionCounterOld = recursionCounter; @@ -2169,12 +2182,12 @@ AstExpr* Parser::parseSimpleExpr() else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); - return reportExprError(start, {}, "Malformed string"); + return reportExprError(start, {}, "Malformed string; did you forget to finish it?"); } else if (lexer.current().type == Lexeme::BrokenInterpDoubleBrace) { nextLexeme(); - return reportExprError(start, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + return reportExprError(start, {}, "Double braces are not permitted within interpolated strings; did you mean '\\{'?"); } else if (lexer.current().type == Lexeme::Dot3) { @@ -2312,7 +2325,7 @@ AstExpr* Parser::parseTableConstructor() nameString.data = const_cast(name.name.value); nameString.size = strlen(name.name.value); - AstExpr* key = allocator.alloc(name.location, nameString); + AstExpr* key = allocator.alloc(name.location, nameString, AstExprConstantString::Unquoted); AstExpr* value = parseExpr(); if (AstExprFunction* func = value->as()) @@ -2661,7 +2674,7 @@ AstExpr* Parser::parseInterpString() { errorWhileChecking = true; nextLexeme(); - expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '`'?")); + expressions.push_back(reportExprError(endLocation, {}, "Malformed interpolated string; did you forget to add a '`'?")); break; } default: @@ -2681,10 +2694,10 @@ AstExpr* Parser::parseInterpString() break; case Lexeme::BrokenInterpDoubleBrace: nextLexeme(); - return reportExprError(endLocation, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + return reportExprError(endLocation, {}, "Double braces are not permitted within interpolated strings; did you mean '\\{'?"); case Lexeme::BrokenString: nextLexeme(); - return reportExprError(endLocation, {}, "Malformed interpolated string, did you forget to add a '}'?"); + return reportExprError(endLocation, {}, "Malformed interpolated string; did you forget to add a '}'?"); default: return reportExprError(endLocation, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); } diff --git a/CLI/Compile.cpp b/CLI/Compile.cpp index 6197f03e..c35f9c3d 100644 --- a/CLI/Compile.cpp +++ b/CLI/Compile.cpp @@ -89,13 +89,14 @@ static void reportError(const char* name, const Luau::CompileError& error) report(name, error.getLocation(), "CompileError", error.what()); } -static std::string getCodegenAssembly(const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options) +static std::string getCodegenAssembly( + const char* name, const std::string& bytecode, Luau::CodeGen::AssemblyOptions options, Luau::CodeGen::LoweringStats* stats) { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); if (luau_load(L, name, bytecode.data(), bytecode.size(), 0) == 0) - return Luau::CodeGen::getAssembly(L, -1, options); + return Luau::CodeGen::getAssembly(L, -1, options, stats); fprintf(stderr, "Error loading bytecode %s\n", name); return ""; @@ -119,6 +120,8 @@ struct CompileStats double parseTime; double compileTime; double codegenTime; + + Luau::CodeGen::LoweringStats lowerStats; }; static double recordDeltaTime(double& timer) @@ -213,10 +216,10 @@ static bool compileFile(const char* name, CompileFormat format, Luau::CodeGen::A case CompileFormat::CodegenAsm: case CompileFormat::CodegenIr: case CompileFormat::CodegenVerbose: - printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options).c_str()); + printf("%s", getCodegenAssembly(name, bcb.getBytecode(), options, &stats.lowerStats).c_str()); break; case CompileFormat::CodegenNull: - stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options).size(); + stats.codegen += getCodegenAssembly(name, bcb.getBytecode(), options, &stats.lowerStats).size(); stats.codegenTime += recordDeltaTime(currts); break; case CompileFormat::Null: @@ -355,13 +358,22 @@ int main(int argc, char** argv) failed += !compileFile(path.c_str(), compileFormat, assemblyTarget, stats); if (compileFormat == CompileFormat::Null) + { printf("Compiled %d KLOC into %d KB bytecode (read %.2fs, parse %.2fs, compile %.2fs)\n", int(stats.lines / 1000), int(stats.bytecode / 1024), stats.readTime, stats.parseTime, stats.compileTime); + } else if (compileFormat == CompileFormat::CodegenNull) + { printf("Compiled %d KLOC into %d KB bytecode => %d KB native code (%.2fx) (read %.2fs, parse %.2fs, compile %.2fs, codegen %.2fs)\n", int(stats.lines / 1000), int(stats.bytecode / 1024), int(stats.codegen / 1024), stats.bytecode == 0 ? 0.0 : double(stats.codegen) / double(stats.bytecode), stats.readTime, stats.parseTime, stats.compileTime, stats.codegenTime); + printf("Lowering stats:\n"); + printf("- spills to stack: %d, spills to restore: %d, max spill slot %u\n", stats.lowerStats.spillsToSlot, stats.lowerStats.spillsToRestore, + stats.lowerStats.maxSpillSlotsUsed); + printf("- regalloc failed: %d, lowering failed %d\n", stats.lowerStats.regAllocErrors, stats.lowerStats.loweringErrors); + } + return failed ? 1 : 0; } diff --git a/CLI/Reduce.cpp b/CLI/Reduce.cpp index 38133e04..721bb51c 100644 --- a/CLI/Reduce.cpp +++ b/CLI/Reduce.cpp @@ -15,7 +15,7 @@ #define VERBOSE 0 // 1 - print out commandline invocations. 2 - print out stdout -#ifdef _WIN32 +#if defined(_WIN32) && !defined(__MINGW32__) const auto popen = &_popen; const auto pclose = &_pclose; diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index b2a523d9..19083bfc 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -757,14 +757,6 @@ int replMain(int argc, char** argv) } #endif -#if !LUA_CUSTOM_EXECUTION - if (codegen) - { - fprintf(stderr, "To run with --codegen, Luau has to be built with LUA_CUSTOM_EXECUTION enabled\n"); - return 1; - } -#endif - if (codegenPerf) { #if __linux__ @@ -784,10 +776,7 @@ int replMain(int argc, char** argv) } if (codegen && !Luau::CodeGen::isSupported()) - { - fprintf(stderr, "Cannot enable --codegen, native code generation is not supported in current configuration\n"); - return 1; - } + fprintf(stderr, "Warning: Native code generation is not supported in current configuration\n"); const std::vector files = getSourceFiles(argc, argv); diff --git a/CMakeLists.txt b/CMakeLists.txt index bc66a83d..15a1b8a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,7 +12,6 @@ option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) option(LUAU_EXTERN_C "Use extern C for all APIs" OFF) -option(LUAU_NATIVE "Enable support for native code generation" OFF) cmake_policy(SET CMP0054 NEW) cmake_policy(SET CMP0091 NEW) @@ -26,6 +25,7 @@ project(Luau LANGUAGES CXX C) add_library(Luau.Common INTERFACE) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) +add_library(Luau.Config STATIC) add_library(Luau.Analysis STATIC) add_library(Luau.CodeGen STATIC) add_library(Luau.VM STATIC) @@ -71,9 +71,13 @@ target_compile_features(Luau.Compiler PUBLIC cxx_std_17) target_include_directories(Luau.Compiler PUBLIC Compiler/include) target_link_libraries(Luau.Compiler PUBLIC Luau.Ast) +target_compile_features(Luau.Config PUBLIC cxx_std_17) +target_include_directories(Luau.Config PUBLIC Config/include) +target_link_libraries(Luau.Config PUBLIC Luau.Ast) + target_compile_features(Luau.Analysis PUBLIC cxx_std_17) target_include_directories(Luau.Analysis PUBLIC Analysis/include) -target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) +target_link_libraries(Luau.Analysis PUBLIC Luau.Ast Luau.Config) target_compile_features(Luau.CodeGen PRIVATE cxx_std_17) target_include_directories(Luau.CodeGen PUBLIC CodeGen/include) @@ -141,13 +145,7 @@ if(LUAU_EXTERN_C) target_compile_definitions(Luau.VM PUBLIC LUA_USE_LONGJMP=1) target_compile_definitions(Luau.VM PUBLIC LUA_API=extern\"C\") target_compile_definitions(Luau.Compiler PUBLIC LUACODE_API=extern\"C\") -endif() - -if(LUAU_NATIVE) - target_compile_definitions(Luau.VM PUBLIC LUA_CUSTOM_EXECUTION=1) - if(LUAU_EXTERN_C) - target_compile_definitions(Luau.CodeGen PUBLIC LUACODEGEN_API=extern\"C\") - endif() + target_compile_definitions(Luau.CodeGen PUBLIC LUACODEGEN_API=extern\"C\") endif() if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC" AND MSVC_VERSION GREATER_EQUAL 1924) diff --git a/CodeGen/include/Luau/AssemblyBuilderA64.h b/CodeGen/include/Luau/AssemblyBuilderA64.h index 3fc37d1d..931003b3 100644 --- a/CodeGen/include/Luau/AssemblyBuilderA64.h +++ b/CodeGen/include/Luau/AssemblyBuilderA64.h @@ -221,6 +221,7 @@ private: void placeFMOV(const char* name, RegisterA64 dst, double src, uint32_t op); void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op); void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms); + void placeER(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift); void place(uint32_t word); diff --git a/CodeGen/include/Luau/CodeAllocator.h b/CodeGen/include/Luau/CodeAllocator.h index e0537b64..d7e43272 100644 --- a/CodeGen/include/Luau/CodeAllocator.h +++ b/CodeGen/include/Luau/CodeAllocator.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/CodeGen.h" + #include #include @@ -16,6 +18,7 @@ constexpr uint32_t kCodeAlignment = 32; struct CodeAllocator { CodeAllocator(size_t blockSize, size_t maxTotalSize); + CodeAllocator(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext); ~CodeAllocator(); // Places data and code into the executable page area @@ -24,7 +27,7 @@ struct CodeAllocator bool allocate( const uint8_t* data, size_t dataSize, const uint8_t* code, size_t codeSize, uint8_t*& result, size_t& resultSize, uint8_t*& resultCodeStart); - // Provided to callbacks + // Provided to unwind info callbacks void* context = nullptr; // Called when new block is created to create and setup the unwinding information for all the code in the block @@ -34,12 +37,16 @@ struct CodeAllocator // Called to destroy unwinding information returned by 'createBlockUnwindInfo' void (*destroyBlockUnwindInfo)(void* context, void* unwindData) = nullptr; +private: // Unwind information can be placed inside the block with some implementation-specific reservations at the beginning // But to simplify block space checks, we limit the max size of all that data static const size_t kMaxReservedDataSize = 256; bool allocateNewBlock(size_t& unwindInfoSize); + uint8_t* allocatePages(size_t size) const; + void freePages(uint8_t* mem, size_t size) const; + // Current block we use for allocations uint8_t* blockPos = nullptr; uint8_t* blockEnd = nullptr; @@ -50,6 +57,9 @@ struct CodeAllocator size_t blockSize = 0; size_t maxTotalSize = 0; + + AllocationCallback* allocationCallback = nullptr; + void* allocationCallbackContext = nullptr; }; } // namespace CodeGen diff --git a/CodeGen/include/Luau/CodeGen.h b/CodeGen/include/Luau/CodeGen.h index 002b4a99..031779f7 100644 --- a/CodeGen/include/Luau/CodeGen.h +++ b/CodeGen/include/Luau/CodeGen.h @@ -3,6 +3,7 @@ #include +#include #include struct lua_State; @@ -18,12 +19,35 @@ enum CodeGenFlags CodeGen_OnlyNativeModules = 1 << 0, }; +enum class CodeGenCompilationResult +{ + Success, // Successfully generated code for at least one function + NothingToCompile, // There were no new functions to compile + + CodeGenNotInitialized, // Native codegen system is not initialized + CodeGenFailed, // Native codegen failed due to an internal compiler error + AllocationFailed, // Native codegen failed due to an allocation error +}; + +struct CompilationStats +{ + size_t bytecodeSizeBytes = 0; + size_t nativeCodeSizeBytes = 0; + size_t nativeDataSizeBytes = 0; + size_t nativeMetadataSizeBytes = 0; + + uint32_t functionsCompiled = 0; +}; + +using AllocationCallback = void(void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize); + bool isSupported(); +void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext); void create(lua_State* L); // Builds target function and all inner functions -void compile(lua_State* L, int idx, unsigned int flags = 0); +CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags = 0, CompilationStats* stats = nullptr); using AnnotatorFn = void (*)(void* context, std::string& result, int fid, int instpos); @@ -51,8 +75,18 @@ struct AssemblyOptions void* annotatorContext = nullptr; }; +struct LoweringStats +{ + int spillsToSlot = 0; + int spillsToRestore = 0; + unsigned maxSpillSlotsUsed = 0; + + int regAllocErrors = 0; + int loweringErrors = 0; +}; + // Generates assembly for target function and all inner functions -std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {}); +std::string getAssembly(lua_State* L, int idx, AssemblyOptions options = {}, LoweringStats* stats = nullptr); using PerfLogFn = void (*)(void* context, uintptr_t addr, unsigned size, const char* symbol); diff --git a/CodeGen/include/Luau/ConditionX64.h b/CodeGen/include/Luau/ConditionX64.h index 78c3fda2..4432641a 100644 --- a/CodeGen/include/Luau/ConditionX64.h +++ b/CodeGen/include/Luau/ConditionX64.h @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/Common.h" + namespace Luau { namespace CodeGen @@ -43,5 +45,68 @@ enum class ConditionX64 : uint8_t Count }; +inline ConditionX64 getReverseCondition(ConditionX64 cond) +{ + switch (cond) + { + case ConditionX64::Overflow: + return ConditionX64::NoOverflow; + case ConditionX64::NoOverflow: + return ConditionX64::Overflow; + case ConditionX64::Carry: + return ConditionX64::NoCarry; + case ConditionX64::NoCarry: + return ConditionX64::Carry; + case ConditionX64::Below: + return ConditionX64::NotBelow; + case ConditionX64::BelowEqual: + return ConditionX64::NotBelowEqual; + case ConditionX64::Above: + return ConditionX64::NotAbove; + case ConditionX64::AboveEqual: + return ConditionX64::NotAboveEqual; + case ConditionX64::Equal: + return ConditionX64::NotEqual; + case ConditionX64::Less: + return ConditionX64::NotLess; + case ConditionX64::LessEqual: + return ConditionX64::NotLessEqual; + case ConditionX64::Greater: + return ConditionX64::NotGreater; + case ConditionX64::GreaterEqual: + return ConditionX64::NotGreaterEqual; + case ConditionX64::NotBelow: + return ConditionX64::Below; + case ConditionX64::NotBelowEqual: + return ConditionX64::BelowEqual; + case ConditionX64::NotAbove: + return ConditionX64::Above; + case ConditionX64::NotAboveEqual: + return ConditionX64::AboveEqual; + case ConditionX64::NotEqual: + return ConditionX64::Equal; + case ConditionX64::NotLess: + return ConditionX64::Less; + case ConditionX64::NotLessEqual: + return ConditionX64::LessEqual; + case ConditionX64::NotGreater: + return ConditionX64::Greater; + case ConditionX64::NotGreaterEqual: + return ConditionX64::GreaterEqual; + case ConditionX64::Zero: + return ConditionX64::NotZero; + case ConditionX64::NotZero: + return ConditionX64::Zero; + case ConditionX64::Parity: + return ConditionX64::NotParity; + case ConditionX64::NotParity: + return ConditionX64::Parity; + case ConditionX64::Count: + LUAU_ASSERT(!"invalid ConditionX64 value"); + } + + return ConditionX64::Count; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index ca1eba62..5fcaf46b 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -4,6 +4,7 @@ #include "Luau/Common.h" #include +#include #include #include @@ -96,6 +97,46 @@ struct CfgInfo void computeCfgImmediateDominators(IrFunction& function); void computeCfgDominanceTreeChildren(IrFunction& function); +struct IdfContext +{ + struct BlockAndOrdering + { + uint32_t blockIdx; + BlockOrdering ordering; + + bool operator<(const BlockAndOrdering& rhs) const + { + if (ordering.depth != rhs.ordering.depth) + return ordering.depth < rhs.ordering.depth; + + return ordering.preOrder < rhs.ordering.preOrder; + } + }; + + // Using priority queue to work on nodes in the order from the bottom of the dominator tree to the top + // If the depth of keys is equal, DFS order is used to provide strong ordering + std::priority_queue queue; + std::vector worklist; + + struct IdfVisitMarks + { + bool seenInQueue = false; + bool seenInWorklist = false; + }; + + std::vector visits; + + std::vector idf; +}; + +// Compute iterated dominance frontier (IDF or DF+) for a variable, given the set of blocks where that variable is defined +// Providing a set of blocks where the variable is a live-in at the entry helps produce a pruned SSA form (inserted phi nodes will not be dead) +// +// 'Iterated' comes from the definition where we recompute the IDFn+1 = DF(S) while adding IDFn to S until a fixed point is reached +// Iterated dominance frontier has been shown to be equal to the set of nodes where phi instructions have to be inserted +void computeIteratedDominanceFrontierForDefs( + IdfContext& ctx, const IrFunction& function, const std::vector& defBlocks, const std::vector& liveInBlocks); + // Function used to update all CFG data void computeCfgInfo(IrFunction& function); diff --git a/CodeGen/include/Luau/IrBuilder.h b/CodeGen/include/Luau/IrBuilder.h index d854b400..b953c888 100644 --- a/CodeGen/include/Luau/IrBuilder.h +++ b/CodeGen/include/Luau/IrBuilder.h @@ -66,6 +66,8 @@ struct IrBuilder bool inTerminatedBlock = false; + bool interruptRequested = false; + bool activeFastcallFallback = false; IrOp fastcallFallbackReturn; int fastcallSkipTarget = -1; @@ -76,6 +78,8 @@ struct IrBuilder std::vector instIndexToBlock; // Block index at the bytecode instruction + std::vector loopStepStack; + // Similar to BytecodeBuilder, duplicate constants are removed used the same method struct ConstantKey { diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 67909b13..298258c1 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -53,12 +53,9 @@ enum class IrCmd : uint8_t // Load a TValue from memory // A: Rn or Kn or pointer (TValue) + // B: int (optional 'A' pointer offset) LOAD_TVALUE, - // Load a TValue from table node value - // A: pointer (LuaNode) - LOAD_NODE_VALUE_TV, // TODO: we should find a way to generalize LOAD_TVALUE - // Load current environment table LOAD_ENV, @@ -70,6 +67,7 @@ enum class IrCmd : uint8_t // Get pointer (LuaNode) to table node element at the active cached slot index // A: pointer (Table) // B: unsigned int (pcpos) + // C: Kn GET_SLOT_NODE_ADDR, // Get pointer (LuaNode) to table node element at the main position of the specified key hash @@ -113,12 +111,15 @@ enum class IrCmd : uint8_t // Store a TValue into memory // A: Rn or pointer (TValue) // B: TValue + // C: int (optional 'A' pointer offset) STORE_TVALUE, - // Store a TValue into table node value - // A: pointer (LuaNode) - // B: TValue - STORE_NODE_VALUE_TV, // TODO: we should find a way to generalize STORE_TVALUE + // Store a pair of tag and value into memory + // A: Rn or pointer (TValue) + // B: tag (must be a constant) + // C: int/double/pointer + // D: int (optional 'A' pointer offset) + STORE_SPLIT_TVALUE, // Add/Sub two integers together // A, B: int @@ -132,6 +133,7 @@ enum class IrCmd : uint8_t SUB_NUM, MUL_NUM, DIV_NUM, + IDIV_NUM, MOD_NUM, // Get the minimum/maximum of two numbers @@ -176,7 +178,7 @@ enum class IrCmd : uint8_t CMP_ANY, // Unconditional jump - // A: block/vmexit + // A: block/vmexit/undef JUMP, // Jump if TValue is truthy @@ -197,24 +199,12 @@ enum class IrCmd : uint8_t // D: block (if false) JUMP_EQ_TAG, - // Jump if two int numbers are equal - // A, B: int - // C: block (if true) - // D: block (if false) - JUMP_EQ_INT, - - // Jump if A < B - // A, B: int - // C: block (if true) - // D: block (if false) - JUMP_LT_INT, - - // Jump if unsigned(A) >= unsigned(B) + // Perform a conditional jump based on the result of integer comparison // A, B: int // C: condition // D: block (if true) // E: block (if false) - JUMP_GE_UINT, + JUMP_CMP_INT, // Jump if pointers are equal // A, B: pointer (*) @@ -245,14 +235,19 @@ enum class IrCmd : uint8_t STRING_LEN, // Allocate new table - // A: int (array element count) - // B: int (node element count) + // A: unsigned int (array element count) + // B: unsigned int (node element count) NEW_TABLE, // Duplicate a table // A: pointer (Table) DUP_TABLE, + // Insert an integer key into a table + // A: pointer (Table) + // B: int (key) + TABLE_SETNUM, + // Try to convert a double number into a table index (int) or jump if it's not an integer // A: double // B: block @@ -356,23 +351,16 @@ enum class IrCmd : uint8_t // Store TValue from stack slot into a function upvalue // A: UPn // B: Rn + // C: tag/undef (tag of the value that was written) SET_UPVALUE, - // Convert TValues into numbers for a numerical for loop - // A: Rn (start) - // B: Rn (end) - // C: Rn (step) - PREPARE_FORN, - // Guards and checks (these instructions are not block terminators even though they jump to fallback) // Guard against tag mismatch // A, B: tag // C: block/vmexit/undef - // D: bool (finish execution in VM on failure) // In final x64 lowering, A can also be Rn - // When undef is specified instead of a block, execution is aborted on check failure; if D is true, execution is continued in VM interpreter - // instead. + // When undef is specified instead of a block, execution is aborted on check failure CHECK_TAG, // Guard against a falsy tag+value @@ -418,6 +406,12 @@ enum class IrCmd : uint8_t // When undef is specified instead of a block, execution is aborted on check failure CHECK_NODE_NO_NEXT, + // Guard against table node with 'nil' value + // A: pointer (LuaNode) + // B: block/vmexit/undef + // When undef is specified instead of a block, execution is aborted on check failure + CHECK_NODE_VALUE, + // Special operations // Check interrupt handler @@ -464,6 +458,7 @@ enum class IrCmd : uint8_t // C: Rn (source start) // D: int (count or -1 to assign values up to stack top) // E: unsigned int (table index to start from) + // F: undef/unsigned int (target table known size) SETLIST, // Call specified function @@ -689,6 +684,10 @@ enum class IrOpKind : uint32_t VmExit, }; +// VmExit uses a special value to indicate that pcpos update should be skipped +// This is only used during type checking at function entry +constexpr uint32_t kVmExitEntryGuardPc = (1u << 28) - 1; + struct IrOp { IrOpKind kind : 4; @@ -834,6 +833,8 @@ struct IrBlock uint32_t finish = ~0u; uint32_t sortkey = ~0u; + uint32_t chainkey = 0; + uint32_t expectedNextBlock = ~0u; Label label; }; @@ -851,6 +852,8 @@ struct IrFunction std::vector constants; std::vector bcMapping; + uint32_t entryBlock = 0; + uint32_t entryLocation = 0; // For each instruction, an operand that can be used to recompute the value std::vector valueRestoreOps; @@ -993,23 +996,26 @@ struct IrFunction valueRestoreOps[instIdx] = location; } - IrOp findRestoreOp(uint32_t instIdx) const + IrOp findRestoreOp(uint32_t instIdx, bool limitToCurrentBlock) const { if (instIdx >= valueRestoreOps.size()) return {}; const IrBlock& block = blocks[validRestoreOpBlockIdx]; - // Values can only reference restore operands in the current block - if (instIdx < block.start || instIdx > block.finish) - return {}; + // When spilled, values can only reference restore operands in the current block + if (limitToCurrentBlock) + { + if (instIdx < block.start || instIdx > block.finish) + return {}; + } return valueRestoreOps[instIdx]; } - IrOp findRestoreOp(const IrInst& inst) const + IrOp findRestoreOp(const IrInst& inst, bool limitToCurrentBlock) const { - return findRestoreOp(getInstIndex(inst)); + return findRestoreOp(getInstIndex(inst), limitToCurrentBlock); } }; @@ -1037,5 +1043,11 @@ inline int vmUpvalueOp(IrOp op) return op.index; } +inline uint32_t vmExitOp(IrOp op) +{ + LUAU_ASSERT(op.kind == IrOpKind::VmExit); + return op.index; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrRegAllocX64.h b/CodeGen/include/Luau/IrRegAllocX64.h index 95930811..632499a4 100644 --- a/CodeGen/include/Luau/IrRegAllocX64.h +++ b/CodeGen/include/Luau/IrRegAllocX64.h @@ -12,6 +12,9 @@ namespace Luau { namespace CodeGen { + +struct LoweringStats; + namespace X64 { @@ -33,7 +36,7 @@ struct IrSpillX64 struct IrRegAllocX64 { - IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function); + IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function, LoweringStats* stats); RegisterX64 allocReg(SizeX64 size, uint32_t instIdx); RegisterX64 allocRegOrReuse(SizeX64 size, uint32_t instIdx, std::initializer_list oprefs); @@ -70,6 +73,7 @@ struct IrRegAllocX64 AssemblyBuilderX64& build; IrFunction& function; + LoweringStats* stats = nullptr; uint32_t currInstIdx = ~0u; @@ -77,6 +81,7 @@ struct IrRegAllocX64 std::array gprInstUsers; std::array freeXmmMap; std::array xmmInstUsers; + uint8_t usableXmmRegCount = 0; std::bitset<256> usedSpillSlots; unsigned maxUsedSlot = 0; diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index fe38cb90..50b3dc7c 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -94,9 +94,7 @@ inline bool isBlockTerminator(IrCmd cmd) case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_EQ_TAG: - case IrCmd::JUMP_EQ_INT: - case IrCmd::JUMP_LT_INT: - case IrCmd::JUMP_GE_UINT: + case IrCmd::JUMP_CMP_INT: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_SLOT_MATCH: @@ -128,6 +126,7 @@ inline bool isNonTerminatingJump(IrCmd cmd) case IrCmd::CHECK_ARRAY_SIZE: case IrCmd::CHECK_SLOT_MATCH: case IrCmd::CHECK_NODE_NO_NEXT: + case IrCmd::CHECK_NODE_VALUE: return true; default: break; @@ -145,7 +144,6 @@ inline bool hasResult(IrCmd cmd) case IrCmd::LOAD_DOUBLE: case IrCmd::LOAD_INT: case IrCmd::LOAD_TVALUE: - case IrCmd::LOAD_NODE_VALUE_TV: case IrCmd::LOAD_ENV: case IrCmd::GET_ARR_ADDR: case IrCmd::GET_SLOT_NODE_ADDR: @@ -157,6 +155,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::SUB_NUM: case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: + case IrCmd::IDIV_NUM: case IrCmd::MOD_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: @@ -169,6 +168,7 @@ inline bool hasResult(IrCmd cmd) case IrCmd::NOT_ANY: case IrCmd::CMP_ANY: case IrCmd::TABLE_LEN: + case IrCmd::TABLE_SETNUM: case IrCmd::STRING_LEN: case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: @@ -264,5 +264,14 @@ uint32_t getNativeContextOffset(int bfid); // Cleans up blocks that were created with no users void killUnusedBlocks(IrFunction& function); +// Get blocks in order that tries to maximize fallthrough between them during lowering +// We want to mostly preserve build order with fallbacks outlined +// But we also use hints from optimization passes that chain blocks together where there's only one out-in edge between them +std::vector getSortedBlockOrder(IrFunction& function); + +// Returns first non-dead block that comes after block at index 'i' in the sorted blocks array +// 'dummy' block is returned if the end of array was reached +IrBlock& getNextBlock(IrFunction& function, std::vector& sortedBlocks, IrBlock& dummy, size_t i); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/RegisterA64.h b/CodeGen/include/Luau/RegisterA64.h index d50369e3..beb34ca7 100644 --- a/CodeGen/include/Luau/RegisterA64.h +++ b/CodeGen/include/Luau/RegisterA64.h @@ -47,18 +47,6 @@ constexpr RegisterA64 castReg(KindA64 kind, RegisterA64 reg) return RegisterA64{kind, reg.index}; } -// This is equivalent to castReg(KindA64::x), but is separate because it implies different semantics -// Specifically, there are cases when it's useful to treat a wN register as an xN register *after* it has been assigned a value -// Since all A64 instructions that write to wN implicitly zero the top half, this works when we need zero extension semantics -// Crucially, this is *not* safe on an ABI boundary - an int parameter in wN register may have anything in its top half in certain cases -// However, as long as our codegen doesn't use 32-bit truncation by using castReg x=>w, we can safely rely on this. -constexpr RegisterA64 zextReg(RegisterA64 reg) -{ - LUAU_ASSERT(reg.kind == KindA64::w); - - return RegisterA64{KindA64::x, reg.index}; -} - constexpr RegisterA64 noreg{KindA64::none, 0}; constexpr RegisterA64 w0{KindA64::w, 0}; diff --git a/CodeGen/include/Luau/UnwindBuilder.h b/CodeGen/include/Luau/UnwindBuilder.h index 8a44629f..1ba377ba 100644 --- a/CodeGen/include/Luau/UnwindBuilder.h +++ b/CodeGen/include/Luau/UnwindBuilder.h @@ -5,6 +5,7 @@ #include "Luau/RegisterX64.h" #include +#include #include #include @@ -48,7 +49,8 @@ public: // mov rbp, rsp // push reg in the order specified in regs // sub rsp, stackSize - virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) = 0; + virtual void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, + const std::vector& simd) = 0; virtual size_t getSize() const = 0; virtual size_t getFunctionCount() const = 0; diff --git a/CodeGen/include/Luau/UnwindBuilderDwarf2.h b/CodeGen/include/Luau/UnwindBuilderDwarf2.h index 66749bfc..741aaed2 100644 --- a/CodeGen/include/Luau/UnwindBuilderDwarf2.h +++ b/CodeGen/include/Luau/UnwindBuilderDwarf2.h @@ -30,7 +30,8 @@ public: void finishInfo() override; void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) override; - void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) override; + void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, + const std::vector& simd) override; size_t getSize() const override; size_t getFunctionCount() const override; diff --git a/CodeGen/include/Luau/UnwindBuilderWin.h b/CodeGen/include/Luau/UnwindBuilderWin.h index 5afed693..3a7e1b5a 100644 --- a/CodeGen/include/Luau/UnwindBuilderWin.h +++ b/CodeGen/include/Luau/UnwindBuilderWin.h @@ -50,7 +50,8 @@ public: void finishInfo() override; void prologueA64(uint32_t prologueSize, uint32_t stackSize, std::initializer_list regs) override; - void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) override; + void prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, + const std::vector& simd) override; size_t getSize() const override; size_t getFunctionCount() const override; diff --git a/CodeGen/src/AssemblyBuilderA64.cpp b/CodeGen/src/AssemblyBuilderA64.cpp index c62d797a..15be6b88 100644 --- a/CodeGen/src/AssemblyBuilderA64.cpp +++ b/CodeGen/src/AssemblyBuilderA64.cpp @@ -5,6 +5,7 @@ #include "ByteUtils.h" #include +#include namespace Luau { @@ -104,7 +105,10 @@ void AssemblyBuilderA64::movk(RegisterA64 dst, uint16_t src, int shift) void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift) { - placeSR3("add", dst, src1, src2, 0b00'01011, shift); + if (src1.kind == KindA64::x && src2.kind == KindA64::w) + placeER("add", dst, src1, src2, 0b00'01011, shift); + else + placeSR3("add", dst, src1, src2, 0b00'01011, shift); } void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, uint16_t src2) @@ -114,7 +118,10 @@ void AssemblyBuilderA64::add(RegisterA64 dst, RegisterA64 src1, uint16_t src2) void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, int shift) { - placeSR3("sub", dst, src1, src2, 0b10'01011, shift); + if (src1.kind == KindA64::x && src2.kind == KindA64::w) + placeER("sub", dst, src1, src2, 0b10'01011, shift); + else + placeSR3("sub", dst, src1, src2, 0b10'01011, shift); } void AssemblyBuilderA64::sub(RegisterA64 dst, RegisterA64 src1, uint16_t src2) @@ -1074,6 +1081,22 @@ void AssemblyBuilderA64::placeBFM(const char* name, RegisterA64 dst, RegisterA64 commit(); } +void AssemblyBuilderA64::placeER(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift) +{ + if (logText) + log(name, dst, src1, src2, shift); + + LUAU_ASSERT(dst.kind == KindA64::x && src1.kind == KindA64::x); + LUAU_ASSERT(src2.kind == KindA64::w); + LUAU_ASSERT(shift >= 0 && shift <= 4); + + uint32_t sf = (dst.kind == KindA64::x) ? 0x80000000 : 0; // could be useful in the future for byte->word extends + int option = 0b010; // UXTW + + place(dst.index | (src1.index << 5) | (shift << 10) | (option << 13) | (src2.index << 16) | (1 << 21) | (op << 24) | sf); + commit(); +} + void AssemblyBuilderA64::place(uint32_t word) { LUAU_ASSERT(codePos < codeEnd); @@ -1166,7 +1189,9 @@ void AssemblyBuilderA64::log(const char* opcode, RegisterA64 dst, RegisterA64 sr log(src1); text.append(","); log(src2); - if (shift > 0) + if (src1.kind == KindA64::x && src2.kind == KindA64::w) + logAppend(" UXTW #%d", shift); + else if (shift > 0) logAppend(" LSL #%d", shift); else if (shift < 0) logAppend(" LSR #%d", -shift); diff --git a/CodeGen/src/AssemblyBuilderX64.cpp b/CodeGen/src/AssemblyBuilderX64.cpp index 2a8bc92e..bd2568d2 100644 --- a/CodeGen/src/AssemblyBuilderX64.cpp +++ b/CodeGen/src/AssemblyBuilderX64.cpp @@ -5,7 +5,6 @@ #include #include -#include namespace Luau { diff --git a/CodeGen/src/CodeAllocator.cpp b/CodeGen/src/CodeAllocator.cpp index 880a3244..84e48a1b 100644 --- a/CodeGen/src/CodeAllocator.cpp +++ b/CodeGen/src/CodeAllocator.cpp @@ -33,13 +33,17 @@ static size_t alignToPageSize(size_t size) } #if defined(_WIN32) -static uint8_t* allocatePages(size_t size) +static uint8_t* allocatePagesImpl(size_t size) { - return (uint8_t*)VirtualAlloc(nullptr, alignToPageSize(size), MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); + LUAU_ASSERT(size == alignToPageSize(size)); + + return (uint8_t*)VirtualAlloc(nullptr, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); } -static void freePages(uint8_t* mem, size_t size) +static void freePagesImpl(uint8_t* mem, size_t size) { + LUAU_ASSERT(size == alignToPageSize(size)); + if (VirtualFree(mem, 0, MEM_RELEASE) == 0) LUAU_ASSERT(!"failed to deallocate block memory"); } @@ -62,14 +66,24 @@ static void flushInstructionCache(uint8_t* mem, size_t size) #endif } #else -static uint8_t* allocatePages(size_t size) +static uint8_t* allocatePagesImpl(size_t size) { - return (uint8_t*)mmap(nullptr, alignToPageSize(size), PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); + LUAU_ASSERT(size == alignToPageSize(size)); + +#ifdef __APPLE__ + void* result = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON | MAP_JIT, -1, 0); +#else + void* result = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); +#endif + + return (result == MAP_FAILED) ? nullptr : static_cast(result); } -static void freePages(uint8_t* mem, size_t size) +static void freePagesImpl(uint8_t* mem, size_t size) { - if (munmap(mem, alignToPageSize(size)) != 0) + LUAU_ASSERT(size == alignToPageSize(size)); + + if (munmap(mem, size) != 0) LUAU_ASSERT(!"Failed to deallocate block memory"); } @@ -94,8 +108,15 @@ namespace CodeGen { CodeAllocator::CodeAllocator(size_t blockSize, size_t maxTotalSize) - : blockSize(blockSize) - , maxTotalSize(maxTotalSize) + : CodeAllocator(blockSize, maxTotalSize, nullptr, nullptr) +{ +} + +CodeAllocator::CodeAllocator(size_t blockSize, size_t maxTotalSize, AllocationCallback* allocationCallback, void* allocationCallbackContext) + : blockSize{blockSize} + , maxTotalSize{maxTotalSize} + , allocationCallback{allocationCallback} + , allocationCallbackContext{allocationCallbackContext} { LUAU_ASSERT(blockSize > kMaxReservedDataSize); LUAU_ASSERT(maxTotalSize >= blockSize); @@ -207,5 +228,29 @@ bool CodeAllocator::allocateNewBlock(size_t& unwindInfoSize) return true; } +uint8_t* CodeAllocator::allocatePages(size_t size) const +{ + const size_t pageAlignedSize = alignToPageSize(size); + + uint8_t* const mem = allocatePagesImpl(pageAlignedSize); + if (mem == nullptr) + return nullptr; + + if (allocationCallback) + allocationCallback(allocationCallbackContext, nullptr, 0, mem, pageAlignedSize); + + return mem; +} + +void CodeAllocator::freePages(uint8_t* mem, size_t size) const +{ + const size_t pageAlignedSize = alignToPageSize(size); + + if (allocationCallback) + allocationCallback(allocationCallbackContext, mem, pageAlignedSize, nullptr, 0); + + freePagesImpl(mem, pageAlignedSize); +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 602130f1..a25e1046 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -65,7 +65,7 @@ static NativeProto createNativeProto(Proto* proto, const IrBuilder& ir) int sizecode = proto->sizecode; uint32_t* instOffsets = new uint32_t[sizecode]; - uint32_t instTarget = ir.function.bcMapping[0].asmLocation; + uint32_t instTarget = ir.function.entryLocation; for (int i = 0; i < sizecode; i++) { @@ -74,6 +74,9 @@ static NativeProto createNativeProto(Proto* proto, const IrBuilder& ir) instOffsets[i] = ir.function.bcMapping[i].asmLocation - instTarget; } + // Set first instruction offset to 0 so that entering this function still executes any generated entry code. + instOffsets[0] = 0; + // entry target will be relocated when assembly is finalized return {proto, instOffsets, instTarget}; } @@ -103,7 +106,7 @@ static std::optional createNativeFunction(AssemblyBuilder& build, M IrBuilder ir; ir.buildFunctionIr(proto); - if (!lowerFunction(ir, build, helpers, proto, {})) + if (!lowerFunction(ir, build, helpers, proto, {}, /* stats */ nullptr)) return std::nullopt; return createNativeProto(proto, ir); @@ -159,9 +162,6 @@ unsigned int getCpuFeaturesA64() bool isSupported() { - if (!LUA_CUSTOM_EXECUTION) - return false; - if (LUA_EXTRA_SIZE != 1) return false; @@ -202,11 +202,11 @@ bool isSupported() #endif } -void create(lua_State* L) +void create(lua_State* L, AllocationCallback* allocationCallback, void* allocationCallbackContext) { LUAU_ASSERT(isSupported()); - std::unique_ptr data = std::make_unique(); + std::unique_ptr data = std::make_unique(allocationCallback, allocationCallbackContext); #if defined(_WIN32) data->unwindBuilder = std::make_unique(); @@ -239,23 +239,38 @@ void create(lua_State* L) ecb->enter = onEnter; } -void compile(lua_State* L, int idx, unsigned int flags) +void create(lua_State* L) +{ + create(L, nullptr, nullptr); +} + +CodeGenCompilationResult compile(lua_State* L, int idx, unsigned int flags, CompilationStats* stats) { LUAU_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); + Proto* root = clvalue(func)->l.p; + if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) + return CodeGenCompilationResult::NothingToCompile; + // If initialization has failed, do not compile any functions NativeState* data = getNativeState(L); if (!data) - return; - - Proto* root = clvalue(func)->l.p; - if ((flags & CodeGen_OnlyNativeModules) != 0 && (root->flags & LPF_NATIVE_MODULE) == 0) - return; + return CodeGenCompilationResult::CodeGenNotInitialized; std::vector protos; gatherFunctions(protos, root); + // Skip protos that have been compiled during previous invocations of CodeGen::compile + protos.erase(std::remove_if(protos.begin(), protos.end(), + [](Proto* p) { + return p == nullptr || p->execdata != nullptr; + }), + protos.end()); + + if (protos.empty()) + return CodeGenCompilationResult::NothingToCompile; + #if defined(__aarch64__) static unsigned int cpuFeatures = getCpuFeaturesA64(); A64::AssemblyBuilderA64 build(/* logText= */ false, cpuFeatures); @@ -273,11 +288,9 @@ void compile(lua_State* L, int idx, unsigned int flags) std::vector results; results.reserve(protos.size()); - // Skip protos that have been compiled during previous invocations of CodeGen::compile for (Proto* p : protos) - if (p && p->execdata == nullptr) - if (std::optional np = createNativeFunction(build, helpers, p)) - results.push_back(*np); + if (std::optional np = createNativeFunction(build, helpers, p)) + results.push_back(*np); // Very large modules might result in overflowing a jump offset; in this case we currently abandon the entire module if (!build.finalize()) @@ -285,12 +298,12 @@ void compile(lua_State* L, int idx, unsigned int flags) for (NativeProto result : results) destroyExecData(result.execdata); - return; + return CodeGenCompilationResult::CodeGenFailed; } // If no functions were assembled, we don't need to allocate/copy executable pages for helpers if (results.empty()) - return; + return CodeGenCompilationResult::CodeGenFailed; uint8_t* nativeData = nullptr; size_t sizeNativeData = 0; @@ -301,7 +314,7 @@ void compile(lua_State* L, int idx, unsigned int flags) for (NativeProto result : results) destroyExecData(result.execdata); - return; + return CodeGenCompilationResult::AllocationFailed; } if (gPerfLogFn && results.size() > 0) @@ -318,13 +331,30 @@ void compile(lua_State* L, int idx, unsigned int flags) } } - for (NativeProto result : results) + for (const NativeProto& result : results) { // the memory is now managed by VM and will be freed via onDestroyFunction result.p->execdata = result.execdata; result.p->exectarget = uintptr_t(codeStart) + result.exectarget; result.p->codeentry = &kCodeEntryInsn; } + + if (stats != nullptr) + { + for (const NativeProto& result : results) + { + stats->bytecodeSizeBytes += result.p->sizecode * sizeof(Instruction); + + // Account for the native -> bytecode instruction offsets mapping: + stats->nativeMetadataSizeBytes += result.p->sizecode * sizeof(uint32_t); + } + + stats->functionsCompiled += uint32_t(results.size()); + stats->nativeCodeSizeBytes += build.code.size(); + stats->nativeDataSizeBytes += build.data.size(); + } + + return CodeGenCompilationResult::Success; } void setPerfLog(void* context, PerfLogFn logFn) diff --git a/CodeGen/src/CodeGenA64.cpp b/CodeGen/src/CodeGenA64.cpp index 37af59aa..2e268d26 100644 --- a/CodeGen/src/CodeGenA64.cpp +++ b/CodeGen/src/CodeGenA64.cpp @@ -24,15 +24,6 @@ struct EntryLocations Label epilogueStart; }; -static void emitClearNativeFlag(AssemblyBuilderA64& build) -{ - build.ldr(x0, mem(rState, offsetof(lua_State, ci))); - build.ldr(w1, mem(x0, offsetof(CallInfo, flags))); - build.mov(w2, ~LUA_CALLINFO_NATIVE); - build.and_(w1, w1, w2); - build.str(w1, mem(x0, offsetof(CallInfo, flags))); -} - static void emitExit(AssemblyBuilderA64& build, bool continueInVm) { build.mov(x0, continueInVm); @@ -40,14 +31,21 @@ static void emitExit(AssemblyBuilderA64& build, bool continueInVm) build.br(x1); } -static void emitUpdatePcAndContinueInVm(AssemblyBuilderA64& build) +static void emitUpdatePcForExit(AssemblyBuilderA64& build) { // x0 = pcpos * sizeof(Instruction) build.add(x0, rCode, x0); build.ldr(x1, mem(rState, offsetof(lua_State, ci))); build.str(x0, mem(x1, offsetof(CallInfo, savedpc))); +} - emitExit(build, /* continueInVm */ true); +static void emitClearNativeFlag(AssemblyBuilderA64& build) +{ + build.ldr(x0, mem(rState, offsetof(lua_State, ci))); + build.ldr(w1, mem(x0, offsetof(CallInfo, flags))); + build.mov(w2, ~LUA_CALLINFO_NATIVE); + build.and_(w1, w1, w2); + build.str(w1, mem(x0, offsetof(CallInfo, flags))); } static void emitInterrupt(AssemblyBuilderA64& build) @@ -227,6 +225,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde build.stp(x19, x20, mem(sp, 16)); build.stp(x21, x22, mem(sp, 32)); build.stp(x23, x24, mem(sp, 48)); + build.str(x25, mem(sp, 64)); build.mov(x29, sp); // this is only necessary if we maintain frame pointers, which we do in the JIT for now @@ -237,6 +236,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde // Setup native execution environment build.mov(rState, x0); build.mov(rNativeContext, x3); + build.ldr(rGlobalState, mem(x0, offsetof(lua_State, global))); build.ldr(rBase, mem(x0, offsetof(lua_State, base))); // L->base @@ -254,6 +254,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde locations.epilogueStart = build.setLabel(); // Cleanup and exit + build.ldr(x25, mem(sp, 64)); build.ldp(x23, x24, mem(sp, 48)); build.ldp(x21, x22, mem(sp, 32)); build.ldp(x19, x20, mem(sp, 16)); @@ -264,7 +265,7 @@ static EntryLocations buildEntryFunction(AssemblyBuilderA64& build, UnwindBuilde // Our entry function is special, it spans the whole remaining code area unwind.startFunction(); - unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24}); + unwind.prologueA64(prologueSize, kStackSize, {x29, x30, x19, x20, x21, x22, x23, x24, x25}); unwind.finishFunction(build.getLabelOffset(locations.start), kFullBlockFuncton); return locations; @@ -305,6 +306,11 @@ bool initHeaderFunctions(NativeState& data) void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers) { + if (build.logText) + build.logAppend("; updatePcAndContinueInVm\n"); + build.setLabel(helpers.updatePcAndContinueInVm); + emitUpdatePcForExit(build); + if (build.logText) build.logAppend("; exitContinueVmClearNativeFlag\n"); build.setLabel(helpers.exitContinueVmClearNativeFlag); @@ -320,11 +326,6 @@ void assembleHelpers(AssemblyBuilderA64& build, ModuleHelpers& helpers) build.setLabel(helpers.exitNoContinueVm); emitExit(build, /* continueInVm */ false); - if (build.logText) - build.logAppend("; updatePcAndContinueInVm\n"); - build.setLabel(helpers.updatePcAndContinueInVm); - emitUpdatePcAndContinueInVm(build); - if (build.logText) build.logAppend("; reentry\n"); build.setLabel(helpers.reentry); diff --git a/CodeGen/src/CodeGenAssembly.cpp b/CodeGen/src/CodeGenAssembly.cpp index fed5ddd3..4fd24cd9 100644 --- a/CodeGen/src/CodeGenAssembly.cpp +++ b/CodeGen/src/CodeGenAssembly.cpp @@ -43,7 +43,7 @@ static void logFunctionHeader(AssemblyBuilder& build, Proto* proto) } template -static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, AssemblyOptions options) +static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, AssemblyOptions options, LoweringStats* stats) { std::vector protos; gatherFunctions(protos, clvalue(func)->l.p); @@ -66,7 +66,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A if (options.includeAssembly || options.includeIr) logFunctionHeader(build, p); - if (!lowerFunction(ir, build, helpers, p, options)) + if (!lowerFunction(ir, build, helpers, p, options, stats)) { if (build.logText) build.logAppend("; skipping (can't lower)\n"); @@ -90,7 +90,7 @@ static std::string getAssemblyImpl(AssemblyBuilder& build, const TValue* func, A unsigned int getCpuFeaturesA64(); #endif -std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) +std::string getAssembly(lua_State* L, int idx, AssemblyOptions options, LoweringStats* stats) { LUAU_ASSERT(lua_isLfunction(L, idx)); const TValue* func = luaA_toobject(L, idx); @@ -106,35 +106,35 @@ std::string getAssembly(lua_State* L, int idx, AssemblyOptions options) X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly); #endif - return getAssemblyImpl(build, func, options); + return getAssemblyImpl(build, func, options, stats); } case AssemblyOptions::A64: { A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, /* features= */ A64::Feature_JSCVT); - return getAssemblyImpl(build, func, options); + return getAssemblyImpl(build, func, options, stats); } case AssemblyOptions::A64_NoFeatures: { A64::AssemblyBuilderA64 build(/* logText= */ options.includeAssembly, /* features= */ 0); - return getAssemblyImpl(build, func, options); + return getAssemblyImpl(build, func, options, stats); } case AssemblyOptions::X64_Windows: { X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly, X64::ABIX64::Windows); - return getAssemblyImpl(build, func, options); + return getAssemblyImpl(build, func, options, stats); } case AssemblyOptions::X64_SystemV: { X64::AssemblyBuilderX64 build(/* logText= */ options.includeAssembly, X64::ABIX64::SystemV); - return getAssemblyImpl(build, func, options); + return getAssemblyImpl(build, func, options, stats); } default: diff --git a/CodeGen/src/CodeGenLower.h b/CodeGen/src/CodeGenLower.h index a7352bce..9b729a92 100644 --- a/CodeGen/src/CodeGenLower.h +++ b/CodeGen/src/CodeGenLower.h @@ -44,38 +44,10 @@ inline void gatherFunctions(std::vector& results, Proto* proto) gatherFunctions(results, proto->p[i]); } -inline IrBlock& getNextBlock(IrFunction& function, std::vector& sortedBlocks, IrBlock& dummy, size_t i) -{ - for (size_t j = i + 1; j < sortedBlocks.size(); ++j) - { - IrBlock& block = function.blocks[sortedBlocks[j]]; - if (block.kind != IrBlockKind::Dead) - return block; - } - - return dummy; -} - template inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& function, int bytecodeid, AssemblyOptions options) { - // While we will need a better block ordering in the future, right now we want to mostly preserve build order with fallbacks outlined - std::vector sortedBlocks; - sortedBlocks.reserve(function.blocks.size()); - for (uint32_t i = 0; i < function.blocks.size(); i++) - sortedBlocks.push_back(i); - - std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) { - const IrBlock& a = function.blocks[idxA]; - const IrBlock& b = function.blocks[idxB]; - - // Place fallback blocks at the end - if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback)) - return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); - - // Try to order by instruction order - return a.sortkey < b.sortkey; - }); + std::vector sortedBlocks = getSortedBlockOrder(function); // For each IR instruction that begins a bytecode instruction, which bytecode instruction is it? std::vector bcLocations(function.instructions.size() + 1, ~0u); @@ -100,6 +72,9 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& IrBlock dummy; dummy.start = ~0u; + // Make sure entry block is first + LUAU_ASSERT(sortedBlocks[0] == 0); + for (size_t i = 0; i < sortedBlocks.size(); ++i) { uint32_t blockIndex = sortedBlocks[i]; @@ -130,8 +105,18 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& build.setLabel(block.label); + if (blockIndex == function.entryBlock) + { + function.entryLocation = build.getLabelOffset(block.label); + } + IrBlock& nextBlock = getNextBlock(function, sortedBlocks, dummy, i); + // Optimizations often propagate information between blocks + // To make sure the register and spill state is correct when blocks are lowered, we check that sorted block order matches the expected one + if (block.expectedNextBlock != ~0u) + LUAU_ASSERT(function.getBlockIndex(nextBlock) == block.expectedNextBlock); + for (uint32_t index = block.start; index <= block.finish; index++) { LUAU_ASSERT(index < function.instructions.size()); @@ -189,7 +174,7 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& } } - lowering.finishBlock(); + lowering.finishBlock(block, nextBlock); if (options.includeIr) build.logAppend("#\n"); @@ -214,24 +199,26 @@ inline bool lowerImpl(AssemblyBuilder& build, IrLowering& lowering, IrFunction& return true; } -inline bool lowerIr(X64::AssemblyBuilderX64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +inline bool lowerIr( + X64::AssemblyBuilderX64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats) { optimizeMemoryOperandsX64(ir.function); - X64::IrLoweringX64 lowering(build, helpers, ir.function); + X64::IrLoweringX64 lowering(build, helpers, ir.function, stats); return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } -inline bool lowerIr(A64::AssemblyBuilderA64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +inline bool lowerIr( + A64::AssemblyBuilderA64& build, IrBuilder& ir, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats) { - A64::IrLoweringA64 lowering(build, helpers, ir.function); + A64::IrLoweringA64 lowering(build, helpers, ir.function, stats); return lowerImpl(build, lowering, ir.function, proto->bytecodeid, options); } template -inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options) +inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& helpers, Proto* proto, AssemblyOptions options, LoweringStats* stats) { killUnusedBlocks(ir.function); @@ -247,7 +234,7 @@ inline bool lowerFunction(IrBuilder& ir, AssemblyBuilder& build, ModuleHelpers& createLinearBlocks(ir, useValueNumbering); } - return lowerIr(build, ir, helpers, proto, options); + return lowerIr(build, ir, helpers, proto, options, stats); } } // namespace CodeGen diff --git a/CodeGen/src/CodeGenX64.cpp b/CodeGen/src/CodeGenX64.cpp index 1e62a4d4..a8cf2e73 100644 --- a/CodeGen/src/CodeGenX64.cpp +++ b/CodeGen/src/CodeGenX64.cpp @@ -16,10 +16,24 @@ * | rdx home space | (unused) * | rcx home space | (unused) * | return address | - * | ... saved non-volatile registers ... <-- rsp + kStackSize + kLocalsSize - * | unused | for 16 byte alignment of the stack + * | ... saved non-volatile registers ... <-- rsp + kStackSizeFull + * | alignment | + * | xmm9 non-vol | + * | xmm9 cont. | + * | xmm8 non-vol | + * | xmm8 cont. | + * | xmm7 non-vol | + * | xmm7 cont. | + * | xmm6 non-vol | + * | xmm6 cont. | + * | spill slot 5 | + * | spill slot 4 | + * | spill slot 3 | + * | spill slot 2 | + * | spill slot 1 | <-- rsp + kStackOffsetToSpillSlots + * | sTemporarySlot | * | sCode | - * | sClosure | <-- rsp + kStackSize + * | sClosure | <-- rsp + kStackOffsetToLocals * | argument 6 | <-- rsp + 40 * | argument 5 | <-- rsp + 32 * | r9 home space | @@ -81,24 +95,43 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde build.push(rdi); build.push(rsi); - // On Windows, rbp is available as a general-purpose non-volatile register; we currently don't use it, but we need to push an even number - // of registers for stack alignment... + // On Windows, rbp is available as a general-purpose non-volatile register and this might be freed up build.push(rbp); - - // TODO: once we start using non-volatile SIMD registers on Windows, we will save those here } - // Allocate stack space (reg home area + local data) - build.sub(rsp, kStackSize + kLocalsSize); + // Allocate stack space + uint8_t usableXmmRegCount = getXmmRegisterCount(build.abi); + unsigned xmmStorageSize = getNonVolXmmStorageSize(build.abi, usableXmmRegCount); + unsigned fullStackSize = getFullStackSize(build.abi, usableXmmRegCount); + + build.sub(rsp, fullStackSize); + + OperandX64 xmmStorageOffset = rsp + (fullStackSize - (kStackAlign + xmmStorageSize)); + + // On Windows, we have to save non-volatile xmm registers + std::vector savedXmmRegs; + + if (build.abi == ABIX64::Windows) + { + if (usableXmmRegCount > kWindowsFirstNonVolXmmReg) + savedXmmRegs.reserve(usableXmmRegCount - kWindowsFirstNonVolXmmReg); + + for (uint8_t i = kWindowsFirstNonVolXmmReg, offset = 0; i < usableXmmRegCount; i++, offset += 16) + { + RegisterX64 xmmReg = RegisterX64{SizeX64::xmmword, i}; + build.vmovaps(xmmword[xmmStorageOffset + offset], xmmReg); + savedXmmRegs.push_back(xmmReg); + } + } locations.prologueEnd = build.setLabel(); uint32_t prologueSize = build.getLabelOffset(locations.prologueEnd) - build.getLabelOffset(locations.start); if (build.abi == ABIX64::SystemV) - unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ true, {rbx, r12, r13, r14, r15}); + unwind.prologueX64(prologueSize, fullStackSize, /* setupFrame= */ true, {rbx, r12, r13, r14, r15}, {}); else if (build.abi == ABIX64::Windows) - unwind.prologueX64(prologueSize, kStackSize + kLocalsSize, /* setupFrame= */ false, {rbx, r12, r13, r14, r15, rdi, rsi, rbp}); + unwind.prologueX64(prologueSize, fullStackSize, /* setupFrame= */ false, {rbx, r12, r13, r14, r15, rdi, rsi, rbp}, savedXmmRegs); // Setup native execution environment build.mov(rState, rArg1); @@ -118,8 +151,15 @@ static EntryLocations buildEntryFunction(AssemblyBuilderX64& build, UnwindBuilde // Even though we jumped away, we will return here in the end locations.epilogueStart = build.setLabel(); - // Cleanup and exit - build.add(rsp, kStackSize + kLocalsSize); + // Epilogue and exit + if (build.abi == ABIX64::Windows) + { + // xmm registers are restored before the official epilogue that has to start with 'add rsp/lea rsp' + for (uint8_t i = kWindowsFirstNonVolXmmReg, offset = 0; i < usableXmmRegCount; i++, offset += 16) + build.vmovaps(RegisterX64{SizeX64::xmmword, i}, xmmword[xmmStorageOffset + offset]); + } + + build.add(rsp, fullStackSize); if (build.abi == ABIX64::Windows) { @@ -180,6 +220,11 @@ bool initHeaderFunctions(NativeState& data) void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) { + if (build.logText) + build.logAppend("; updatePcAndContinueInVm\n"); + build.setLabel(helpers.updatePcAndContinueInVm); + emitUpdatePcForExit(build); + if (build.logText) build.logAppend("; exitContinueVmClearNativeFlag\n"); build.setLabel(helpers.exitContinueVmClearNativeFlag); @@ -195,11 +240,6 @@ void assembleHelpers(X64::AssemblyBuilderX64& build, ModuleHelpers& helpers) build.setLabel(helpers.exitNoContinueVm); emitExit(build, /* continueInVm */ false); - if (build.logText) - build.logAppend("; updatePcAndContinueInVm\n"); - build.setLabel(helpers.updatePcAndContinueInVm); - emitUpdatePcAndContinueInVm(build); - if (build.logText) build.logAppend("; continueCallInVm\n"); build.setLabel(helpers.continueCallInVm); diff --git a/CodeGen/src/EmitBuiltinsX64.cpp b/CodeGen/src/EmitBuiltinsX64.cpp index efc480e0..87e4e795 100644 --- a/CodeGen/src/EmitBuiltinsX64.cpp +++ b/CodeGen/src/EmitBuiltinsX64.cpp @@ -1,8 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "EmitBuiltinsX64.h" -#include "Luau/AssemblyBuilderX64.h" #include "Luau/Bytecode.h" + +#include "Luau/AssemblyBuilderX64.h" #include "Luau/IrCallWrapperX64.h" #include "Luau/IrRegAllocX64.h" diff --git a/CodeGen/src/EmitCommon.h b/CodeGen/src/EmitCommon.h index 214cfd6d..8ac746ee 100644 --- a/CodeGen/src/EmitCommon.h +++ b/CodeGen/src/EmitCommon.h @@ -25,7 +25,7 @@ struct ModuleHelpers Label exitContinueVm; Label exitNoContinueVm; Label exitContinueVmClearNativeFlag; - Label updatePcAndContinueInVm; + Label updatePcAndContinueInVm; // no reentry Label return_; Label interrupt; diff --git a/CodeGen/src/EmitCommonA64.h b/CodeGen/src/EmitCommonA64.h index 9e89b1c0..894570d9 100644 --- a/CodeGen/src/EmitCommonA64.h +++ b/CodeGen/src/EmitCommonA64.h @@ -31,23 +31,24 @@ namespace A64 // 1. Constant registers (only loaded during codegen entry) constexpr RegisterA64 rState = x19; // lua_State* L constexpr RegisterA64 rNativeContext = x20; // NativeContext* context +constexpr RegisterA64 rGlobalState = x21; // global_State* L->global // 2. Frame registers (reloaded when call frame changes; rBase is also reloaded after all calls that may reallocate stack) -constexpr RegisterA64 rConstants = x21; // TValue* k -constexpr RegisterA64 rClosure = x22; // Closure* cl -constexpr RegisterA64 rCode = x23; // Instruction* code -constexpr RegisterA64 rBase = x24; // StkId base +constexpr RegisterA64 rConstants = x22; // TValue* k +constexpr RegisterA64 rClosure = x23; // Closure* cl +constexpr RegisterA64 rCode = x24; // Instruction* code +constexpr RegisterA64 rBase = x25; // StkId base // Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point // See CodeGenA64.cpp for layout -constexpr unsigned kStashSlots = 8; // stashed non-volatile registers +constexpr unsigned kStashSlots = 9; // stashed non-volatile registers +constexpr unsigned kTempSlots = 1; // 8 bytes of temporary space, such luxury! constexpr unsigned kSpillSlots = 22; // slots for spilling temporary registers -constexpr unsigned kTempSlots = 2; // 16 bytes of temporary space, such luxury! -constexpr unsigned kStackSize = (kStashSlots + kSpillSlots + kTempSlots) * 8; +constexpr unsigned kStackSize = (kStashSlots + kTempSlots + kSpillSlots) * 8; -constexpr AddressA64 sSpillArea = mem(sp, kStashSlots * 8); -constexpr AddressA64 sTemporary = mem(sp, (kStashSlots + kSpillSlots) * 8); +constexpr AddressA64 sSpillArea = mem(sp, (kStashSlots + kTempSlots) * 8); +constexpr AddressA64 sTemporary = mem(sp, kStashSlots * 8); inline void emitUpdateBase(AssemblyBuilderA64& build) { diff --git a/CodeGen/src/EmitCommonX64.cpp b/CodeGen/src/EmitCommonX64.cpp index e6fae4cc..97749fbe 100644 --- a/CodeGen/src/EmitCommonX64.cpp +++ b/CodeGen/src/EmitCommonX64.cpp @@ -12,6 +12,8 @@ #include "lgc.h" #include "lstate.h" +#include + namespace Luau { namespace CodeGen @@ -22,10 +24,15 @@ namespace X64 void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label) { // Refresher on comi/ucomi EFLAGS: + // all zero: greater // CF only: less // ZF only: equal // PF+CF+ZF: unordered (NaN) + // To avoid the lack of conditional jumps that check for "greater" conditions in IEEE 754 compliant way, we use "less" forms to emulate these + if (cond == IrCondition::Greater || cond == IrCondition::GreaterEqual || cond == IrCondition::NotGreater || cond == IrCondition::NotGreaterEqual) + std::swap(lhs, rhs); + if (rhs.cat == CategoryX64::reg) { build.vucomisd(rhs, lhs); @@ -41,18 +48,22 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, switch (cond) { case IrCondition::NotLessEqual: + case IrCondition::NotGreaterEqual: // (b < a) is the same as !(a <= b). jnae checks CF=1 which means < or NaN build.jcc(ConditionX64::NotAboveEqual, label); break; case IrCondition::LessEqual: + case IrCondition::GreaterEqual: // (b >= a) is the same as (a <= b). jae checks CF=0 which means >= and not NaN build.jcc(ConditionX64::AboveEqual, label); break; case IrCondition::NotLess: + case IrCondition::NotGreater: // (b <= a) is the same as !(a < b). jna checks CF=1 or ZF=1 which means <= or NaN build.jcc(ConditionX64::NotAbove, label); break; case IrCondition::Less: + case IrCondition::Greater: // (b > a) is the same as (a < b). ja checks CF=0 and ZF=0 which means > and not NaN build.jcc(ConditionX64::Above, label); break; @@ -66,6 +77,44 @@ void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, } } +ConditionX64 getConditionInt(IrCondition cond) +{ + switch (cond) + { + case IrCondition::Equal: + return ConditionX64::Equal; + case IrCondition::NotEqual: + return ConditionX64::NotEqual; + case IrCondition::Less: + return ConditionX64::Less; + case IrCondition::NotLess: + return ConditionX64::NotLess; + case IrCondition::LessEqual: + return ConditionX64::LessEqual; + case IrCondition::NotLessEqual: + return ConditionX64::NotLessEqual; + case IrCondition::Greater: + return ConditionX64::Greater; + case IrCondition::NotGreater: + return ConditionX64::NotGreater; + case IrCondition::GreaterEqual: + return ConditionX64::GreaterEqual; + case IrCondition::NotGreaterEqual: + return ConditionX64::NotGreaterEqual; + case IrCondition::UnsignedLess: + return ConditionX64::Below; + case IrCondition::UnsignedLessEqual: + return ConditionX64::BelowEqual; + case IrCondition::UnsignedGreater: + return ConditionX64::Above; + case IrCondition::UnsignedGreaterEqual: + return ConditionX64::AboveEqual; + default: + LUAU_ASSERT(!"Unsupported condition"); + return ConditionX64::Zero; + } +} + void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos) { LUAU_ASSERT(tmp != node); @@ -123,16 +172,6 @@ void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, in emitUpdateBase(build); } -void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init) -{ - IrCallWrapperX64 callWrap(regs, build); - callWrap.addArgument(SizeX64::qword, rState); - callWrap.addArgument(SizeX64::qword, luauRegAddress(limit)); - callWrap.addArgument(SizeX64::qword, luauRegAddress(step)); - callWrap.addArgument(SizeX64::qword, luauRegAddress(init)); - callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaV_prepareFORN)]); -} - void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra) { IrCallWrapperX64 callWrap(regs, build); @@ -157,13 +196,14 @@ void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, Operan emitUpdateBase(build); } -void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, int ratag, Label& skip) +void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, IrOp ra, int ratag, Label& skip) { // Barrier should've been optimized away if we know that it's not collectable, checking for correctness if (ratag == -1 || !isGCO(ratag)) { // iscollectable(ra) - build.cmp(luauRegTag(ra), LUA_TSTRING); + OperandX64 tag = (ra.kind == IrOpKind::VmReg) ? luauRegTag(vmRegOp(ra)) : luauConstantTag(vmConstOp(ra)); + build.cmp(tag, LUA_TSTRING); build.jcc(ConditionX64::Less, skip); } @@ -172,12 +212,14 @@ void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, Re build.jcc(ConditionX64::Zero, skip); // iswhite(gcvalue(ra)) - build.mov(tmp, luauRegValue(ra)); + OperandX64 value = (ra.kind == IrOpKind::VmReg) ? luauRegValue(vmRegOp(ra)) : luauConstantValue(vmConstOp(ra)); + build.mov(tmp, value); build.test(byte[tmp + offsetof(GCheader, marked)], bit2mask(WHITE0BIT, WHITE1BIT)); build.jcc(ConditionX64::Zero, skip); } -void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, int ratag) + +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, IrOp ra, int ratag) { Label skip; @@ -328,14 +370,12 @@ void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, in emitUpdateBase(build); } -void emitUpdatePcAndContinueInVm(AssemblyBuilderX64& build) +void emitUpdatePcForExit(AssemblyBuilderX64& build) { // edx = pcpos * sizeof(Instruction) build.add(rdx, sCode); build.mov(rax, qword[rState + offsetof(lua_State, ci)]); build.mov(qword[rax + offsetof(CallInfo, savedpc)], rdx); - - emitExit(build, /* continueInVm */ true); } void emitContinueCallInVm(AssemblyBuilderX64& build) diff --git a/CodeGen/src/EmitCommonX64.h b/CodeGen/src/EmitCommonX64.h index d8c68da4..3d9a59ff 100644 --- a/CodeGen/src/EmitCommonX64.h +++ b/CodeGen/src/EmitCommonX64.h @@ -42,16 +42,55 @@ constexpr RegisterX64 rBase = r14; // StkId base constexpr RegisterX64 rNativeContext = r13; // NativeContext* context constexpr RegisterX64 rConstants = r12; // TValue* k -// Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point -// See CodeGenX64.cpp for layout -constexpr unsigned kStackSize = 32 + 16; // 4 home locations for registers, 16 bytes for additional function call arguments -constexpr unsigned kSpillSlots = 4; // locations for register allocator to spill data into -constexpr unsigned kLocalsSize = 24 + 8 * kSpillSlots; // 3 extra slots for our custom locals (also aligns the stack to 16 byte boundary) +constexpr unsigned kExtraLocals = 3; // Number of 8 byte slots available for specialized local variables specified below +constexpr unsigned kSpillSlots = 5; // Number of 8 byte slots available for register allocator to spill data into +static_assert((kExtraLocals + kSpillSlots) * 8 % 16 == 0, "locals have to preserve 16 byte alignment"); -constexpr OperandX64 sClosure = qword[rsp + kStackSize + 0]; // Closure* cl -constexpr OperandX64 sCode = qword[rsp + kStackSize + 8]; // Instruction* code -constexpr OperandX64 sTemporarySlot = addr[rsp + kStackSize + 16]; -constexpr OperandX64 sSpillArea = addr[rsp + kStackSize + 24]; +constexpr uint8_t kWindowsFirstNonVolXmmReg = 6; + +constexpr uint8_t kWindowsUsableXmmRegs = 10; // Some xmm regs are non-volatile, we have to balance how many we want to use/preserve +constexpr uint8_t kSystemVUsableXmmRegs = 16; // All xmm regs are volatile + +inline uint8_t getXmmRegisterCount(ABIX64 abi) +{ + return abi == ABIX64::SystemV ? kSystemVUsableXmmRegs : kWindowsUsableXmmRegs; +} + +// Native code is as stackless as the interpreter, so we can place some data on the stack once and have it accessible at any point +// Stack is separated into sections for different data. See CodeGenX64.cpp for layout overview +constexpr unsigned kStackAlign = 8; // Bytes we need to align the stack for non-vol xmm register storage +constexpr unsigned kStackLocalStorage = 8 * kExtraLocals; +constexpr unsigned kStackSpillStorage = 8 * kSpillSlots; +constexpr unsigned kStackExtraArgumentStorage = 2 * 8; // Bytes for 5th and 6th function call arguments used under Windows ABI +constexpr unsigned kStackRegHomeStorage = 4 * 8; // Register 'home' locations that can be used by callees under Windows ABI + +inline unsigned getNonVolXmmStorageSize(ABIX64 abi, uint8_t xmmRegCount) +{ + if (abi == ABIX64::SystemV) + return 0; + + // First 6 are volatile + if (xmmRegCount <= kWindowsFirstNonVolXmmReg) + return 0; + + LUAU_ASSERT(xmmRegCount <= 16); + return (xmmRegCount - kWindowsFirstNonVolXmmReg) * 16; +} + +// Useful offsets to specific parts +constexpr unsigned kStackOffsetToLocals = kStackExtraArgumentStorage + kStackRegHomeStorage; +constexpr unsigned kStackOffsetToSpillSlots = kStackOffsetToLocals + kStackLocalStorage; + +inline unsigned getFullStackSize(ABIX64 abi, uint8_t xmmRegCount) +{ + return kStackOffsetToSpillSlots + kStackSpillStorage + getNonVolXmmStorageSize(abi, xmmRegCount) + kStackAlign; +} + +constexpr OperandX64 sClosure = qword[rsp + kStackOffsetToLocals + 0]; // Closure* cl +constexpr OperandX64 sCode = qword[rsp + kStackOffsetToLocals + 8]; // Instruction* code +constexpr OperandX64 sTemporarySlot = addr[rsp + kStackOffsetToLocals + 16]; + +constexpr OperandX64 sSpillArea = addr[rsp + kStackOffsetToSpillSlots]; inline OperandX64 luauReg(int ri) { @@ -114,11 +153,6 @@ inline OperandX64 luauNodeKeyTag(RegisterX64 node) return dword[node + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext]; } -inline OperandX64 luauNodeValue(RegisterX64 node) -{ - return xmmword[node + offsetof(LuaNode, val)]; -} - inline void setLuauReg(AssemblyBuilderX64& build, RegisterX64 tmp, int ri, OperandX64 op) { LUAU_ASSERT(op.cat == CategoryX64::mem); @@ -161,16 +195,17 @@ inline void jumpIfTruthy(AssemblyBuilderX64& build, int ri, Label& target, Label void jumpOnNumberCmp(AssemblyBuilderX64& build, RegisterX64 tmp, OperandX64 lhs, OperandX64 rhs, IrCondition cond, Label& label); +ConditionX64 getConditionInt(IrCondition cond); + void getTableNodeAtCachedSlot(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 node, RegisterX64 table, int pcpos); void convertNumberToIndexOrJump(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 numd, RegisterX64 numi, Label& label); void callArithHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, OperandX64 c, TMS tm); void callLengthHelper(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb); -void callPrepareForN(IrRegAllocX64& regs, AssemblyBuilderX64& build, int limit, int step, int init); void callGetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); void callSetTable(IrRegAllocX64& regs, AssemblyBuilderX64& build, int rb, OperandX64 c, int ra); -void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, int ra, int ratag, Label& skip); -void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, int ra, int ratag); +void checkObjectBarrierConditions(AssemblyBuilderX64& build, RegisterX64 tmp, RegisterX64 object, IrOp ra, int ratag, Label& skip); +void callBarrierObject(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 object, IrOp objectOp, IrOp ra, int ratag); void callBarrierTableFast(IrRegAllocX64& regs, AssemblyBuilderX64& build, RegisterX64 table, IrOp tableOp); void callStepGc(IrRegAllocX64& regs, AssemblyBuilderX64& build); @@ -180,7 +215,7 @@ void emitUpdateBase(AssemblyBuilderX64& build); void emitInterrupt(AssemblyBuilderX64& build); void emitFallback(IrRegAllocX64& regs, AssemblyBuilderX64& build, int offset, int pcpos); -void emitUpdatePcAndContinueInVm(AssemblyBuilderX64& build); +void emitUpdatePcForExit(AssemblyBuilderX64& build); void emitContinueCallInVm(AssemblyBuilderX64& build); void emitReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers); diff --git a/CodeGen/src/EmitInstructionX64.cpp b/CodeGen/src/EmitInstructionX64.cpp index ea511958..bccdc8f0 100644 --- a/CodeGen/src/EmitInstructionX64.cpp +++ b/CodeGen/src/EmitInstructionX64.cpp @@ -251,7 +251,7 @@ void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, i } } -void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index) +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index, int knownSize) { // TODO: This should use IrCallWrapperX64 RegisterX64 rArg1 = (build.abi == ABIX64::Windows) ? rcx : rdi; @@ -285,25 +285,28 @@ void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int build.add(last, index - 1); } - Label skipResize; - RegisterX64 table = regs.takeReg(rax, kInvalidInstIdx); build.mov(table, luauRegValue(ra)); - // Resize if h->sizearray < last - build.cmp(dword[table + offsetof(Table, sizearray)], last); - build.jcc(ConditionX64::NotBelow, skipResize); + if (count == LUA_MULTRET || knownSize < 0 || knownSize < int(index + count - 1)) + { + Label skipResize; - // Argument setup reordered to avoid conflicts - LUAU_ASSERT(rArg3 != table); - build.mov(dwordReg(rArg3), last); - build.mov(rArg2, table); - build.mov(rArg1, rState); - build.call(qword[rNativeContext + offsetof(NativeContext, luaH_resizearray)]); - build.mov(table, luauRegValue(ra)); // Reload cloberred register value + // Resize if h->sizearray < last + build.cmp(dword[table + offsetof(Table, sizearray)], last); + build.jcc(ConditionX64::NotBelow, skipResize); - build.setLabel(skipResize); + // Argument setup reordered to avoid conflicts + LUAU_ASSERT(rArg3 != table); + build.mov(dwordReg(rArg3), last); + build.mov(rArg2, table); + build.mov(rArg1, rState); + build.call(qword[rNativeContext + offsetof(NativeContext, luaH_resizearray)]); + build.mov(table, luauRegValue(ra)); // Reload clobbered register value + + build.setLabel(skipResize); + } RegisterX64 arrayDst = rdx; RegisterX64 offset = rcx; diff --git a/CodeGen/src/EmitInstructionX64.h b/CodeGen/src/EmitInstructionX64.h index b248b7e8..59fd8e41 100644 --- a/CodeGen/src/EmitInstructionX64.h +++ b/CodeGen/src/EmitInstructionX64.h @@ -19,7 +19,7 @@ struct IrRegAllocX64; void emitInstCall(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int nparams, int nresults); void emitInstReturn(AssemblyBuilderX64& build, ModuleHelpers& helpers, int ra, int actualResults, bool functionVariadic); -void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index); +void emitInstSetList(IrRegAllocX64& regs, AssemblyBuilderX64& build, int ra, int rb, int count, uint32_t index, int knownSize); void emitInstForGLoop(AssemblyBuilderX64& build, int ra, int aux, Label& loopRepeat); } // namespace X64 diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index 23f2dd21..b14f1470 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -186,75 +186,12 @@ void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, ui } } -static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& block, RegisterSet& defRs, std::bitset<256>& capturedRegs) +template +static void visitVmRegDefsUses(T& visitor, IrFunction& function, const IrBlock& block) { - RegisterSet inRs; - - auto def = [&](IrOp op, int offset = 0) { - defRs.regs.set(vmRegOp(op) + offset, true); - }; - - auto use = [&](IrOp op, int offset = 0) { - if (!defRs.regs.test(vmRegOp(op) + offset)) - inRs.regs.set(vmRegOp(op) + offset, true); - }; - - auto maybeDef = [&](IrOp op) { - if (op.kind == IrOpKind::VmReg) - defRs.regs.set(vmRegOp(op), true); - }; - - auto maybeUse = [&](IrOp op) { - if (op.kind == IrOpKind::VmReg) - { - if (!defRs.regs.test(vmRegOp(op))) - inRs.regs.set(vmRegOp(op), true); - } - }; - - auto defVarargs = [&](uint8_t varargStart) { - defRs.varargSeq = true; - defRs.varargStart = varargStart; - }; - - auto useVarargs = [&](uint8_t varargStart) { - requireVariadicSequence(inRs, defRs, varargStart); - - // Variadic sequence has been consumed - defRs.varargSeq = false; - defRs.varargStart = 0; - }; - - auto defRange = [&](int start, int count) { - if (count == -1) - { - defVarargs(start); - } - else - { - for (int i = start; i < start + count; i++) - defRs.regs.set(i, true); - } - }; - - auto useRange = [&](int start, int count) { - if (count == -1) - { - useVarargs(start); - } - else - { - for (int i = start; i < start + count; i++) - { - if (!defRs.regs.test(i)) - inRs.regs.set(i, true); - } - } - }; - for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) { - const IrInst& inst = function.instructions[instIdx]; + IrInst& inst = function.instructions[instIdx]; // For correct analysis, all instruction uses must be handled before handling the definitions switch (inst.cmd) @@ -264,7 +201,7 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::LOAD_DOUBLE: case IrCmd::LOAD_INT: case IrCmd::LOAD_TVALUE: - maybeUse(inst.a); // Argument can also be a VmConst + visitor.maybeUse(inst.a); // Argument can also be a VmConst break; case IrCmd::STORE_TAG: case IrCmd::STORE_POINTER: @@ -272,63 +209,55 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& case IrCmd::STORE_INT: case IrCmd::STORE_VECTOR: case IrCmd::STORE_TVALUE: - maybeDef(inst.a); // Argument can also be a pointer value + case IrCmd::STORE_SPLIT_TVALUE: + visitor.maybeDef(inst.a); // Argument can also be a pointer value break; case IrCmd::CMP_ANY: - use(inst.a); - use(inst.b); + visitor.use(inst.a); + visitor.use(inst.b); break; case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_FALSY: - use(inst.a); + visitor.use(inst.a); break; // A <- B, C case IrCmd::DO_ARITH: case IrCmd::GET_TABLE: - use(inst.b); - maybeUse(inst.c); // Argument can also be a VmConst + visitor.use(inst.b); + visitor.maybeUse(inst.c); // Argument can also be a VmConst - def(inst.a); + visitor.def(inst.a); break; case IrCmd::SET_TABLE: - use(inst.a); - use(inst.b); - maybeUse(inst.c); // Argument can also be a VmConst + visitor.use(inst.a); + visitor.use(inst.b); + visitor.maybeUse(inst.c); // Argument can also be a VmConst break; // A <- B case IrCmd::DO_LEN: - use(inst.b); + visitor.use(inst.b); - def(inst.a); + visitor.def(inst.a); break; case IrCmd::GET_IMPORT: - def(inst.a); + visitor.def(inst.a); break; case IrCmd::CONCAT: - useRange(vmRegOp(inst.a), function.uintOp(inst.b)); + visitor.useRange(vmRegOp(inst.a), function.uintOp(inst.b)); - defRange(vmRegOp(inst.a), function.uintOp(inst.b)); + visitor.defRange(vmRegOp(inst.a), function.uintOp(inst.b)); break; case IrCmd::GET_UPVALUE: - def(inst.a); + visitor.def(inst.a); break; case IrCmd::SET_UPVALUE: - use(inst.b); - break; - case IrCmd::PREPARE_FORN: - use(inst.a); - use(inst.b); - use(inst.c); - - def(inst.a); - def(inst.b); - def(inst.c); + visitor.use(inst.b); break; case IrCmd::INTERRUPT: break; case IrCmd::BARRIER_OBJ: case IrCmd::BARRIER_TABLE_FORWARD: - use(inst.b); + visitor.maybeUse(inst.b); break; case IrCmd::CLOSE_UPVALS: // Closing an upvalue should be counted as a register use (it copies the fresh register value) @@ -336,23 +265,23 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& // Because we don't plan to optimize captured registers atm, we skip full dataflow analysis for them right now break; case IrCmd::CAPTURE: - maybeUse(inst.a); + visitor.maybeUse(inst.a); if (function.uintOp(inst.b) == 1) - capturedRegs.set(vmRegOp(inst.a), true); + visitor.capture(vmRegOp(inst.a)); break; case IrCmd::SETLIST: - use(inst.b); - useRange(vmRegOp(inst.c), function.intOp(inst.d)); + visitor.use(inst.b); + visitor.useRange(vmRegOp(inst.c), function.intOp(inst.d)); break; case IrCmd::CALL: - use(inst.a); - useRange(vmRegOp(inst.a) + 1, function.intOp(inst.b)); + visitor.use(inst.a); + visitor.useRange(vmRegOp(inst.a) + 1, function.intOp(inst.b)); - defRange(vmRegOp(inst.a), function.intOp(inst.c)); + visitor.defRange(vmRegOp(inst.a), function.intOp(inst.c)); break; case IrCmd::RETURN: - useRange(vmRegOp(inst.a), function.intOp(inst.b)); + visitor.useRange(vmRegOp(inst.a), function.intOp(inst.b)); break; // TODO: FASTCALL is more restrictive than INVOKE_FASTCALL; we should either determine the exact semantics, or rework it @@ -364,89 +293,89 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& { LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && vmRegOp(inst.d) == vmRegOp(inst.c) + 1); - useRange(vmRegOp(inst.c), count); + visitor.useRange(vmRegOp(inst.c), count); } else { if (count >= 1) - use(inst.c); + visitor.use(inst.c); if (count >= 2) - maybeUse(inst.d); // Argument can also be a VmConst + visitor.maybeUse(inst.d); // Argument can also be a VmConst } } else { - useVarargs(vmRegOp(inst.c)); + visitor.useVarargs(vmRegOp(inst.c)); } // Multiple return sequences (count == -1) are defined by ADJUST_STACK_TO_REG if (int count = function.intOp(inst.f); count != -1) - defRange(vmRegOp(inst.b), count); + visitor.defRange(vmRegOp(inst.b), count); break; case IrCmd::FORGLOOP: // First register is not used by instruction, we check that it's still 'nil' with CHECK_TAG - use(inst.a, 1); - use(inst.a, 2); + visitor.use(inst.a, 1); + visitor.use(inst.a, 2); - def(inst.a, 2); - defRange(vmRegOp(inst.a) + 3, function.intOp(inst.b)); + visitor.def(inst.a, 2); + visitor.defRange(vmRegOp(inst.a) + 3, function.intOp(inst.b)); break; case IrCmd::FORGLOOP_FALLBACK: - useRange(vmRegOp(inst.a), 3); + visitor.useRange(vmRegOp(inst.a), 3); - def(inst.a, 2); - defRange(vmRegOp(inst.a) + 3, uint8_t(function.intOp(inst.b))); // ignore most significant bit + visitor.def(inst.a, 2); + visitor.defRange(vmRegOp(inst.a) + 3, uint8_t(function.intOp(inst.b))); // ignore most significant bit break; case IrCmd::FORGPREP_XNEXT_FALLBACK: - use(inst.b); + visitor.use(inst.b); break; case IrCmd::FALLBACK_GETGLOBAL: - def(inst.b); + visitor.def(inst.b); break; case IrCmd::FALLBACK_SETGLOBAL: - use(inst.b); + visitor.use(inst.b); break; case IrCmd::FALLBACK_GETTABLEKS: - use(inst.c); + visitor.use(inst.c); - def(inst.b); + visitor.def(inst.b); break; case IrCmd::FALLBACK_SETTABLEKS: - use(inst.b); - use(inst.c); + visitor.use(inst.b); + visitor.use(inst.c); break; case IrCmd::FALLBACK_NAMECALL: - use(inst.c); + visitor.use(inst.c); - defRange(vmRegOp(inst.b), 2); + visitor.defRange(vmRegOp(inst.b), 2); break; case IrCmd::FALLBACK_PREPVARARGS: // No effect on explicitly referenced registers break; case IrCmd::FALLBACK_GETVARARGS: - defRange(vmRegOp(inst.b), function.intOp(inst.c)); + visitor.defRange(vmRegOp(inst.b), function.intOp(inst.c)); break; case IrCmd::FALLBACK_DUPCLOSURE: - def(inst.b); + visitor.def(inst.b); break; case IrCmd::FALLBACK_FORGPREP: - use(inst.b); + visitor.use(inst.b); - defRange(vmRegOp(inst.b), 3); + visitor.defRange(vmRegOp(inst.b), 3); break; case IrCmd::ADJUST_STACK_TO_REG: - defRange(vmRegOp(inst.a), -1); + visitor.defRange(vmRegOp(inst.a), -1); break; case IrCmd::ADJUST_STACK_TO_TOP: // While this can be considered to be a vararg consumer, it is already handled in fastcall instructions break; case IrCmd::GET_TYPEOF: - use(inst.a); + visitor.use(inst.a); break; case IrCmd::FINDUPVAL: - use(inst.a); + visitor.use(inst.a); break; default: @@ -460,8 +389,102 @@ static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& break; } } +} - return inRs; +struct BlockVmRegLiveInComputation +{ + BlockVmRegLiveInComputation(RegisterSet& defRs, std::bitset<256>& capturedRegs) + : defRs(defRs) + , capturedRegs(capturedRegs) + { + } + + RegisterSet& defRs; + std::bitset<256>& capturedRegs; + + RegisterSet inRs; + + void def(IrOp op, int offset = 0) + { + defRs.regs.set(vmRegOp(op) + offset, true); + } + + void use(IrOp op, int offset = 0) + { + if (!defRs.regs.test(vmRegOp(op) + offset)) + inRs.regs.set(vmRegOp(op) + offset, true); + } + + void maybeDef(IrOp op) + { + if (op.kind == IrOpKind::VmReg) + defRs.regs.set(vmRegOp(op), true); + } + + void maybeUse(IrOp op) + { + if (op.kind == IrOpKind::VmReg) + { + if (!defRs.regs.test(vmRegOp(op))) + inRs.regs.set(vmRegOp(op), true); + } + } + + void defVarargs(uint8_t varargStart) + { + defRs.varargSeq = true; + defRs.varargStart = varargStart; + } + + void useVarargs(uint8_t varargStart) + { + requireVariadicSequence(inRs, defRs, varargStart); + + // Variadic sequence has been consumed + defRs.varargSeq = false; + defRs.varargStart = 0; + } + + void defRange(int start, int count) + { + if (count == -1) + { + defVarargs(start); + } + else + { + for (int i = start; i < start + count; i++) + defRs.regs.set(i, true); + } + } + + void useRange(int start, int count) + { + if (count == -1) + { + useVarargs(start); + } + else + { + for (int i = start; i < start + count; i++) + { + if (!defRs.regs.test(i)) + inRs.regs.set(i, true); + } + } + } + + void capture(int reg) + { + capturedRegs.set(reg, true); + } +}; + +static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& block, RegisterSet& defRs, std::bitset<256>& capturedRegs) +{ + BlockVmRegLiveInComputation visitor(defRs, capturedRegs); + visitVmRegDefsUses(visitor, function, block); + return visitor.inRs; } // The algorithm used here is commonly known as backwards data-flow analysis. @@ -866,6 +889,79 @@ void computeCfgDominanceTreeChildren(IrFunction& function) computeBlockOrdering(function, info.domOrdering, /* preOrder */ nullptr, /* postOrder */ nullptr); } +// This algorithm is based on 'A Linear Time Algorithm for Placing Phi-Nodes' [Vugranam C.Sreedhar] +// It uses the optimized form from LLVM that relies an implicit DJ-graph (join edges are edges of the CFG that are not part of the dominance tree) +void computeIteratedDominanceFrontierForDefs( + IdfContext& ctx, const IrFunction& function, const std::vector& defBlocks, const std::vector& liveInBlocks) +{ + LUAU_ASSERT(!function.cfg.domOrdering.empty()); + + LUAU_ASSERT(ctx.queue.empty()); + LUAU_ASSERT(ctx.worklist.empty()); + + ctx.idf.clear(); + + ctx.visits.clear(); + ctx.visits.resize(function.blocks.size()); + + for (uint32_t defBlock : defBlocks) + { + const BlockOrdering& ordering = function.cfg.domOrdering[defBlock]; + ctx.queue.push({defBlock, ordering}); + } + + while (!ctx.queue.empty()) + { + IdfContext::BlockAndOrdering root = ctx.queue.top(); + ctx.queue.pop(); + + LUAU_ASSERT(ctx.worklist.empty()); + ctx.worklist.push_back(root.blockIdx); + ctx.visits[root.blockIdx].seenInWorklist = true; + + while (!ctx.worklist.empty()) + { + uint32_t blockIdx = ctx.worklist.back(); + ctx.worklist.pop_back(); + + // Check if successor node is the node where dominance of the current root ends, making it a part of dominance frontier set + for (uint32_t succIdx : successors(function.cfg, blockIdx)) + { + const BlockOrdering& succOrdering = function.cfg.domOrdering[succIdx]; + + // Nodes in the DF of root always have a level that is less than or equal to the level of root + if (succOrdering.depth > root.ordering.depth) + continue; + + if (ctx.visits[succIdx].seenInQueue) + continue; + + ctx.visits[succIdx].seenInQueue = true; + + // Skip successor block if it doesn't have our variable as a live in there + if (std::find(liveInBlocks.begin(), liveInBlocks.end(), succIdx) == liveInBlocks.end()) + continue; + + ctx.idf.push_back(succIdx); + + // If block doesn't have its own definition of the variable, add it to the queue + if (std::find(defBlocks.begin(), defBlocks.end(), succIdx) == defBlocks.end()) + ctx.queue.push({succIdx, succOrdering}); + } + + // Add dominance tree children that haven't been processed yet to the worklist + for (uint32_t domChildIdx : domChildren(function.cfg, blockIdx)) + { + if (ctx.visits[domChildIdx].seenInWorklist) + continue; + + ctx.visits[domChildIdx].seenInWorklist = true; + ctx.worklist.push_back(domChildIdx); + } + } + } +} + void computeCfgInfo(IrFunction& function) { computeCfgBlockEdges(function); diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 52d0a0b5..1aca2993 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -1,7 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrBuilder.h" -#include "Luau/IrAnalysis.h" +#include "Luau/Bytecode.h" +#include "Luau/BytecodeUtils.h" +#include "Luau/IrData.h" #include "Luau/IrUtils.h" #include "IrTranslation.h" @@ -22,10 +24,14 @@ IrBuilder::IrBuilder() { } +static bool hasTypedParameters(Proto* proto) +{ + return proto->typeinfo && proto->numparams != 0; +} + static void buildArgumentTypeChecks(IrBuilder& build, Proto* proto) { - if (!proto->typeinfo || proto->numparams == 0) - return; + LUAU_ASSERT(hasTypedParameters(proto)); for (int i = 0; i < proto->numparams; ++i) { @@ -53,31 +59,31 @@ static void buildArgumentTypeChecks(IrBuilder& build, Proto* proto) switch (tag) { case LBC_TYPE_NIL: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNIL), build.undef(), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNIL), build.vmExit(kVmExitEntryGuardPc)); break; case LBC_TYPE_BOOLEAN: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBOOLEAN), build.undef(), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TBOOLEAN), build.vmExit(kVmExitEntryGuardPc)); break; case LBC_TYPE_NUMBER: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNUMBER), build.undef(), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TNUMBER), build.vmExit(kVmExitEntryGuardPc)); break; case LBC_TYPE_STRING: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TSTRING), build.undef(), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TSTRING), build.vmExit(kVmExitEntryGuardPc)); break; case LBC_TYPE_TABLE: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTABLE), build.undef(), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTABLE), build.vmExit(kVmExitEntryGuardPc)); break; case LBC_TYPE_FUNCTION: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TFUNCTION), build.undef(), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TFUNCTION), build.vmExit(kVmExitEntryGuardPc)); break; case LBC_TYPE_THREAD: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTHREAD), build.undef(), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TTHREAD), build.vmExit(kVmExitEntryGuardPc)); break; case LBC_TYPE_USERDATA: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.undef(), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TUSERDATA), build.vmExit(kVmExitEntryGuardPc)); break; case LBC_TYPE_VECTOR: - build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.undef(), build.constInt(1)); + build.inst(IrCmd::CHECK_TAG, load, build.constTag(LUA_TVECTOR), build.vmExit(kVmExitEntryGuardPc)); break; } @@ -103,11 +109,28 @@ void IrBuilder::buildFunctionIr(Proto* proto) function.proto = proto; function.variadic = proto->is_vararg != 0; + // Reserve entry block + bool generateTypeChecks = hasTypedParameters(proto); + IrOp entry = generateTypeChecks ? block(IrBlockKind::Internal) : IrOp{}; + // Rebuild original control flow blocks rebuildBytecodeBasicBlocks(proto); function.bcMapping.resize(proto->sizecode, {~0u, ~0u}); + if (generateTypeChecks) + { + beginBlock(entry); + buildArgumentTypeChecks(*this, proto); + inst(IrCmd::JUMP, blockAtInst(0)); + } + else + { + entry = blockAtInst(0); + } + + function.entryBlock = entry.index; + // Translate all instructions to IR inside blocks for (int i = 0; i < proto->sizecode;) { @@ -123,12 +146,15 @@ void IrBuilder::buildFunctionIr(Proto* proto) if (instIndexToBlock[i] != kNoAssociatedBlockIndex) beginBlock(blockAtInst(i)); - if (i == 0) - buildArgumentTypeChecks(*this, proto); - // We skip dead bytecode instructions when they appear after block was already terminated if (!inTerminatedBlock) { + if (interruptRequested) + { + interruptRequested = false; + inst(IrCmd::INTERRUPT, constUint(i)); + } + translateInst(op, pc, i); if (fastcallSkipTarget != -1) @@ -313,6 +339,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) case LOP_DIV: translateInstBinary(*this, pc, i, TM_DIV); break; + case LOP_IDIV: + translateInstBinary(*this, pc, i, TM_IDIV); + break; case LOP_MOD: translateInstBinary(*this, pc, i, TM_MOD); break; @@ -331,6 +360,9 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) case LOP_DIVK: translateInstBinaryK(*this, pc, i, TM_DIV); break; + case LOP_IDIVK: + translateInstBinaryK(*this, pc, i, TM_IDIV); + break; case LOP_MODK: translateInstBinaryK(*this, pc, i, TM_MOD); break; @@ -353,7 +385,8 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstDupTable(*this, pc, i); break; case LOP_SETLIST: - inst(IrCmd::SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); + inst(IrCmd::SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1]), + undef()); break; case LOP_GETUPVAL: translateInstGetUpval(*this, pc, i); diff --git a/CodeGen/src/IrCallWrapperX64.cpp b/CodeGen/src/IrCallWrapperX64.cpp index 816e0184..15fabf09 100644 --- a/CodeGen/src/IrCallWrapperX64.cpp +++ b/CodeGen/src/IrCallWrapperX64.cpp @@ -13,7 +13,7 @@ namespace CodeGen namespace X64 { -static const std::array kWindowsGprOrder = {rcx, rdx, r8, r9, addr[rsp + 32], addr[rsp + 40]}; +static const std::array kWindowsGprOrder = {rcx, rdx, r8, r9, addr[rsp + kStackRegHomeStorage], addr[rsp + kStackRegHomeStorage + 8]}; static const std::array kSystemvGprOrder = {rdi, rsi, rdx, rcx, r8, r9}; static const std::array kXmmOrder = {xmm0, xmm1, xmm2, xmm3}; // Common order for first 4 fp arguments on Windows/SystemV diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index 12b75bcc..21b2f702 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -89,8 +89,6 @@ const char* getCmdName(IrCmd cmd) return "LOAD_INT"; case IrCmd::LOAD_TVALUE: return "LOAD_TVALUE"; - case IrCmd::LOAD_NODE_VALUE_TV: - return "LOAD_NODE_VALUE_TV"; case IrCmd::LOAD_ENV: return "LOAD_ENV"; case IrCmd::GET_ARR_ADDR: @@ -113,8 +111,8 @@ const char* getCmdName(IrCmd cmd) return "STORE_VECTOR"; case IrCmd::STORE_TVALUE: return "STORE_TVALUE"; - case IrCmd::STORE_NODE_VALUE_TV: - return "STORE_NODE_VALUE_TV"; + case IrCmd::STORE_SPLIT_TVALUE: + return "STORE_SPLIT_TVALUE"; case IrCmd::ADD_INT: return "ADD_INT"; case IrCmd::SUB_INT: @@ -127,6 +125,8 @@ const char* getCmdName(IrCmd cmd) return "MUL_NUM"; case IrCmd::DIV_NUM: return "DIV_NUM"; + case IrCmd::IDIV_NUM: + return "IDIV_NUM"; case IrCmd::MOD_NUM: return "MOD_NUM"; case IrCmd::MIN_NUM: @@ -157,12 +157,8 @@ const char* getCmdName(IrCmd cmd) return "JUMP_IF_FALSY"; case IrCmd::JUMP_EQ_TAG: return "JUMP_EQ_TAG"; - case IrCmd::JUMP_EQ_INT: - return "JUMP_EQ_INT"; - case IrCmd::JUMP_LT_INT: - return "JUMP_LT_INT"; - case IrCmd::JUMP_GE_UINT: - return "JUMP_GE_UINT"; + case IrCmd::JUMP_CMP_INT: + return "JUMP_CMP_INT"; case IrCmd::JUMP_EQ_POINTER: return "JUMP_EQ_POINTER"; case IrCmd::JUMP_CMP_NUM: @@ -171,6 +167,8 @@ const char* getCmdName(IrCmd cmd) return "JUMP_SLOT_MATCH"; case IrCmd::TABLE_LEN: return "TABLE_LEN"; + case IrCmd::TABLE_SETNUM: + return "TABLE_SETNUM"; case IrCmd::STRING_LEN: return "STRING_LEN"; case IrCmd::NEW_TABLE: @@ -215,8 +213,6 @@ const char* getCmdName(IrCmd cmd) return "GET_UPVALUE"; case IrCmd::SET_UPVALUE: return "SET_UPVALUE"; - case IrCmd::PREPARE_FORN: - return "PREPARE_FORN"; case IrCmd::CHECK_TAG: return "CHECK_TAG"; case IrCmd::CHECK_TRUTHY: @@ -233,6 +229,8 @@ const char* getCmdName(IrCmd cmd) return "CHECK_SLOT_MATCH"; case IrCmd::CHECK_NODE_NO_NEXT: return "CHECK_NODE_NO_NEXT"; + case IrCmd::CHECK_NODE_VALUE: + return "CHECK_NODE_VALUE"; case IrCmd::INTERRUPT: return "INTERRUPT"; case IrCmd::CHECK_GC: @@ -402,7 +400,7 @@ void toString(IrToStringContext& ctx, IrOp op) append(ctx.result, "U%d", vmUpvalueOp(op)); break; case IrOpKind::VmExit: - append(ctx.result, "exit(%d)", op.index); + append(ctx.result, "exit(%d)", vmExitOp(op)); break; } } diff --git a/CodeGen/src/IrLoweringA64.cpp b/CodeGen/src/IrLoweringA64.cpp index 16796a28..6e51b0d5 100644 --- a/CodeGen/src/IrLoweringA64.cpp +++ b/CodeGen/src/IrLoweringA64.cpp @@ -1,10 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "IrLoweringA64.h" -#include "Luau/CodeGen.h" #include "Luau/DenseHash.h" -#include "Luau/IrAnalysis.h" -#include "Luau/IrDump.h" +#include "Luau/IrData.h" #include "Luau/IrUtils.h" #include "EmitCommonA64.h" @@ -60,28 +58,56 @@ inline ConditionA64 getConditionFP(IrCondition cond) } } -static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp, int ra, int ratag, Label& skip) +inline ConditionA64 getConditionInt(IrCondition cond) { - RegisterA64 tempw = castReg(KindA64::w, temp); - - // Barrier should've been optimized away if we know that it's not collectable, checking for correctness - if (ratag == -1 || !isGCO(ratag)) + switch (cond) { - // iscollectable(ra) - build.ldr(tempw, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, tt))); - build.cmp(tempw, LUA_TSTRING); - build.b(ConditionA64::Less, skip); + case IrCondition::Equal: + return ConditionA64::Equal; + + case IrCondition::NotEqual: + return ConditionA64::NotEqual; + + case IrCondition::Less: + return ConditionA64::Minus; + + case IrCondition::NotLess: + return ConditionA64::Plus; + + case IrCondition::LessEqual: + return ConditionA64::LessEqual; + + case IrCondition::NotLessEqual: + return ConditionA64::Greater; + + case IrCondition::Greater: + return ConditionA64::Greater; + + case IrCondition::NotGreater: + return ConditionA64::LessEqual; + + case IrCondition::GreaterEqual: + return ConditionA64::GreaterEqual; + + case IrCondition::NotGreaterEqual: + return ConditionA64::Less; + + case IrCondition::UnsignedLess: + return ConditionA64::CarryClear; + + case IrCondition::UnsignedLessEqual: + return ConditionA64::UnsignedLessEqual; + + case IrCondition::UnsignedGreater: + return ConditionA64::UnsignedGreater; + + case IrCondition::UnsignedGreaterEqual: + return ConditionA64::CarrySet; + + default: + LUAU_ASSERT(!"Unexpected condition code"); + return ConditionA64::Always; } - - // isblack(obj2gco(o)) - build.ldrb(tempw, mem(object, offsetof(GCheader, marked))); - build.tbz(tempw, BLACKBIT, skip); - - // iswhite(gcvalue(ra)) - build.ldr(temp, mem(rBase, ra * sizeof(TValue) + offsetof(TValue, value))); - build.ldrb(tempw, mem(temp, offsetof(GCheader, marked))); - build.tst(tempw, bit2mask(WHITE0BIT, WHITE1BIT)); - build.b(ConditionA64::Equal, skip); // Equal = Zero after tst } static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA64 src, size_t offset) @@ -100,6 +126,47 @@ static void emitAddOffset(AssemblyBuilderA64& build, RegisterA64 dst, RegisterA6 } } +static void checkObjectBarrierConditions(AssemblyBuilderA64& build, RegisterA64 object, RegisterA64 temp, IrOp ra, int ratag, Label& skip) +{ + RegisterA64 tempw = castReg(KindA64::w, temp); + AddressA64 addr = temp; + + // iscollectable(ra) + if (ratag == -1 || !isGCO(ratag)) + { + if (ra.kind == IrOpKind::VmReg) + { + addr = mem(rBase, vmRegOp(ra) * sizeof(TValue) + offsetof(TValue, tt)); + } + else if (ra.kind == IrOpKind::VmConst) + { + emitAddOffset(build, temp, rConstants, vmConstOp(ra) * sizeof(TValue) + offsetof(TValue, tt)); + } + + build.ldr(tempw, addr); + build.cmp(tempw, LUA_TSTRING); + build.b(ConditionA64::Less, skip); + } + + // isblack(obj2gco(o)) + build.ldrb(tempw, mem(object, offsetof(GCheader, marked))); + build.tbz(tempw, BLACKBIT, skip); + + // iswhite(gcvalue(ra)) + if (ra.kind == IrOpKind::VmReg) + { + addr = mem(rBase, vmRegOp(ra) * sizeof(TValue) + offsetof(TValue, value)); + } + else if (ra.kind == IrOpKind::VmConst) + { + emitAddOffset(build, temp, rConstants, vmConstOp(ra) * sizeof(TValue) + offsetof(TValue, value)); + } + build.ldr(temp, addr); + build.ldrb(tempw, mem(temp, offsetof(GCheader, marked))); + build.tst(tempw, bit2mask(WHITE0BIT, WHITE1BIT)); + build.b(ConditionA64::Equal, skip); // Equal = Zero after tst +} + static void emitAbort(AssemblyBuilderA64& build, Label& abort) { Label skip; @@ -124,6 +191,7 @@ static void emitFallback(AssemblyBuilderA64& build, int offset, int pcpos) static void emitInvokeLibm1P(AssemblyBuilderA64& build, size_t func, int arg) { + LUAU_ASSERT(kTempSlots >= 1); build.ldr(d0, mem(rBase, arg * sizeof(TValue) + offsetof(TValue, value.n))); build.add(x0, sp, sTemporary.data); // sp-relative offset build.ldr(x1, mem(rNativeContext, uint32_t(func))); @@ -172,11 +240,12 @@ static bool emitBuiltin( } } -IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function) +IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats) : build(build) , helpers(helpers) , function(function) - , regs(function, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}}) + , stats(stats) + , regs(function, stats, {{x0, x15}, {x16, x17}, {q0, q7}, {q16, q31}}) , valueTracker(function) , exitHandlerMap(~0u) { @@ -189,7 +258,7 @@ IrLoweringA64::IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, }); } -void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) +void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { valueTracker.beforeInstLowering(inst); @@ -226,16 +295,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::LOAD_TVALUE: { inst.regA64 = regs.allocReg(KindA64::q, index); - AddressA64 addr = tempAddr(inst.a, 0); + + int addrOffset = inst.b.kind != IrOpKind::None ? intOp(inst.b) : 0; + AddressA64 addr = tempAddr(inst.a, addrOffset); build.ldr(inst.regA64, addr); break; } - case IrCmd::LOAD_NODE_VALUE_TV: - { - inst.regA64 = regs.allocReg(KindA64::q, index); - build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(LuaNode, val))); - break; - } case IrCmd::LOAD_ENV: inst.regA64 = regs.allocReg(KindA64::x, index); build.ldr(inst.regA64, mem(rClosure, offsetof(Closure, env))); @@ -247,7 +312,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) if (inst.b.kind == IrOpKind::Inst) { - build.add(inst.regA64, inst.regA64, zextReg(regOp(inst.b)), kTValueSizeLog2); + build.add(inst.regA64, inst.regA64, regOp(inst.b), kTValueSizeLog2); } else if (inst.b.kind == IrOpKind::Constant) { @@ -276,6 +341,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp1w = castReg(KindA64::w, temp1); RegisterA64 temp2 = regs.allocTemp(KindA64::w); + RegisterA64 temp2x = castReg(KindA64::x, temp2); // note: since the stride of the load is the same as the destination register size, we can range check the array index, not the byte offset if (uintOp(inst.b) <= AddressA64::kMaxOffset) @@ -293,7 +359,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // note: this may clobber inst.a, so it's important that we don't use it after this build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); - build.add(inst.regA64, inst.regA64, zextReg(temp2), kLuaNodeSizeLog2); + build.add(inst.regA64, inst.regA64, temp2x, kLuaNodeSizeLog2); // "zero extend" temp2 to get a larger shift (top 32 bits are zero) break; } case IrCmd::GET_HASH_NODE_ADDR: @@ -301,6 +367,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regA64 = regs.allocReuse(KindA64::x, index, {inst.a}); RegisterA64 temp1 = regs.allocTemp(KindA64::w); RegisterA64 temp2 = regs.allocTemp(KindA64::w); + RegisterA64 temp2x = castReg(KindA64::x, temp2); // hash & ((1 << lsizenode) - 1) == hash & ~(-1 << lsizenode) build.mov(temp1, -1); @@ -311,7 +378,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // note: this may clobber inst.a, so it's important that we don't use it after this build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(Table, node))); - build.add(inst.regA64, inst.regA64, zextReg(temp2), kLuaNodeSizeLog2); + build.add(inst.regA64, inst.regA64, temp2x, kLuaNodeSizeLog2); // "zero extend" temp2 to get a larger shift (top 32 bits are zero) break; } case IrCmd::GET_CLOSURE_UPVAL_ADDR: @@ -324,10 +391,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::STORE_TAG: { - RegisterA64 temp = regs.allocTemp(KindA64::w); AddressA64 addr = tempAddr(inst.a, offsetof(TValue, tt)); - build.mov(temp, tagOp(inst.b)); - build.str(temp, addr); + if (tagOp(inst.b) == 0) + { + build.str(wzr, addr); + } + else + { + RegisterA64 temp = regs.allocTemp(KindA64::w); + build.mov(temp, tagOp(inst.b)); + build.str(temp, addr); + } break; } case IrCmd::STORE_POINTER: @@ -345,9 +419,16 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::STORE_INT: { - RegisterA64 temp = tempInt(inst.b); AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value)); - build.str(temp, addr); + if (inst.b.kind == IrOpKind::Constant && intOp(inst.b) == 0) + { + build.str(wzr, addr); + } + else + { + RegisterA64 temp = tempInt(inst.b); + build.str(temp, addr); + } break; } case IrCmd::STORE_VECTOR: @@ -370,13 +451,48 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::STORE_TVALUE: { - AddressA64 addr = tempAddr(inst.a, 0); + int addrOffset = inst.c.kind != IrOpKind::None ? intOp(inst.c) : 0; + AddressA64 addr = tempAddr(inst.a, addrOffset); build.str(regOp(inst.b), addr); break; } - case IrCmd::STORE_NODE_VALUE_TV: - build.str(regOp(inst.b), mem(regOp(inst.a), offsetof(LuaNode, val))); + case IrCmd::STORE_SPLIT_TVALUE: + { + int addrOffset = inst.d.kind != IrOpKind::None ? intOp(inst.d) : 0; + + RegisterA64 tempt = regs.allocTemp(KindA64::w); + AddressA64 addrt = tempAddr(inst.a, offsetof(TValue, tt) + addrOffset); + build.mov(tempt, tagOp(inst.b)); + build.str(tempt, addrt); + + AddressA64 addr = tempAddr(inst.a, offsetof(TValue, value) + addrOffset); + + if (tagOp(inst.b) == LUA_TBOOLEAN) + { + if (inst.c.kind == IrOpKind::Constant) + { + // note: we reuse tag temp register as value for true booleans, and use built-in zero register for false values + LUAU_ASSERT(LUA_TBOOLEAN == 1); + build.str(intOp(inst.c) ? tempt : wzr, addr); + } + else + build.str(regOp(inst.c), addr); + } + else if (tagOp(inst.b) == LUA_TNUMBER) + { + RegisterA64 temp = tempDouble(inst.c); + build.str(temp, addr); + } + else if (isGCO(tagOp(inst.b))) + { + build.str(regOp(inst.c), addr); + } + else + { + LUAU_ASSERT(!"Unsupported instruction form"); + } break; + } case IrCmd::ADD_INT: inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b}); if (inst.b.kind == IrOpKind::Constant && unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate) @@ -433,6 +549,15 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.fdiv(inst.regA64, temp1, temp2); break; } + case IrCmd::IDIV_NUM: + { + inst.regA64 = regs.allocReuse(KindA64::d, index, {inst.a, inst.b}); + RegisterA64 temp1 = tempDouble(inst.a); + RegisterA64 temp2 = tempDouble(inst.b); + build.fdiv(inst.regA64, temp1, temp2); + build.frintm(inst.regA64, inst.regA64); + break; + } case IrCmd::MOD_NUM: { inst.regA64 = regs.allocReg(KindA64::d, index); // can't allocReuse because both A and B are used twice @@ -560,13 +685,11 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) emitUpdateBase(build); - // since w0 came from a call, we need to move it so that we don't violate zextReg safety contract - inst.regA64 = regs.allocReg(KindA64::w, index); - build.mov(inst.regA64, w0); + inst.regA64 = regs.takeReg(w0, index); break; } case IrCmd::JUMP: - if (inst.a.kind == IrOpKind::VmExit) + if (inst.a.kind == IrOpKind::Undef || inst.a.kind == IrOpKind::VmExit) { Label fresh; build.b(getTargetLabel(inst.a, fresh)); @@ -644,31 +767,25 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; } - case IrCmd::JUMP_EQ_INT: - if (intOp(inst.b) == 0) + case IrCmd::JUMP_CMP_INT: + { + IrCondition cond = conditionOp(inst.c); + + if (cond == IrCondition::Equal && intOp(inst.b) == 0) { - build.cbz(regOp(inst.a), labelOp(inst.c)); + build.cbz(regOp(inst.a), labelOp(inst.d)); + } + else if (cond == IrCondition::NotEqual && intOp(inst.b) == 0) + { + build.cbnz(regOp(inst.a), labelOp(inst.d)); } else { LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); - build.b(ConditionA64::Equal, labelOp(inst.c)); + build.b(getConditionInt(cond), labelOp(inst.d)); } - jumpOrFallthrough(blockOp(inst.d), next); - break; - case IrCmd::JUMP_LT_INT: - LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); - build.cmp(regOp(inst.a), uint16_t(intOp(inst.b))); - build.b(ConditionA64::Less, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); - break; - case IrCmd::JUMP_GE_UINT: - { - LUAU_ASSERT(unsigned(intOp(inst.b)) <= AssemblyBuilderA64::kMaxImmediate); - build.cmp(regOp(inst.a), uint16_t(unsigned(intOp(inst.b)))); - build.b(ConditionA64::CarrySet, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); + jumpOrFallthrough(blockOp(inst.e), next); break; } case IrCmd::JUMP_EQ_POINTER: @@ -706,16 +823,42 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(x0, reg); build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaH_getn))); build.blr(x1); - inst.regA64 = regs.allocReg(KindA64::d, index); - build.scvtf(inst.regA64, x0); + + inst.regA64 = regs.takeReg(w0, index); break; } case IrCmd::STRING_LEN: { - RegisterA64 reg = regOp(inst.a); inst.regA64 = regs.allocReg(KindA64::w, index); - build.ldr(inst.regA64, mem(reg, offsetof(TString, len))); + build.ldr(inst.regA64, mem(regOp(inst.a), offsetof(TString, len))); + break; + } + case IrCmd::TABLE_SETNUM: + { + // note: we need to call regOp before spill so that we don't do redundant reloads + RegisterA64 table = regOp(inst.a); + RegisterA64 key = regOp(inst.b); + RegisterA64 temp = regs.allocTemp(KindA64::w); + + regs.spill(build, index, {table, key}); + + if (w1 != key) + { + build.mov(x1, table); + build.mov(w2, key); + } + else + { + build.mov(temp, w1); + build.mov(x1, table); + build.mov(w2, temp); + } + + build.mov(x0, rState); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaH_setnum))); + build.blr(x3); + inst.regA64 = regs.takeReg(x0, index); break; } case IrCmd::NEW_TABLE: @@ -776,8 +919,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.spill(build, index, {temp1}); build.mov(x0, temp1); build.mov(w1, intOp(inst.b)); - build.ldr(x2, mem(rState, offsetof(lua_State, global))); - build.ldr(x2, mem(x2, offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*))); + build.ldr(x2, mem(rGlobalState, offsetof(global_State, tmname) + intOp(inst.b) * sizeof(TString*))); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaT_gettm))); build.blr(x3); @@ -812,8 +954,6 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) inst.regA64 = regs.allocReg(KindA64::w, index); RegisterA64 temp = tempDouble(inst.a); build.fcvtzs(castReg(KindA64::x, inst.regA64), temp); - // truncation needs to clear high bits to preserve zextReg safety contract - build.mov(inst.regA64, inst.regA64); break; } case IrCmd::ADJUST_STACK_TO_REG: @@ -828,7 +968,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) else if (inst.b.kind == IrOpKind::Inst) { build.add(temp, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); - build.add(temp, temp, zextReg(regOp(inst.b)), kTValueSizeLog2); + build.add(temp, temp, regOp(inst.b), kTValueSizeLog2); build.str(temp, mem(rState, offsetof(lua_State, top))); } else @@ -877,9 +1017,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.ldr(x6, mem(rNativeContext, offsetof(NativeContext, luauF_table) + uintOp(inst.a) * sizeof(luau_FastFunction))); build.blr(x6); - // since w0 came from a call, we need to move it so that we don't violate zextReg safety contract - inst.regA64 = regs.allocReg(KindA64::w, index); - build.mov(inst.regA64, w0); + inst.regA64 = regs.takeReg(w0, index); break; } case IrCmd::CHECK_FASTCALL_RES: @@ -974,8 +1112,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) case IrCmd::CONCAT: regs.spill(build, index); build.mov(x0, rState); - build.mov(x1, uintOp(inst.b)); - build.mov(x2, vmRegOp(inst.a) + uintOp(inst.b) - 1); + build.mov(w1, uintOp(inst.b)); + build.mov(w2, vmRegOp(inst.a) + uintOp(inst.b) - 1); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaV_concat))); build.blr(x3); @@ -1018,38 +1156,30 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.ldr(temp3, mem(rBase, vmRegOp(inst.b) * sizeof(TValue))); build.str(temp3, temp2); - Label skip; - checkObjectBarrierConditions(build, temp1, temp2, vmRegOp(inst.b), /* ratag */ -1, skip); + if (inst.c.kind == IrOpKind::Undef || isGCO(tagOp(inst.c))) + { + Label skip; + checkObjectBarrierConditions(build, temp1, temp2, inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); - size_t spills = regs.spill(build, index, {temp1}); + size_t spills = regs.spill(build, index, {temp1}); - build.mov(x1, temp1); - build.mov(x0, rState); - build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); - build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); - build.blr(x3); + build.mov(x1, temp1); + build.mov(x0, rState); + build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barrierf))); + build.blr(x3); - regs.restore(build, spills); // need to restore before skip so that registers are in a consistent state + regs.restore(build, spills); // need to restore before skip so that registers are in a consistent state - // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack - build.setLabel(skip); + // note: no emitUpdateBase necessary because luaC_ barriers do not reallocate stack + build.setLabel(skip); + } break; } - case IrCmd::PREPARE_FORN: - regs.spill(build, index); - build.mov(x0, rState); - build.add(x1, rBase, uint16_t(vmRegOp(inst.a) * sizeof(TValue))); - build.add(x2, rBase, uint16_t(vmRegOp(inst.b) * sizeof(TValue))); - build.add(x3, rBase, uint16_t(vmRegOp(inst.c) * sizeof(TValue))); - build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, luaV_prepareFORN))); - build.blr(x4); - // note: no emitUpdateBase necessary because prepareFORN does not reallocate stack - break; case IrCmd::CHECK_TAG: { - bool continueInVm = (inst.d.kind == IrOpKind::Constant && intOp(inst.d)); Label fresh; // used when guard aborts execution or jumps to a VM exit - Label& fail = continueInVm ? helpers.exitContinueVmClearNativeFlag : getTargetLabel(inst.c, fresh); + Label& fail = getTargetLabel(inst.c, fresh); // To support DebugLuauAbortingChecks, CHECK_TAG with VmReg has to be handled RegisterA64 tag = inst.a.kind == IrOpKind::VmReg ? regs.allocTemp(KindA64::w) : regOp(inst.a); @@ -1066,8 +1196,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.cmp(tag, tagOp(inst.b)); build.b(ConditionA64::NotEqual, fail); } - if (!continueInVm) - finalizeTargetLabel(inst.c, fresh); + + finalizeTargetLabel(inst.c, fresh); break; } case IrCmd::CHECK_TRUTHY: @@ -1210,14 +1340,24 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) finalizeTargetLabel(inst.b, fresh); break; } + case IrCmd::CHECK_NODE_VALUE: + { + Label fresh; // used when guard aborts execution or jumps to a VM exit + RegisterA64 temp = regs.allocTemp(KindA64::w); + + build.ldr(temp, mem(regOp(inst.a), offsetof(LuaNode, val.tt))); + LUAU_ASSERT(LUA_TNIL == 0); + build.cbz(temp, getTargetLabel(inst.b, fresh)); + finalizeTargetLabel(inst.b, fresh); + break; + } case IrCmd::INTERRUPT: { regs.spill(build, index); Label self; - build.ldr(x0, mem(rState, offsetof(lua_State, global))); - build.ldr(x0, mem(x0, offsetof(global_State, cb.interrupt))); + build.ldr(x0, mem(rGlobalState, offsetof(global_State, cb.interrupt))); build.cbnz(x0, self); Label next = build.setLabel(); @@ -1230,11 +1370,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) RegisterA64 temp1 = regs.allocTemp(KindA64::x); RegisterA64 temp2 = regs.allocTemp(KindA64::x); + LUAU_ASSERT(offsetof(global_State, totalbytes) == offsetof(global_State, GCthreshold) + 8); Label skip; - build.ldr(temp1, mem(rState, offsetof(lua_State, global))); - // TODO: totalbytes and GCthreshold loads can be fused with ldp - build.ldr(temp2, mem(temp1, offsetof(global_State, totalbytes))); - build.ldr(temp1, mem(temp1, offsetof(global_State, GCthreshold))); + build.ldp(temp1, temp2, mem(rGlobalState, offsetof(global_State, GCthreshold))); build.cmp(temp1, temp2); build.b(ConditionA64::UnsignedGreater, skip); @@ -1242,8 +1380,8 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(x0, rState); build.mov(w1, 1); - build.ldr(x1, mem(rNativeContext, offsetof(NativeContext, luaC_step))); - build.blr(x1); + build.ldr(x2, mem(rNativeContext, offsetof(NativeContext, luaC_step))); + build.blr(x2); emitUpdateBase(build); @@ -1257,7 +1395,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) RegisterA64 temp = regs.allocTemp(KindA64::x); Label skip; - checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); + checkObjectBarrierConditions(build, regOp(inst.a), temp, inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads size_t spills = regs.spill(build, index, {reg}); @@ -1301,13 +1439,14 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) RegisterA64 temp = regs.allocTemp(KindA64::x); Label skip; - checkObjectBarrierConditions(build, regOp(inst.a), temp, vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); + checkObjectBarrierConditions(build, regOp(inst.a), temp, inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); RegisterA64 reg = regOp(inst.a); // note: we need to call regOp before spill so that we don't do redundant reloads + AddressA64 addr = tempAddr(inst.b, offsetof(TValue, value)); size_t spills = regs.spill(build, index, {reg}); build.mov(x1, reg); build.mov(x0, rState); - build.ldr(x2, mem(rBase, vmRegOp(inst.b) * sizeof(TValue) + offsetof(TValue, value))); + build.ldr(x2, addr); build.ldr(x3, mem(rNativeContext, offsetof(NativeContext, luaC_barriertable))); build.blr(x3); @@ -1453,9 +1592,9 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // clear extra variables since we might have more than two if (intOp(inst.b) > 2) { - build.mov(w0, LUA_TNIL); + LUAU_ASSERT(LUA_TNIL == 0); for (int i = 2; i < intOp(inst.b); ++i) - build.str(w0, mem(rBase, (vmRegOp(inst.a) + 3 + i) * sizeof(TValue) + offsetof(TValue, tt))); + build.str(wzr, mem(rBase, (vmRegOp(inst.a) + 3 + i) * sizeof(TValue) + offsetof(TValue, tt))); } // we use full iter fallback for now; in the future it could be worthwhile to accelerate array iteration here build.mov(x0, rState); @@ -1564,7 +1703,7 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { emitAddOffset(build, x1, rCode, uintOp(inst.a) * sizeof(Instruction)); build.mov(x2, rBase); - build.mov(x3, vmRegOp(inst.b)); + build.mov(w3, vmRegOp(inst.b)); build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, executeGETVARARGSMultRet))); build.blr(x4); @@ -1573,10 +1712,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) else { build.mov(x1, rBase); - build.mov(x2, vmRegOp(inst.b)); - build.mov(x3, intOp(inst.c)); + build.mov(w2, vmRegOp(inst.b)); + build.mov(w3, intOp(inst.c)); build.ldr(x4, mem(rNativeContext, offsetof(NativeContext, executeGETVARARGSConst))); build.blr(x4); + + // note: no emitUpdateBase necessary because executeGETVARARGSConst does not reallocate stack } break; case IrCmd::NEWCLOSURE: @@ -1793,13 +1934,12 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regA64 = regs.allocReg(KindA64::x, index); - build.ldr(inst.regA64, mem(rState, offsetof(lua_State, global))); LUAU_ASSERT(sizeof(TString*) == 8); if (inst.a.kind == IrOpKind::Inst) - build.add(inst.regA64, inst.regA64, zextReg(regOp(inst.a)), 3); + build.add(inst.regA64, rGlobalState, regOp(inst.a), 3); else if (inst.a.kind == IrOpKind::Constant) - build.add(inst.regA64, inst.regA64, uint16_t(tagOp(inst.a)) * 8); + build.add(inst.regA64, rGlobalState, uint16_t(tagOp(inst.a)) * 8); else LUAU_ASSERT(!"Unsupported instruction form"); @@ -1839,9 +1979,17 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.freeTempRegs(); } -void IrLoweringA64::finishBlock() +void IrLoweringA64::finishBlock(const IrBlock& curr, const IrBlock& next) { - regs.assertNoSpills(); + if (!regs.spills.empty()) + { + // If we have spills remaining, we have to immediately lower the successor block + for (uint32_t predIdx : predecessors(function.cfg, function.getBlockIndex(next))) + LUAU_ASSERT(predIdx == function.getBlockIndex(curr)); + + // And the next block cannot be a join block in cfg + LUAU_ASSERT(next.useCount == 1); + } } void IrLoweringA64::finishFunction() @@ -1862,10 +2010,22 @@ void IrLoweringA64::finishFunction() for (ExitHandler& handler : exitHandlers) { + LUAU_ASSERT(handler.pcpos != kVmExitEntryGuardPc); + build.setLabel(handler.self); + build.mov(x0, handler.pcpos * sizeof(Instruction)); build.b(helpers.updatePcAndContinueInVm); } + + if (stats) + { + if (error) + stats->loweringErrors++; + + if (regs.error) + stats->regAllocErrors++; + } } bool IrLoweringA64::hasError() const @@ -1873,12 +2033,12 @@ bool IrLoweringA64::hasError() const return error || regs.error; } -bool IrLoweringA64::isFallthroughBlock(IrBlock target, IrBlock next) +bool IrLoweringA64::isFallthroughBlock(const IrBlock& target, const IrBlock& next) { return target.start == next.start; } -void IrLoweringA64::jumpOrFallthrough(IrBlock& target, IrBlock& next) +void IrLoweringA64::jumpOrFallthrough(IrBlock& target, const IrBlock& next) { if (!isFallthroughBlock(target, next)) build.b(target.label); @@ -1891,7 +2051,11 @@ Label& IrLoweringA64::getTargetLabel(IrOp op, Label& fresh) if (op.kind == IrOpKind::VmExit) { - if (uint32_t* index = exitHandlerMap.find(op.index)) + // Special exit case that doesn't have to update pcpos + if (vmExitOp(op) == kVmExitEntryGuardPc) + return helpers.exitContinueVmClearNativeFlag; + + if (uint32_t* index = exitHandlerMap.find(vmExitOp(op))) return exitHandlers[*index].self; return fresh; @@ -1906,10 +2070,10 @@ void IrLoweringA64::finalizeTargetLabel(IrOp op, Label& fresh) { emitAbort(build, fresh); } - else if (op.kind == IrOpKind::VmExit && fresh.id != 0) + else if (op.kind == IrOpKind::VmExit && fresh.id != 0 && fresh.id != helpers.exitContinueVmClearNativeFlag.id) { - exitHandlerMap[op.index] = uint32_t(exitHandlers.size()); - exitHandlers.push_back({fresh, op.index}); + exitHandlerMap[vmExitOp(op)] = uint32_t(exitHandlers.size()); + exitHandlers.push_back({fresh, vmExitOp(op)}); } } diff --git a/CodeGen/src/IrLoweringA64.h b/CodeGen/src/IrLoweringA64.h index 72b0da2f..46f41021 100644 --- a/CodeGen/src/IrLoweringA64.h +++ b/CodeGen/src/IrLoweringA64.h @@ -17,22 +17,23 @@ namespace CodeGen struct ModuleHelpers; struct AssemblyOptions; +struct LoweringStats; namespace A64 { struct IrLoweringA64 { - IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function); + IrLoweringA64(AssemblyBuilderA64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats); - void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); - void finishBlock(); + void lowerInst(IrInst& inst, uint32_t index, const IrBlock& next); + void finishBlock(const IrBlock& curr, const IrBlock& next); void finishFunction(); bool hasError() const; - bool isFallthroughBlock(IrBlock target, IrBlock next); - void jumpOrFallthrough(IrBlock& target, IrBlock& next); + bool isFallthroughBlock(const IrBlock& target, const IrBlock& next); + void jumpOrFallthrough(IrBlock& target, const IrBlock& next); Label& getTargetLabel(IrOp op, Label& fresh); void finalizeTargetLabel(IrOp op, Label& fresh); @@ -74,6 +75,7 @@ struct IrLoweringA64 ModuleHelpers& helpers; IrFunction& function; + LoweringStats* stats = nullptr; IrRegAllocA64 regs; diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 6fe1e771..03ae4f0c 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -1,19 +1,19 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "IrLoweringX64.h" -#include "Luau/CodeGen.h" #include "Luau/DenseHash.h" -#include "Luau/IrAnalysis.h" -#include "Luau/IrCallWrapperX64.h" -#include "Luau/IrDump.h" +#include "Luau/IrData.h" #include "Luau/IrUtils.h" +#include "Luau/IrCallWrapperX64.h" + #include "EmitBuiltinsX64.h" #include "EmitCommonX64.h" #include "EmitInstructionX64.h" #include "NativeState.h" #include "lstate.h" +#include "lgc.h" namespace Luau { @@ -22,11 +22,12 @@ namespace CodeGen namespace X64 { -IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function) +IrLoweringX64::IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats) : build(build) , helpers(helpers) , function(function) - , regs(build, function) + , stats(stats) + , regs(build, function, stats) , valueTracker(function) , exitHandlerMap(~0u) { @@ -59,7 +60,7 @@ void IrLoweringX64::storeDoubleAsFloat(OperandX64 dst, IrOp src) build.vmovss(dst, tmp.reg); } -void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) +void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next) { regs.currInstIdx = index; @@ -111,22 +112,21 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(inst.regX64, luauRegValueInt(vmRegOp(inst.a))); break; case IrCmd::LOAD_TVALUE: + { inst.regX64 = regs.allocReg(SizeX64::xmmword, index); + int addrOffset = inst.b.kind != IrOpKind::None ? intOp(inst.b) : 0; + if (inst.a.kind == IrOpKind::VmReg) build.vmovups(inst.regX64, luauReg(vmRegOp(inst.a))); else if (inst.a.kind == IrOpKind::VmConst) build.vmovups(inst.regX64, luauConstant(vmConstOp(inst.a))); else if (inst.a.kind == IrOpKind::Inst) - build.vmovups(inst.regX64, xmmword[regOp(inst.a)]); + build.vmovups(inst.regX64, xmmword[regOp(inst.a) + addrOffset]); else LUAU_ASSERT(!"Unsupported instruction form"); break; - case IrCmd::LOAD_NODE_VALUE_TV: - inst.regX64 = regs.allocReg(SizeX64::xmmword, index); - - build.vmovups(inst.regX64, luauNodeValue(regOp(inst.a))); - break; + } case IrCmd::LOAD_ENV: inst.regX64 = regs.allocReg(SizeX64::qword, index); @@ -252,16 +252,59 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) storeDoubleAsFloat(luauRegValueVector(vmRegOp(inst.a), 2), inst.d); break; case IrCmd::STORE_TVALUE: + { + int addrOffset = inst.c.kind != IrOpKind::None ? intOp(inst.c) : 0; + if (inst.a.kind == IrOpKind::VmReg) build.vmovups(luauReg(vmRegOp(inst.a)), regOp(inst.b)); else if (inst.a.kind == IrOpKind::Inst) - build.vmovups(xmmword[regOp(inst.a)], regOp(inst.b)); + build.vmovups(xmmword[regOp(inst.a) + addrOffset], regOp(inst.b)); else LUAU_ASSERT(!"Unsupported instruction form"); break; - case IrCmd::STORE_NODE_VALUE_TV: - build.vmovups(luauNodeValue(regOp(inst.a)), regOp(inst.b)); + } + case IrCmd::STORE_SPLIT_TVALUE: + { + int addrOffset = inst.d.kind != IrOpKind::None ? intOp(inst.d) : 0; + + OperandX64 tagLhs = inst.a.kind == IrOpKind::Inst ? dword[regOp(inst.a) + offsetof(TValue, tt) + addrOffset] : luauRegTag(vmRegOp(inst.a)); + build.mov(tagLhs, tagOp(inst.b)); + + if (tagOp(inst.b) == LUA_TBOOLEAN) + { + OperandX64 valueLhs = + inst.a.kind == IrOpKind::Inst ? dword[regOp(inst.a) + offsetof(TValue, value) + addrOffset] : luauRegValueInt(vmRegOp(inst.a)); + build.mov(valueLhs, inst.c.kind == IrOpKind::Constant ? OperandX64(intOp(inst.c)) : regOp(inst.c)); + } + else if (tagOp(inst.b) == LUA_TNUMBER) + { + OperandX64 valueLhs = + inst.a.kind == IrOpKind::Inst ? qword[regOp(inst.a) + offsetof(TValue, value) + addrOffset] : luauRegValue(vmRegOp(inst.a)); + + if (inst.c.kind == IrOpKind::Constant) + { + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + build.vmovsd(tmp.reg, build.f64(doubleOp(inst.c))); + build.vmovsd(valueLhs, tmp.reg); + } + else + { + build.vmovsd(valueLhs, regOp(inst.c)); + } + } + else if (isGCO(tagOp(inst.b))) + { + OperandX64 valueLhs = + inst.a.kind == IrOpKind::Inst ? qword[regOp(inst.a) + offsetof(TValue, value) + addrOffset] : luauRegValue(vmRegOp(inst.a)); + build.mov(valueLhs, regOp(inst.c)); + } + else + { + LUAU_ASSERT(!"Unsupported instruction form"); + } break; + } case IrCmd::ADD_INT: { inst.regX64 = regs.allocRegOrReuse(SizeX64::dword, index, {inst.a}); @@ -365,6 +408,22 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.vdivsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b)); } break; + case IrCmd::IDIV_NUM: + inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); + + if (inst.a.kind == IrOpKind::Constant) + { + ScopedRegX64 tmp{regs, SizeX64::xmmword}; + + build.vmovsd(tmp.reg, memRegDoubleOp(inst.a)); + build.vdivsd(inst.regX64, tmp.reg, memRegDoubleOp(inst.b)); + } + else + { + build.vdivsd(inst.regX64, regOp(inst.a), memRegDoubleOp(inst.b)); + } + build.vroundsd(inst.regX64, inst.regX64, inst.regX64, RoundingModeX64::RoundToNegativeInfinity); + break; case IrCmd::MOD_NUM: { inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b}); @@ -565,24 +624,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) break; } case IrCmd::JUMP: - if (inst.a.kind == IrOpKind::VmExit) - { - if (uint32_t* index = exitHandlerMap.find(inst.a.index)) - { - build.jmp(exitHandlers[*index].self); - } - else - { - Label self; - build.jmp(self); - exitHandlerMap[inst.a.index] = uint32_t(exitHandlers.size()); - exitHandlers.push_back({self, inst.a.index}); - } - } - else - { - jumpOrFallthrough(blockOp(inst.a), next); - } + jumpOrAbortOnUndef(inst.a, next); break; case IrCmd::JUMP_IF_TRUTHY: jumpIfTruthy(build, vmRegOp(inst.a), labelOp(inst.b), labelOp(inst.c)); @@ -597,7 +639,10 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) LUAU_ASSERT(inst.b.kind == IrOpKind::Inst || inst.b.kind == IrOpKind::Constant); OperandX64 opb = inst.b.kind == IrOpKind::Inst ? regOp(inst.b) : OperandX64(tagOp(inst.b)); - build.cmp(memRegTagOp(inst.a), opb); + if (inst.a.kind == IrOpKind::Constant) + build.cmp(opb, tagOp(inst.a)); + else + build.cmp(memRegTagOp(inst.a), opb); if (isFallthroughBlock(blockOp(inst.d), next)) { @@ -611,42 +656,36 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } break; } - case IrCmd::JUMP_EQ_INT: - if (intOp(inst.b) == 0) + case IrCmd::JUMP_CMP_INT: + { + IrCondition cond = conditionOp(inst.c); + + if ((cond == IrCondition::Equal || cond == IrCondition::NotEqual) && intOp(inst.b) == 0) { + bool invert = cond == IrCondition::NotEqual; + build.test(regOp(inst.a), regOp(inst.a)); - if (isFallthroughBlock(blockOp(inst.c), next)) + if (isFallthroughBlock(blockOp(inst.d), next)) { - build.jcc(ConditionX64::NotZero, labelOp(inst.d)); - jumpOrFallthrough(blockOp(inst.c), next); + build.jcc(invert ? ConditionX64::Zero : ConditionX64::NotZero, labelOp(inst.e)); + jumpOrFallthrough(blockOp(inst.d), next); } else { - build.jcc(ConditionX64::Zero, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); + build.jcc(invert ? ConditionX64::NotZero : ConditionX64::Zero, labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.e), next); } } else { build.cmp(regOp(inst.a), intOp(inst.b)); - build.jcc(ConditionX64::Equal, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); + build.jcc(getConditionInt(cond), labelOp(inst.d)); + jumpOrFallthrough(blockOp(inst.e), next); } break; - case IrCmd::JUMP_LT_INT: - build.cmp(regOp(inst.a), intOp(inst.b)); - - build.jcc(ConditionX64::Less, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); - break; - case IrCmd::JUMP_GE_UINT: - build.cmp(regOp(inst.a), unsigned(intOp(inst.b))); - - build.jcc(ConditionX64::AboveEqual, labelOp(inst.c)); - jumpOrFallthrough(blockOp(inst.d), next); - break; + } case IrCmd::JUMP_EQ_POINTER: build.cmp(regOp(inst.a), regOp(inst.b)); @@ -659,7 +698,6 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) ScopedRegX64 tmp{regs, SizeX64::xmmword}; - // TODO: jumpOnNumberCmp should work on IrCondition directly jumpOnNumberCmp(build, tmp.reg, memRegDoubleOp(inst.a), memRegDoubleOp(inst.b), cond, labelOp(inst.d)); jumpOrFallthrough(blockOp(inst.e), next); break; @@ -669,9 +707,17 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) IrCallWrapperX64 callWrap(regs, build, index); callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_getn)]); - - inst.regX64 = regs.allocReg(SizeX64::xmmword, index); - build.vcvtsi2sd(inst.regX64, inst.regX64, eax); + inst.regX64 = regs.takeReg(eax, index); + break; + } + case IrCmd::TABLE_SETNUM: + { + IrCallWrapperX64 callWrap(regs, build, index); + callWrap.addArgument(SizeX64::qword, rState); + callWrap.addArgument(SizeX64::qword, regOp(inst.a), inst.a); + callWrap.addArgument(SizeX64::dword, regOp(inst.b), inst.b); + callWrap.call(qword[rNativeContext + offsetof(NativeContext, luaH_setnum)]); + inst.regX64 = regs.takeReg(rax, index); break; } case IrCmd::STRING_LEN: @@ -968,19 +1014,14 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) tmp1.free(); - callBarrierObject(regs, build, tmp2.release(), {}, vmRegOp(inst.b), /* ratag */ -1); + if (inst.c.kind == IrOpKind::Undef || isGCO(tagOp(inst.c))) + callBarrierObject(regs, build, tmp2.release(), {}, inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c)); break; } - case IrCmd::PREPARE_FORN: - callPrepareForN(regs, build, vmRegOp(inst.a), vmRegOp(inst.b), vmRegOp(inst.c)); - break; case IrCmd::CHECK_TAG: - { - bool continueInVm = (inst.d.kind == IrOpKind::Constant && intOp(inst.d)); build.cmp(memRegTagOp(inst.a), tagOp(inst.b)); - jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.c, continueInVm); + jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.c, next); break; - } case IrCmd::CHECK_TRUTHY: { // Constant tags which don't require boolean value check should've been removed in constant folding @@ -992,7 +1033,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { // Fail to fallback on 'nil' (falsy) build.cmp(memRegTagOp(inst.a), LUA_TNIL); - jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.c); + jumpOrAbortOnUndef(ConditionX64::Equal, inst.c, next); // Skip value test if it's not a boolean (truthy) build.cmp(memRegTagOp(inst.a), LUA_TBOOLEAN); @@ -1001,7 +1042,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // fail to fallback on 'false' boolean value (falsy) build.cmp(memRegUintOp(inst.b), 0); - jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.c); + jumpOrAbortOnUndef(ConditionX64::Equal, inst.c, next); if (inst.a.kind != IrOpKind::Constant) build.setLabel(skip); @@ -1009,11 +1050,11 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) } case IrCmd::CHECK_READONLY: build.cmp(byte[regOp(inst.a) + offsetof(Table, readonly)], 0); - jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.b); + jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.b, next); break; case IrCmd::CHECK_NO_METATABLE: build.cmp(qword[regOp(inst.a) + offsetof(Table, metatable)], 0); - jumpOrAbortOnUndef(ConditionX64::NotEqual, ConditionX64::Equal, inst.b); + jumpOrAbortOnUndef(ConditionX64::NotEqual, inst.b, next); break; case IrCmd::CHECK_SAFE_ENV: { @@ -1023,7 +1064,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(tmp.reg, qword[tmp.reg + offsetof(Closure, env)]); build.cmp(byte[tmp.reg + offsetof(Table, safeenv)], 0); - jumpOrAbortOnUndef(ConditionX64::Equal, ConditionX64::NotEqual, inst.a); + jumpOrAbortOnUndef(ConditionX64::Equal, inst.a, next); break; } case IrCmd::CHECK_ARRAY_SIZE: @@ -1034,7 +1075,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) else LUAU_ASSERT(!"Unsupported instruction form"); - jumpOrAbortOnUndef(ConditionX64::BelowEqual, ConditionX64::NotBelowEqual, inst.c); + jumpOrAbortOnUndef(ConditionX64::BelowEqual, inst.c, next); break; case IrCmd::JUMP_SLOT_MATCH: case IrCmd::CHECK_SLOT_MATCH: @@ -1080,7 +1121,13 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) build.mov(tmp.reg, dword[regOp(inst.a) + offsetof(LuaNode, key) + kOffsetOfTKeyTagNext]); build.shr(tmp.reg, kTKeyTagBits); - jumpOrAbortOnUndef(ConditionX64::NotZero, ConditionX64::Zero, inst.b); + jumpOrAbortOnUndef(ConditionX64::NotZero, inst.b, next); + break; + } + case IrCmd::CHECK_NODE_VALUE: + { + build.cmp(dword[regOp(inst.a) + offsetof(LuaNode, val) + offsetof(TValue, tt)], LUA_TNIL); + jumpOrAbortOnUndef(ConditionX64::Equal, inst.b, next); break; } case IrCmd::INTERRUPT: @@ -1109,7 +1156,7 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) callStepGc(regs, build); break; case IrCmd::BARRIER_OBJ: - callBarrierObject(regs, build, regOp(inst.a), inst.a, vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c)); + callBarrierObject(regs, build, regOp(inst.a), inst.a, inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c)); break; case IrCmd::BARRIER_TABLE_BACK: callBarrierTableFast(regs, build, regOp(inst.a), inst.a); @@ -1119,7 +1166,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) Label skip; ScopedRegX64 tmp{regs, SizeX64::qword}; - checkObjectBarrierConditions(build, tmp.reg, regOp(inst.a), vmRegOp(inst.b), inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); + + checkObjectBarrierConditions(build, tmp.reg, regOp(inst.a), inst.b, inst.c.kind == IrOpKind::Undef ? -1 : tagOp(inst.c), skip); { ScopedSpills spillGuard(regs); @@ -1182,7 +1230,8 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) // Fallbacks to non-IR instruction implementations case IrCmd::SETLIST: regs.assertAllFree(); - emitInstSetList(regs, build, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e)); + emitInstSetList( + regs, build, vmRegOp(inst.b), vmRegOp(inst.c), intOp(inst.d), uintOp(inst.e), inst.f.kind == IrOpKind::Undef ? -1 : int(uintOp(inst.f))); break; case IrCmd::CALL: regs.assertAllFree(); @@ -1560,9 +1609,17 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) regs.freeLastUseRegs(inst, index); } -void IrLoweringX64::finishBlock() +void IrLoweringX64::finishBlock(const IrBlock& curr, const IrBlock& next) { - regs.assertNoSpills(); + if (!regs.spills.empty()) + { + // If we have spills remaining, we have to immediately lower the successor block + for (uint32_t predIdx : predecessors(function.cfg, function.getBlockIndex(next))) + LUAU_ASSERT(predIdx == function.getBlockIndex(curr)); + + // And the next block cannot be a join block in cfg + LUAU_ASSERT(next.useCount == 1); + } } void IrLoweringX64::finishFunction() @@ -1583,10 +1640,22 @@ void IrLoweringX64::finishFunction() for (ExitHandler& handler : exitHandlers) { + LUAU_ASSERT(handler.pcpos != kVmExitEntryGuardPc); + build.setLabel(handler.self); + build.mov(edx, handler.pcpos * sizeof(Instruction)); build.jmp(helpers.updatePcAndContinueInVm); } + + if (stats) + { + if (regs.maxUsedSlot > kSpillSlots) + stats->regAllocErrors++; + + if (regs.maxUsedSlot > stats->maxSpillSlotsUsed) + stats->maxSpillSlotsUsed = regs.maxUsedSlot; + } } bool IrLoweringX64::hasError() const @@ -1598,50 +1667,81 @@ bool IrLoweringX64::hasError() const return false; } -bool IrLoweringX64::isFallthroughBlock(IrBlock target, IrBlock next) +bool IrLoweringX64::isFallthroughBlock(const IrBlock& target, const IrBlock& next) { return target.start == next.start; } -void IrLoweringX64::jumpOrFallthrough(IrBlock& target, IrBlock& next) +Label& IrLoweringX64::getTargetLabel(IrOp op, Label& fresh) +{ + if (op.kind == IrOpKind::Undef) + return fresh; + + if (op.kind == IrOpKind::VmExit) + { + // Special exit case that doesn't have to update pcpos + if (vmExitOp(op) == kVmExitEntryGuardPc) + return helpers.exitContinueVmClearNativeFlag; + + if (uint32_t* index = exitHandlerMap.find(vmExitOp(op))) + return exitHandlers[*index].self; + + return fresh; + } + + return labelOp(op); +} + +void IrLoweringX64::finalizeTargetLabel(IrOp op, Label& fresh) +{ + if (op.kind == IrOpKind::VmExit && fresh.id != 0 && fresh.id != helpers.exitContinueVmClearNativeFlag.id) + { + exitHandlerMap[vmExitOp(op)] = uint32_t(exitHandlers.size()); + exitHandlers.push_back({fresh, vmExitOp(op)}); + } +} + +void IrLoweringX64::jumpOrFallthrough(IrBlock& target, const IrBlock& next) { if (!isFallthroughBlock(target, next)) build.jmp(target.label); } -void IrLoweringX64::jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef, bool continueInVm) +void IrLoweringX64::jumpOrAbortOnUndef(ConditionX64 cond, IrOp target, const IrBlock& next) { - if (targetOrUndef.kind == IrOpKind::Undef) - { - if (continueInVm) - { - build.jcc(cond, helpers.exitContinueVmClearNativeFlag); - return; - } + Label fresh; + Label& label = getTargetLabel(target, fresh); - Label skip; - build.jcc(condInverse, skip); - build.ud2(); - build.setLabel(skip); - } - else if (targetOrUndef.kind == IrOpKind::VmExit) + if (target.kind == IrOpKind::Undef) { - if (uint32_t* index = exitHandlerMap.find(targetOrUndef.index)) + if (cond == ConditionX64::Count) { - build.jcc(cond, exitHandlers[*index].self); + build.ud2(); // Unconditional jump to abort is just an abort } else { - Label self; - build.jcc(cond, self); - exitHandlerMap[targetOrUndef.index] = uint32_t(exitHandlers.size()); - exitHandlers.push_back({self, targetOrUndef.index}); + build.jcc(getReverseCondition(cond), label); + build.ud2(); + build.setLabel(label); } } + else if (cond == ConditionX64::Count) + { + // Unconditional jump can be skipped if it's a fallthrough + if (target.kind == IrOpKind::VmExit || !isFallthroughBlock(blockOp(target), next)) + build.jmp(label); + } else { - build.jcc(cond, labelOp(targetOrUndef)); + build.jcc(cond, label); } + + finalizeTargetLabel(target, fresh); +} + +void IrLoweringX64::jumpOrAbortOnUndef(IrOp target, const IrBlock& next) +{ + jumpOrAbortOnUndef(ConditionX64::Count, target, next); } OperandX64 IrLoweringX64::memRegDoubleOp(IrOp op) diff --git a/CodeGen/src/IrLoweringX64.h b/CodeGen/src/IrLoweringX64.h index a8dab3c9..920ad002 100644 --- a/CodeGen/src/IrLoweringX64.h +++ b/CodeGen/src/IrLoweringX64.h @@ -19,23 +19,29 @@ namespace CodeGen struct ModuleHelpers; struct AssemblyOptions; +struct LoweringStats; namespace X64 { struct IrLoweringX64 { - IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function); + IrLoweringX64(AssemblyBuilderX64& build, ModuleHelpers& helpers, IrFunction& function, LoweringStats* stats); - void lowerInst(IrInst& inst, uint32_t index, IrBlock& next); - void finishBlock(); + void lowerInst(IrInst& inst, uint32_t index, const IrBlock& next); + void finishBlock(const IrBlock& curr, const IrBlock& next); void finishFunction(); bool hasError() const; - bool isFallthroughBlock(IrBlock target, IrBlock next); - void jumpOrFallthrough(IrBlock& target, IrBlock& next); - void jumpOrAbortOnUndef(ConditionX64 cond, ConditionX64 condInverse, IrOp targetOrUndef, bool continueInVm = false); + bool isFallthroughBlock(const IrBlock& target, const IrBlock& next); + void jumpOrFallthrough(IrBlock& target, const IrBlock& next); + + Label& getTargetLabel(IrOp op, Label& fresh); + void finalizeTargetLabel(IrOp op, Label& fresh); + + void jumpOrAbortOnUndef(ConditionX64 cond, IrOp target, const IrBlock& next); + void jumpOrAbortOnUndef(IrOp target, const IrBlock& next); void storeDoubleAsFloat(OperandX64 dst, IrOp src); @@ -71,6 +77,7 @@ struct IrLoweringX64 ModuleHelpers& helpers; IrFunction& function; + LoweringStats* stats = nullptr; IrRegAllocX64 regs; diff --git a/CodeGen/src/IrRegAllocA64.cpp b/CodeGen/src/IrRegAllocA64.cpp index 02d7df98..d2b7ff14 100644 --- a/CodeGen/src/IrRegAllocA64.cpp +++ b/CodeGen/src/IrRegAllocA64.cpp @@ -2,6 +2,7 @@ #include "IrRegAllocA64.h" #include "Luau/AssemblyBuilderA64.h" +#include "Luau/CodeGen.h" #include "Luau/IrUtils.h" #include "BitUtils.h" @@ -70,9 +71,9 @@ static int getReloadOffset(IrCmd cmd) LUAU_UNREACHABLE(); } -static AddressA64 getReloadAddress(const IrFunction& function, const IrInst& inst) +static AddressA64 getReloadAddress(const IrFunction& function, const IrInst& inst, bool limitToCurrentBlock) { - IrOp location = function.findRestoreOp(inst); + IrOp location = function.findRestoreOp(inst, limitToCurrentBlock); if (location.kind == IrOpKind::VmReg) return mem(rBase, vmRegOp(location) * sizeof(TValue) + getReloadOffset(inst.cmd)); @@ -99,7 +100,7 @@ static void restoreInst(AssemblyBuilderA64& build, uint32_t& freeSpillSlots, IrF else { LUAU_ASSERT(!inst.spilled && inst.needsReload); - AddressA64 addr = getReloadAddress(function, function.instructions[s.inst]); + AddressA64 addr = getReloadAddress(function, function.instructions[s.inst], /*limitToCurrentBlock*/ false); LUAU_ASSERT(addr.base != xzr); build.ldr(reg, addr); } @@ -109,8 +110,9 @@ static void restoreInst(AssemblyBuilderA64& build, uint32_t& freeSpillSlots, IrF inst.regA64 = reg; } -IrRegAllocA64::IrRegAllocA64(IrFunction& function, std::initializer_list> regs) +IrRegAllocA64::IrRegAllocA64(IrFunction& function, LoweringStats* stats, std::initializer_list> regs) : function(function) + , stats(stats) { for (auto& p : regs) { @@ -321,7 +323,7 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init { // instead of spilling the register to never reload it, we assume the register is not needed anymore } - else if (getReloadAddress(function, def).base != xzr) + else if (getReloadAddress(function, def, /*limitToCurrentBlock*/ true).base != xzr) { // instead of spilling the register to stack, we can reload it from VM stack/constants // we still need to record the spill for restore(start) to work @@ -329,6 +331,9 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init spills.push_back(s); def.needsReload = true; + + if (stats) + stats->spillsToRestore++; } else { @@ -345,6 +350,14 @@ size_t IrRegAllocA64::spill(AssemblyBuilderA64& build, uint32_t index, std::init spills.push_back(s); def.spilled = true; + + if (stats) + { + stats->spillsToSlot++; + + if (slot != kInvalidSpill && unsigned(slot + 1) > stats->maxSpillSlotsUsed) + stats->maxSpillSlotsUsed = slot + 1; + } } def.regA64 = noreg; @@ -411,11 +424,6 @@ void IrRegAllocA64::restoreReg(AssemblyBuilderA64& build, IrInst& inst) LUAU_ASSERT(!"Expected to find a spill record"); } -void IrRegAllocA64::assertNoSpills() const -{ - LUAU_ASSERT(spills.empty()); -} - IrRegAllocA64::Set& IrRegAllocA64::getSet(KindA64 kind) { switch (kind) diff --git a/CodeGen/src/IrRegAllocA64.h b/CodeGen/src/IrRegAllocA64.h index 854a9f10..d16fa19f 100644 --- a/CodeGen/src/IrRegAllocA64.h +++ b/CodeGen/src/IrRegAllocA64.h @@ -12,6 +12,9 @@ namespace Luau { namespace CodeGen { + +struct LoweringStats; + namespace A64 { @@ -19,7 +22,7 @@ class AssemblyBuilderA64; struct IrRegAllocA64 { - IrRegAllocA64(IrFunction& function, std::initializer_list> regs); + IrRegAllocA64(IrFunction& function, LoweringStats* stats, std::initializer_list> regs); RegisterA64 allocReg(KindA64 kind, uint32_t index); RegisterA64 allocTemp(KindA64 kind); @@ -43,8 +46,6 @@ struct IrRegAllocA64 // Restores register for a single instruction; may not assign the previously used register! void restoreReg(AssemblyBuilderA64& build, IrInst& inst); - void assertNoSpills() const; - struct Set { // which registers are in the set that the allocator manages (initialized at construction) @@ -71,6 +72,7 @@ struct IrRegAllocA64 Set& getSet(KindA64 kind); IrFunction& function; + LoweringStats* stats = nullptr; Set gpr, simd; std::vector spills; diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index b81aec8c..81cf2f4c 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/IrRegAllocX64.h" +#include "Luau/CodeGen.h" #include "Luau/IrUtils.h" #include "EmitCommonX64.h" @@ -14,9 +15,11 @@ namespace X64 static const RegisterX64 kGprAllocOrder[] = {rax, rdx, rcx, rbx, rsi, rdi, r8, r9, r10, r11}; -IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function) +IrRegAllocX64::IrRegAllocX64(AssemblyBuilderX64& build, IrFunction& function, LoweringStats* stats) : build(build) , function(function) + , stats(stats) + , usableXmmRegCount(getXmmRegisterCount(build.abi)) { freeGprMap.fill(true); gprInstUsers.fill(kInvalidInstIdx); @@ -28,7 +31,7 @@ RegisterX64 IrRegAllocX64::allocReg(SizeX64 size, uint32_t instIdx) { if (size == SizeX64::xmmword) { - for (size_t i = 0; i < freeXmmMap.size(); ++i) + for (size_t i = 0; i < usableXmmRegCount; ++i) { if (freeXmmMap[i]) { @@ -54,7 +57,12 @@ RegisterX64 IrRegAllocX64::allocReg(SizeX64 size, uint32_t instIdx) // Out of registers, spill the value with the furthest next use const std::array& regInstUsers = size == SizeX64::xmmword ? xmmInstUsers : gprInstUsers; if (uint32_t furthestUseTarget = findInstructionWithFurthestNextUse(regInstUsers); furthestUseTarget != kInvalidInstIdx) - return takeReg(function.instructions[furthestUseTarget].regX64, instIdx); + { + RegisterX64 reg = function.instructions[furthestUseTarget].regX64; + reg.size = size; // Adjust size to the requested + + return takeReg(reg, instIdx); + } LUAU_ASSERT(!"Out of registers to allocate"); return noreg; @@ -219,10 +227,16 @@ void IrRegAllocX64::preserve(IrInst& inst) spill.stackSlot = uint8_t(i); inst.spilled = true; + + if (stats) + stats->spillsToSlot++; } else { inst.needsReload = true; + + if (stats) + stats->spillsToRestore++; } spills.push_back(spill); @@ -332,7 +346,9 @@ unsigned IrRegAllocX64::findSpillStackSlot(IrValueKind valueKind) IrOp IrRegAllocX64::getRestoreOp(const IrInst& inst) const { - if (IrOp location = function.findRestoreOp(inst); location.kind == IrOpKind::VmReg || location.kind == IrOpKind::VmConst) + // When restoring the value, we allow cross-block restore because we have commited to the target location at spill time + if (IrOp location = function.findRestoreOp(inst, /*limitToCurrentBlock*/ false); + location.kind == IrOpKind::VmReg || location.kind == IrOpKind::VmConst) return location; return IrOp(); @@ -340,11 +356,16 @@ IrOp IrRegAllocX64::getRestoreOp(const IrInst& inst) const bool IrRegAllocX64::hasRestoreOp(const IrInst& inst) const { - return getRestoreOp(inst).kind != IrOpKind::None; + // When checking if value has a restore operation to spill it, we only allow it in the same block + IrOp location = function.findRestoreOp(inst, /*limitToCurrentBlock*/ true); + + return location.kind == IrOpKind::VmReg || location.kind == IrOpKind::VmConst; } OperandX64 IrRegAllocX64::getRestoreAddress(const IrInst& inst, IrOp restoreOp) { + LUAU_ASSERT(restoreOp.kind != IrOpKind::None); + switch (getCmdValueKind(inst.cmd)) { case IrValueKind::Unknown: diff --git a/CodeGen/src/IrTranslateBuiltins.cpp b/CodeGen/src/IrTranslateBuiltins.cpp index 8392ad84..279eabea 100644 --- a/CodeGen/src/IrTranslateBuiltins.cpp +++ b/CodeGen/src/IrTranslateBuiltins.cpp @@ -411,7 +411,7 @@ static BuiltinImplResult translateBuiltinBit32BinaryOp( IrOp falsey = build.block(IrBlockKind::Internal); IrOp truthy = build.block(IrBlockKind::Internal); IrOp exit = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_EQ_INT, res, build.constInt(0), falsey, truthy); + build.inst(IrCmd::JUMP_CMP_INT, res, build.constInt(0), build.cond(IrCondition::Equal), falsey, truthy); build.beginBlock(falsey); build.inst(IrCmd::STORE_INT, build.vmReg(ra), build.constInt(0)); @@ -484,7 +484,7 @@ static BuiltinImplResult translateBuiltinBit32Shift( if (!knownGoodShift) { IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_GE_UINT, vbi, build.constInt(32), fallback, block); + build.inst(IrCmd::JUMP_CMP_INT, vbi, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block); build.beginBlock(block); } @@ -549,36 +549,56 @@ static BuiltinImplResult translateBuiltinBit32Extract( IrOp vb = builtinLoadDouble(build, args); IrOp n = build.inst(IrCmd::NUM_TO_UINT, va); - IrOp f = build.inst(IrCmd::NUM_TO_INT, vb); IrOp value; if (nparams == 2) { - IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); - build.beginBlock(block); + if (vb.kind == IrOpKind::Constant) + { + int f = int(build.function.doubleOp(vb)); - // TODO: this can be optimized using a bit-select instruction (bt on x86) - IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f); - value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1)); + if (unsigned(f) >= 32) + build.inst(IrCmd::JUMP, fallback); + + // TODO: this pair can be optimized using a bit-select instruction (bt on x86) + if (f) + value = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f)); + + if ((f + 1) < 32) + value = build.inst(IrCmd::BITAND_UINT, value, build.constInt(1)); + } + else + { + IrOp f = build.inst(IrCmd::NUM_TO_INT, vb); + + IrOp block = build.block(IrBlockKind::Internal); + build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block); + build.beginBlock(block); + + // TODO: this pair can be optimized using a bit-select instruction (bt on x86) + IrOp shift = build.inst(IrCmd::BITRSHIFT_UINT, n, f); + value = build.inst(IrCmd::BITAND_UINT, shift, build.constInt(1)); + } } else { + IrOp f = build.inst(IrCmd::NUM_TO_INT, vb); + builtinCheckDouble(build, build.vmReg(args.index + 1), pcpos); IrOp vc = builtinLoadDouble(build, build.vmReg(args.index + 1)); IrOp w = build.inst(IrCmd::NUM_TO_INT, vc); IrOp block1 = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_LT_INT, f, build.constInt(0), fallback, block1); + build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(0), build.cond(IrCondition::Less), fallback, block1); build.beginBlock(block1); IrOp block2 = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_LT_INT, w, build.constInt(1), fallback, block2); + build.inst(IrCmd::JUMP_CMP_INT, w, build.constInt(1), build.cond(IrCondition::Less), fallback, block2); build.beginBlock(block2); IrOp block3 = build.block(IrBlockKind::Internal); IrOp fw = build.inst(IrCmd::ADD_INT, f, w); - build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); + build.inst(IrCmd::JUMP_CMP_INT, fw, build.constInt(33), build.cond(IrCondition::Less), block3, fallback); build.beginBlock(block3); IrOp shift = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); @@ -615,10 +635,15 @@ static BuiltinImplResult translateBuiltinBit32ExtractK( uint32_t m = ~(0xfffffffeu << w1); - IrOp nf = build.inst(IrCmd::BITRSHIFT_UINT, n, build.constInt(f)); - IrOp and_ = build.inst(IrCmd::BITAND_UINT, nf, build.constInt(m)); + IrOp result = n; - IrOp value = build.inst(IrCmd::UINT_TO_NUM, and_); + if (f) + result = build.inst(IrCmd::BITRSHIFT_UINT, result, build.constInt(f)); + + if ((f + w1 + 1) < 32) + result = build.inst(IrCmd::BITAND_UINT, result, build.constInt(m)); + + IrOp value = build.inst(IrCmd::UINT_TO_NUM, result); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), value); if (ra != arg) @@ -673,7 +698,7 @@ static BuiltinImplResult translateBuiltinBit32Replace( if (nparams == 3) { IrOp block = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_GE_UINT, f, build.constInt(32), fallback, block); + build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(32), build.cond(IrCondition::UnsignedGreaterEqual), fallback, block); build.beginBlock(block); // TODO: this can be optimized using a bit-select instruction (btr on x86) @@ -694,16 +719,16 @@ static BuiltinImplResult translateBuiltinBit32Replace( IrOp w = build.inst(IrCmd::NUM_TO_INT, vd); IrOp block1 = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_LT_INT, f, build.constInt(0), fallback, block1); + build.inst(IrCmd::JUMP_CMP_INT, f, build.constInt(0), build.cond(IrCondition::Less), fallback, block1); build.beginBlock(block1); IrOp block2 = build.block(IrBlockKind::Internal); - build.inst(IrCmd::JUMP_LT_INT, w, build.constInt(1), fallback, block2); + build.inst(IrCmd::JUMP_CMP_INT, w, build.constInt(1), build.cond(IrCondition::Less), fallback, block2); build.beginBlock(block2); IrOp block3 = build.block(IrBlockKind::Internal); IrOp fw = build.inst(IrCmd::ADD_INT, f, w); - build.inst(IrCmd::JUMP_LT_INT, fw, build.constInt(33), block3, fallback); + build.inst(IrCmd::JUMP_CMP_INT, fw, build.constInt(33), build.cond(IrCondition::Less), block3, fallback); build.beginBlock(block3); IrOp shift1 = build.inst(IrCmd::BITLSHIFT_UINT, build.constInt(0xfffffffe), build.inst(IrCmd::SUB_INT, w, build.constInt(1))); @@ -748,6 +773,28 @@ static BuiltinImplResult translateBuiltinVector(IrBuilder& build, int nparams, i return {BuiltinImplType::Full, 1}; } +static BuiltinImplResult translateBuiltinTableInsert(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) +{ + if (nparams != 2 || nresults > 0) + return {BuiltinImplType::None, -1}; + + build.loadAndCheckTag(build.vmReg(arg), LUA_TTABLE, build.vmExit(pcpos)); + + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(arg)); + build.inst(IrCmd::CHECK_READONLY, table, build.vmExit(pcpos)); + + IrOp pos = build.inst(IrCmd::ADD_INT, build.inst(IrCmd::TABLE_LEN, table), build.constInt(1)); + + IrOp setnum = build.inst(IrCmd::TABLE_SETNUM, table, pos); + + IrOp va = build.inst(IrCmd::LOAD_TVALUE, args); + build.inst(IrCmd::STORE_TVALUE, setnum, va); + + build.inst(IrCmd::BARRIER_TABLE_FORWARD, table, args, build.undef()); + + return {BuiltinImplType::Full, 0}; +} + static BuiltinImplResult translateBuiltinStringLen(IrBuilder& build, int nparams, int ra, int arg, IrOp args, int nresults, int pcpos) { if (nparams < 1 || nresults > 1) @@ -849,6 +896,8 @@ BuiltinImplResult translateBuiltin(IrBuilder& build, int bfid, int ra, int arg, return translateBuiltinTypeof(build, nparams, ra, arg, args, nresults); case LBF_VECTOR: return translateBuiltinVector(build, nparams, ra, arg, args, nresults, pcpos); + case LBF_TABLE_INSERT: + return translateBuiltinTableInsert(build, nparams, ra, arg, args, nresults, pcpos); case LBF_STRING_LEN: return translateBuiltinStringLen(build, nparams, ra, arg, args, nresults, pcpos); default: diff --git a/CodeGen/src/IrTranslation.cpp b/CodeGen/src/IrTranslation.cpp index 363a1cdb..13256789 100644 --- a/CodeGen/src/IrTranslation.cpp +++ b/CodeGen/src/IrTranslation.cpp @@ -2,6 +2,7 @@ #include "IrTranslation.h" #include "Luau/Bytecode.h" +#include "Luau/BytecodeUtils.h" #include "Luau/IrBuilder.h" #include "Luau/IrUtils.h" @@ -11,6 +12,8 @@ #include "lstate.h" #include "ltm.h" +LUAU_FASTFLAGVARIABLE(LuauImproveForN, false) + namespace Luau { namespace CodeGen @@ -169,7 +172,7 @@ void translateInstJumpIfEq(IrBuilder& build, const Instruction* pc, int pcpos, b build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); IrOp result = build.inst(IrCmd::CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(IrCondition::Equal)); - build.inst(IrCmd::JUMP_EQ_INT, result, build.constInt(0), not_ ? target : next, not_ ? next : target); + build.inst(IrCmd::JUMP_CMP_INT, result, build.constInt(0), build.cond(IrCondition::Equal), not_ ? target : next, not_ ? next : target); build.beginBlock(next); } @@ -217,7 +220,7 @@ void translateInstJumpIfCond(IrBuilder& build, const Instruction* pc, int pcpos, } IrOp result = build.inst(IrCmd::CMP_ANY, build.vmReg(ra), build.vmReg(rb), build.cond(cond)); - build.inst(IrCmd::JUMP_EQ_INT, result, build.constInt(0), reverse ? target : next, reverse ? next : target); + build.inst(IrCmd::JUMP_CMP_INT, result, build.constInt(0), build.cond(IrCondition::Equal), reverse ? target : next, reverse ? next : target); build.beginBlock(next); } @@ -261,7 +264,7 @@ void translateInstJumpxEqB(IrBuilder& build, const Instruction* pc, int pcpos) build.beginBlock(checkValue); IrOp va = build.inst(IrCmd::LOAD_INT, build.vmReg(ra)); - build.inst(IrCmd::JUMP_EQ_INT, va, build.constInt(aux & 0x1), not_ ? next : target, not_ ? target : next); + build.inst(IrCmd::JUMP_CMP_INT, va, build.constInt(aux & 0x1), build.cond(IrCondition::Equal), not_ ? next : target, not_ ? target : next); // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(next)) @@ -381,6 +384,9 @@ static void translateInstBinaryNumeric(IrBuilder& build, int ra, int rb, int rc, case TM_DIV: result = build.inst(IrCmd::DIV_NUM, vb, vc); break; + case TM_IDIV: + result = build.inst(IrCmd::IDIV_NUM, vb, vc); + break; case TM_MOD: result = build.inst(IrCmd::MOD_NUM, vb, vc); break; @@ -471,8 +477,9 @@ void translateInstLength(IrBuilder& build, const Instruction* pc, int pcpos) build.inst(IrCmd::CHECK_NO_METATABLE, vb, fallback); IrOp va = build.inst(IrCmd::TABLE_LEN, vb); + IrOp vai = build.inst(IrCmd::INT_TO_NUM, va); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), va); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), vai); build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER)); IrOp next = build.blockAtInst(pcpos + 1); @@ -526,7 +533,7 @@ void translateInstSetUpval(IrBuilder& build, const Instruction* pc, int pcpos) int ra = LUAU_INSN_A(*pc); int up = LUAU_INSN_B(*pc); - build.inst(IrCmd::SET_UPVALUE, build.vmUpvalue(up), build.vmReg(ra)); + build.inst(IrCmd::SET_UPVALUE, build.vmUpvalue(up), build.vmReg(ra), build.undef()); } void translateInstCloseUpvals(IrBuilder& build, const Instruction* pc) @@ -553,7 +560,7 @@ IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool IrOp builtinArgs = args; - if (customArgs.kind == IrOpKind::VmConst) + if (customArgs.kind == IrOpKind::VmConst && bfid != LBF_TABLE_INSERT) { TValue protok = build.function.proto->k[customArgs.index]; @@ -602,56 +609,131 @@ IrOp translateFastCallN(IrBuilder& build, const Instruction* pc, int pcpos, bool return fallback; } +// numeric for loop always ends with the computation of step that targets ra+1 +// any conditionals would result in a split basic block, so we can recover the step constants by pattern matching the IR we generated for LOADN/K +static IrOp getLoopStepK(IrBuilder& build, int ra) +{ + IrBlock& active = build.function.blocks[build.activeBlockIdx]; + + if (active.start + 2 < build.function.instructions.size()) + { + IrInst& sv = build.function.instructions[build.function.instructions.size() - 2]; + IrInst& st = build.function.instructions[build.function.instructions.size() - 1]; + + // We currently expect to match IR generated from LOADN/LOADK so we match a particular sequence of opcodes + // In the future this can be extended to cover opposite STORE order as well as STORE_SPLIT_TVALUE + if (sv.cmd == IrCmd::STORE_DOUBLE && sv.a.kind == IrOpKind::VmReg && sv.a.index == ra + 1 && sv.b.kind == IrOpKind::Constant && + st.cmd == IrCmd::STORE_TAG && st.a.kind == IrOpKind::VmReg && st.a.index == ra + 1 && build.function.tagOp(st.b) == LUA_TNUMBER) + return sv.b; + } + + return build.undef(); +} + void translateInstForNPrep(IrBuilder& build, const Instruction* pc, int pcpos) { int ra = LUAU_INSN_A(*pc); IrOp loopStart = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); IrOp loopExit = build.blockAtInst(getJumpTarget(*pc, pcpos)); - IrOp fallback = build.block(IrBlockKind::Fallback); - IrOp nextStep = build.block(IrBlockKind::Internal); - IrOp direct = build.block(IrBlockKind::Internal); - IrOp reverse = build.block(IrBlockKind::Internal); + if (FFlag::LuauImproveForN) + { + IrOp stepK = getLoopStepK(build, ra); + build.loopStepStack.push_back(stepK); - IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); - build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), fallback); - IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); - build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), fallback); - IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); - build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), fallback); - build.inst(IrCmd::JUMP, nextStep); + // When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails + // Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant + // To avoid that overhead for an extremely rare case (that doesn't even typecheck), we exit to VM to handle it + IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); + build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - // After successful conversion of arguments to number in a fallback, we return here - build.beginBlock(nextStep); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); - IrOp zero = build.constDouble(0.0); - IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); - IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + if (stepK.kind == IrOpKind::Undef) + { + IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); - // step <= 0 - build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); - // TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation + IrOp zero = build.constDouble(0.0); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); - // step <= 0 is false, check idx <= limit - build.beginBlock(direct); - build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopStart, loopExit); + // step > 0 + // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse); - // step <= 0 is true, check limit <= idx - build.beginBlock(reverse); - build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopStart, loopExit); + // Condition to start the loop: step > 0 ? idx <= limit : limit <= idx + // We invert the condition so that loopStart is the fallthrough (false) label - // Fallback will try to convert loop variables to numbers or throw an error - build.beginBlock(fallback); - build.inst(IrCmd::SET_SAVEDPC, build.constUint(pcpos + 1)); - build.inst(IrCmd::PREPARE_FORN, build.vmReg(ra + 0), build.vmReg(ra + 1), build.vmReg(ra + 2)); - build.inst(IrCmd::JUMP, nextStep); + // step > 0 is false, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + + // step > 0 is true, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + } + else + { + double stepN = build.function.doubleOp(stepK); + + // Condition to start the loop: step > 0 ? idx <= limit : limit <= idx + // We invert the condition so that loopStart is the fallthrough (false) label + if (stepN > 0) + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + else + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::NotLessEqual), loopExit, loopStart); + } + } + else + { + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); + + // When loop parameters are not numbers, VM tries to perform type coercion from string and raises an exception if that fails + // Performing that fallback in native code increases code size and complicates CFG, obscuring the values when they are constant + // To avoid that overhead for an extreemely rare case (that doesn't even typecheck), we exit to VM to handle it + IrOp tagLimit = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 0)); + build.inst(IrCmd::CHECK_TAG, tagLimit, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + IrOp tagStep = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 1)); + build.inst(IrCmd::CHECK_TAG, tagStep, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + IrOp tagIdx = build.inst(IrCmd::LOAD_TAG, build.vmReg(ra + 2)); + build.inst(IrCmd::CHECK_TAG, tagIdx, build.constTag(LUA_TNUMBER), build.vmExit(pcpos)); + + IrOp zero = build.constDouble(0.0); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + + // step <= 0 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + + // TODO: target branches can probably be arranged better, but we need tests for NaN behavior preservation + + // step <= 0 is false, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopStart, loopExit); + + // step <= 0 is true, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopStart, loopExit); + } // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(loopStart)) build.beginBlock(loopStart); + + // VM places interrupt in FORNLOOP, but that creates a likely spill point for short loops that use loop index as INTERRUPT always spills + // We place the interrupt at the beginning of the loop body instead; VM uses FORNLOOP because it doesn't want to waste an extra instruction. + // Because loop block may not have been started yet (as it's started when lowering the first instruction!), we need to defer INTERRUPT placement. + if (FFlag::LuauImproveForN) + build.interruptRequested = true; } void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) @@ -661,29 +743,76 @@ void translateInstForNLoop(IrBuilder& build, const Instruction* pc, int pcpos) IrOp loopRepeat = build.blockAtInst(getJumpTarget(*pc, pcpos)); IrOp loopExit = build.blockAtInst(pcpos + getOpLength(LuauOpcode(LUAU_INSN_OP(*pc)))); - build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + if (FFlag::LuauImproveForN) + { + LUAU_ASSERT(!build.loopStepStack.empty()); + IrOp stepK = build.loopStepStack.back(); + build.loopStepStack.pop_back(); - IrOp zero = build.constDouble(0.0); - IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); - IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + IrOp zero = build.constDouble(0.0); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = stepK.kind == IrOpKind::Undef ? build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)) : stepK; - IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); - idx = build.inst(IrCmd::ADD_NUM, idx, step); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + idx = build.inst(IrCmd::ADD_NUM, idx, step); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); - IrOp direct = build.block(IrBlockKind::Internal); - IrOp reverse = build.block(IrBlockKind::Internal); + if (stepK.kind == IrOpKind::Undef) + { + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); - // step <= 0 - build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + // step > 0 + // note: equivalent to 0 < step, but lowers into one instruction on both X64 and A64 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::Greater), direct, reverse); - // step <= 0 is false, check idx <= limit - build.beginBlock(direct); - build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + // Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx - // step <= 0 is true, check limit <= idx - build.beginBlock(reverse); - build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + // step > 0 is false, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + + // step > 0 is true, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + } + else + { + double stepN = build.function.doubleOp(stepK); + + // Condition to continue the loop: step > 0 ? idx <= limit : limit <= idx + if (stepN > 0) + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + else + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + } + } + else + { + build.inst(IrCmd::INTERRUPT, build.constUint(pcpos)); + + IrOp zero = build.constDouble(0.0); + IrOp limit = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 0)); + IrOp step = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 1)); + + IrOp idx = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(ra + 2)); + idx = build.inst(IrCmd::ADD_NUM, idx, step); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra + 2), idx); + + IrOp direct = build.block(IrBlockKind::Internal); + IrOp reverse = build.block(IrBlockKind::Internal); + + // step <= 0 + build.inst(IrCmd::JUMP_CMP_NUM, step, zero, build.cond(IrCondition::LessEqual), reverse, direct); + + // step <= 0 is false, check idx <= limit + build.beginBlock(direct); + build.inst(IrCmd::JUMP_CMP_NUM, idx, limit, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + + // step <= 0 is true, check limit <= idx + build.beginBlock(reverse); + build.inst(IrCmd::JUMP_CMP_NUM, limit, idx, build.cond(IrCondition::LessEqual), loopRepeat, loopExit); + } // Fallthrough in original bytecode is implicit, so we start next internal block here if (build.isInternalBlock(loopExit)) @@ -984,11 +1113,11 @@ void translateInstGetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); - IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, vb, build.constUint(pcpos)); + IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, vb, build.constUint(pcpos), build.vmConst(aux)); build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback); - IrOp tvn = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrSlotEl); + IrOp tvn = build.inst(IrCmd::LOAD_TVALUE, addrSlotEl, build.constInt(offsetof(LuaNode, val))); build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvn); IrOp next = build.blockAtInst(pcpos + 2); @@ -1011,13 +1140,13 @@ void translateInstSetTableKS(IrBuilder& build, const Instruction* pc, int pcpos) IrOp vb = build.inst(IrCmd::LOAD_POINTER, build.vmReg(rb)); - IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, vb, build.constUint(pcpos)); + IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, vb, build.constUint(pcpos), build.vmConst(aux)); build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback); build.inst(IrCmd::CHECK_READONLY, vb, fallback); IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); - build.inst(IrCmd::STORE_NODE_VALUE_TV, addrSlotEl, tva); + build.inst(IrCmd::STORE_TVALUE, addrSlotEl, tva, build.constInt(offsetof(LuaNode, val))); build.inst(IrCmd::BARRIER_TABLE_FORWARD, vb, build.vmReg(ra), build.undef()); @@ -1036,11 +1165,11 @@ void translateInstGetGlobal(IrBuilder& build, const Instruction* pc, int pcpos) IrOp fallback = build.block(IrBlockKind::Fallback); IrOp env = build.inst(IrCmd::LOAD_ENV); - IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, env, build.constUint(pcpos)); + IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, env, build.constUint(pcpos), build.vmConst(aux)); build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback); - IrOp tvn = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrSlotEl); + IrOp tvn = build.inst(IrCmd::LOAD_TVALUE, addrSlotEl, build.constInt(offsetof(LuaNode, val))); build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), tvn); IrOp next = build.blockAtInst(pcpos + 2); @@ -1058,13 +1187,13 @@ void translateInstSetGlobal(IrBuilder& build, const Instruction* pc, int pcpos) IrOp fallback = build.block(IrBlockKind::Fallback); IrOp env = build.inst(IrCmd::LOAD_ENV); - IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, env, build.constUint(pcpos)); + IrOp addrSlotEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, env, build.constUint(pcpos), build.vmConst(aux)); build.inst(IrCmd::CHECK_SLOT_MATCH, addrSlotEl, build.vmConst(aux), fallback); build.inst(IrCmd::CHECK_READONLY, env, fallback); IrOp tva = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(ra)); - build.inst(IrCmd::STORE_NODE_VALUE_TV, addrSlotEl, tva); + build.inst(IrCmd::STORE_TVALUE, addrSlotEl, tva, build.constInt(offsetof(LuaNode, val))); build.inst(IrCmd::BARRIER_TABLE_FORWARD, env, build.vmReg(ra), build.undef()); @@ -1136,7 +1265,7 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) build.inst(IrCmd::STORE_POINTER, build.vmReg(ra + 1), table); build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TTABLE)); - IrOp nodeEl = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrNodeEl); + IrOp nodeEl = build.inst(IrCmd::LOAD_TVALUE, addrNodeEl, build.constInt(offsetof(LuaNode, val))); build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), nodeEl); build.inst(IrCmd::JUMP, next); @@ -1149,7 +1278,7 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) build.loadAndCheckTag(indexPtr, LUA_TTABLE, fallback); IrOp index = build.inst(IrCmd::LOAD_POINTER, indexPtr); - IrOp addrIndexNodeEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, index, build.constUint(pcpos)); + IrOp addrIndexNodeEl = build.inst(IrCmd::GET_SLOT_NODE_ADDR, index, build.constUint(pcpos), build.vmConst(aux)); build.inst(IrCmd::CHECK_SLOT_MATCH, addrIndexNodeEl, build.vmConst(aux), fallback); // TODO: original 'table' was clobbered by a call inside 'FASTGETTM' @@ -1158,7 +1287,7 @@ void translateInstNamecall(IrBuilder& build, const Instruction* pc, int pcpos) build.inst(IrCmd::STORE_POINTER, build.vmReg(ra + 1), table2); build.inst(IrCmd::STORE_TAG, build.vmReg(ra + 1), build.constTag(LUA_TTABLE)); - IrOp indexNodeEl = build.inst(IrCmd::LOAD_NODE_VALUE_TV, addrIndexNodeEl); + IrOp indexNodeEl = build.inst(IrCmd::LOAD_TVALUE, addrIndexNodeEl, build.constInt(offsetof(LuaNode, val))); build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), indexNodeEl); build.inst(IrCmd::JUMP, next); diff --git a/CodeGen/src/IrTranslation.h b/CodeGen/src/IrTranslation.h index aff18b30..3736b44d 100644 --- a/CodeGen/src/IrTranslation.h +++ b/CodeGen/src/IrTranslation.h @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Bytecode.h" - #include #include "ltm.h" @@ -67,38 +65,5 @@ void translateInstAndX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp void translateInstOrX(IrBuilder& build, const Instruction* pc, int pcpos, IrOp c); void translateInstNewClosure(IrBuilder& build, const Instruction* pc, int pcpos); -inline int getOpLength(LuauOpcode op) -{ - switch (op) - { - case LOP_GETGLOBAL: - case LOP_SETGLOBAL: - case LOP_GETIMPORT: - case LOP_GETTABLEKS: - case LOP_SETTABLEKS: - case LOP_NAMECALL: - case LOP_JUMPIFEQ: - case LOP_JUMPIFLE: - case LOP_JUMPIFLT: - case LOP_JUMPIFNOTEQ: - case LOP_JUMPIFNOTLE: - case LOP_JUMPIFNOTLT: - case LOP_NEWTABLE: - case LOP_SETLIST: - case LOP_FORGLOOP: - case LOP_LOADKX: - case LOP_FASTCALL2: - case LOP_FASTCALL2K: - case LOP_JUMPXEQKNIL: - case LOP_JUMPXEQKB: - case LOP_JUMPXEQKN: - case LOP_JUMPXEQKS: - return 2; - - default: - return 1; - } -} - } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index c2d3e1a8..687b0c2a 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -32,7 +32,6 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::LOAD_INT: return IrValueKind::Int; case IrCmd::LOAD_TVALUE: - case IrCmd::LOAD_NODE_VALUE_TV: return IrValueKind::Tvalue; case IrCmd::LOAD_ENV: case IrCmd::GET_ARR_ADDR: @@ -46,7 +45,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::STORE_INT: case IrCmd::STORE_VECTOR: case IrCmd::STORE_TVALUE: - case IrCmd::STORE_NODE_VALUE_TV: + case IrCmd::STORE_SPLIT_TVALUE: return IrValueKind::None; case IrCmd::ADD_INT: case IrCmd::SUB_INT: @@ -55,6 +54,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::SUB_NUM: case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: + case IrCmd::IDIV_NUM: case IrCmd::MOD_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: @@ -72,15 +72,15 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::JUMP_IF_TRUTHY: case IrCmd::JUMP_IF_FALSY: case IrCmd::JUMP_EQ_TAG: - case IrCmd::JUMP_EQ_INT: - case IrCmd::JUMP_LT_INT: - case IrCmd::JUMP_GE_UINT: + case IrCmd::JUMP_CMP_INT: case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_CMP_NUM: case IrCmd::JUMP_SLOT_MATCH: return IrValueKind::None; case IrCmd::TABLE_LEN: - return IrValueKind::Double; + return IrValueKind::Int; + case IrCmd::TABLE_SETNUM: + return IrValueKind::Pointer; case IrCmd::STRING_LEN: return IrValueKind::Int; case IrCmd::NEW_TABLE: @@ -112,7 +112,6 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::CONCAT: case IrCmd::GET_UPVALUE: case IrCmd::SET_UPVALUE: - case IrCmd::PREPARE_FORN: case IrCmd::CHECK_TAG: case IrCmd::CHECK_TRUTHY: case IrCmd::CHECK_READONLY: @@ -121,6 +120,7 @@ IrValueKind getCmdValueKind(IrCmd cmd) case IrCmd::CHECK_ARRAY_SIZE: case IrCmd::CHECK_SLOT_MATCH: case IrCmd::CHECK_NODE_NO_NEXT: + case IrCmd::CHECK_NODE_VALUE: case IrCmd::INTERRUPT: case IrCmd::CHECK_GC: case IrCmd::BARRIER_OBJ: @@ -216,6 +216,8 @@ void removeUse(IrFunction& function, IrOp op) bool isGCO(uint8_t tag) { + LUAU_ASSERT(tag < LUA_T_COUNT); + // mirrors iscollectable(o) from VM/lobject.h return tag >= LUA_TSTRING; } @@ -387,6 +389,38 @@ void applySubstitutions(IrFunction& function, IrInst& inst) } bool compare(double a, double b, IrCondition cond) +{ + // Note: redundant bool() casts work around invalid MSVC optimization that merges cases in this switch, violating IEEE754 comparison semantics + switch (cond) + { + case IrCondition::Equal: + return a == b; + case IrCondition::NotEqual: + return a != b; + case IrCondition::Less: + return a < b; + case IrCondition::NotLess: + return !bool(a < b); + case IrCondition::LessEqual: + return a <= b; + case IrCondition::NotLessEqual: + return !bool(a <= b); + case IrCondition::Greater: + return a > b; + case IrCondition::NotGreater: + return !bool(a > b); + case IrCondition::GreaterEqual: + return a >= b; + case IrCondition::NotGreaterEqual: + return !bool(a >= b); + default: + LUAU_ASSERT(!"Unsupported condition"); + } + + return false; +} + +bool compare(int a, int b, IrCondition cond) { switch (cond) { @@ -410,6 +444,14 @@ bool compare(double a, double b, IrCondition cond) return a >= b; case IrCondition::NotGreaterEqual: return !(a >= b); + case IrCondition::UnsignedLess: + return unsigned(a) < unsigned(b); + case IrCondition::UnsignedLessEqual: + return unsigned(a) <= unsigned(b); + case IrCondition::UnsignedGreater: + return unsigned(a) > unsigned(b); + case IrCondition::UnsignedGreaterEqual: + return unsigned(a) >= unsigned(b); default: LUAU_ASSERT(!"Unsupported condition"); } @@ -463,6 +505,10 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(function.doubleOp(inst.a) / function.doubleOp(inst.b))); break; + case IrCmd::IDIV_NUM: + if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) + substitute(function, inst, build.constDouble(luai_numidiv(function.doubleOp(inst.a), function.doubleOp(inst.b)))); + break; case IrCmd::MOD_NUM: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) substitute(function, inst, build.constDouble(luai_nummod(function.doubleOp(inst.a), function.doubleOp(inst.b)))); @@ -531,31 +577,13 @@ void foldConstants(IrBuilder& build, IrFunction& function, IrBlock& block, uint3 replace(function, block, index, {IrCmd::JUMP, inst.d}); } break; - case IrCmd::JUMP_EQ_INT: + case IrCmd::JUMP_CMP_INT: if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) { - if (function.intOp(inst.a) == function.intOp(inst.b)) - replace(function, block, index, {IrCmd::JUMP, inst.c}); - else + if (compare(function.intOp(inst.a), function.intOp(inst.b), conditionOp(inst.c))) replace(function, block, index, {IrCmd::JUMP, inst.d}); - } - break; - case IrCmd::JUMP_LT_INT: - if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) - { - if (function.intOp(inst.a) < function.intOp(inst.b)) - replace(function, block, index, {IrCmd::JUMP, inst.c}); else - replace(function, block, index, {IrCmd::JUMP, inst.d}); - } - break; - case IrCmd::JUMP_GE_UINT: - if (inst.a.kind == IrOpKind::Constant && inst.b.kind == IrOpKind::Constant) - { - if (unsigned(function.intOp(inst.a)) >= unsigned(function.intOp(inst.b))) - replace(function, block, index, {IrCmd::JUMP, inst.c}); - else - replace(function, block, index, {IrCmd::JUMP, inst.d}); + replace(function, block, index, {IrCmd::JUMP, inst.e}); } break; case IrCmd::JUMP_CMP_NUM: @@ -832,5 +860,43 @@ void killUnusedBlocks(IrFunction& function) } } +std::vector getSortedBlockOrder(IrFunction& function) +{ + std::vector sortedBlocks; + sortedBlocks.reserve(function.blocks.size()); + for (uint32_t i = 0; i < function.blocks.size(); i++) + sortedBlocks.push_back(i); + + std::sort(sortedBlocks.begin(), sortedBlocks.end(), [&](uint32_t idxA, uint32_t idxB) { + const IrBlock& a = function.blocks[idxA]; + const IrBlock& b = function.blocks[idxB]; + + // Place fallback blocks at the end + if ((a.kind == IrBlockKind::Fallback) != (b.kind == IrBlockKind::Fallback)) + return (a.kind == IrBlockKind::Fallback) < (b.kind == IrBlockKind::Fallback); + + // Try to order by instruction order + if (a.sortkey != b.sortkey) + return a.sortkey < b.sortkey; + + // Chains of blocks are merged together by having the same sort key and consecutive chain key + return a.chainkey < b.chainkey; + }); + + return sortedBlocks; +} + +IrBlock& getNextBlock(IrFunction& function, std::vector& sortedBlocks, IrBlock& dummy, size_t i) +{ + for (size_t j = i + 1; j < sortedBlocks.size(); ++j) + { + IrBlock& block = function.blocks[sortedBlocks[j]]; + if (block.kind != IrBlockKind::Dead) + return block; + } + + return dummy; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrValueLocationTracking.cpp b/CodeGen/src/IrValueLocationTracking.cpp index 4536630b..04c210d9 100644 --- a/CodeGen/src/IrValueLocationTracking.cpp +++ b/CodeGen/src/IrValueLocationTracking.cpp @@ -28,6 +28,7 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) case IrCmd::STORE_INT: case IrCmd::STORE_VECTOR: case IrCmd::STORE_TVALUE: + case IrCmd::STORE_SPLIT_TVALUE: invalidateRestoreOp(inst.a); break; case IrCmd::ADJUST_STACK_TO_REG: @@ -53,11 +54,6 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) case IrCmd::GET_UPVALUE: invalidateRestoreOp(inst.a); break; - case IrCmd::PREPARE_FORN: - invalidateRestoreOp(inst.a); - invalidateRestoreOp(inst.b); - invalidateRestoreOp(inst.c); - break; case IrCmd::CALL: // Even if result count is limited, all registers starting from function (ra) might be modified invalidateRestoreVmRegs(vmRegOp(inst.a), -1); @@ -112,13 +108,14 @@ void IrValueLocationTracking::beforeInstLowering(IrInst& inst) case IrCmd::FINDUPVAL: break; - // These instrucitons read VmReg only after optimizeMemoryOperandsX64 + // These instructions read VmReg only after optimizeMemoryOperandsX64 case IrCmd::CHECK_TAG: case IrCmd::CHECK_TRUTHY: case IrCmd::ADD_NUM: case IrCmd::SUB_NUM: case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: + case IrCmd::IDIV_NUM: case IrCmd::MOD_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: diff --git a/CodeGen/src/NativeState.cpp b/CodeGen/src/NativeState.cpp index e7a0c424..7b2f068b 100644 --- a/CodeGen/src/NativeState.cpp +++ b/CodeGen/src/NativeState.cpp @@ -14,16 +14,21 @@ #include #include +LUAU_FASTINTVARIABLE(LuauCodeGenBlockSize, 4 * 1024 * 1024) +LUAU_FASTINTVARIABLE(LuauCodeGenMaxTotalSize, 256 * 1024 * 1024) + namespace Luau { namespace CodeGen { -constexpr unsigned kBlockSize = 4 * 1024 * 1024; -constexpr unsigned kMaxTotalSize = 256 * 1024 * 1024; - NativeState::NativeState() - : codeAllocator(kBlockSize, kMaxTotalSize) + : NativeState(nullptr, nullptr) +{ +} + +NativeState::NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext) + : codeAllocator{size_t(FInt::LuauCodeGenBlockSize), size_t(FInt::LuauCodeGenMaxTotalSize), allocationCallback, allocationCallbackContext} { } @@ -39,7 +44,6 @@ void initFunctions(NativeState& data) data.context.luaV_equalval = luaV_equalval; data.context.luaV_doarith = luaV_doarith; data.context.luaV_dolen = luaV_dolen; - data.context.luaV_prepareFORN = luaV_prepareFORN; data.context.luaV_gettable = luaV_gettable; data.context.luaV_settable = luaV_settable; data.context.luaV_getimport = luaV_getimport; @@ -49,6 +53,7 @@ void initFunctions(NativeState& data) data.context.luaH_new = luaH_new; data.context.luaH_clone = luaH_clone; data.context.luaH_resizearray = luaH_resizearray; + data.context.luaH_setnum = luaH_setnum; data.context.luaC_barriertable = luaC_barriertable; data.context.luaC_barrierf = luaC_barrierf; diff --git a/CodeGen/src/NativeState.h b/CodeGen/src/NativeState.h index 4aa5c8a2..f0b8561c 100644 --- a/CodeGen/src/NativeState.h +++ b/CodeGen/src/NativeState.h @@ -35,7 +35,6 @@ struct NativeContext int (*luaV_equalval)(lua_State* L, const TValue* t1, const TValue* t2) = nullptr; void (*luaV_doarith)(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op) = nullptr; void (*luaV_dolen)(lua_State* L, StkId ra, const TValue* rb) = nullptr; - void (*luaV_prepareFORN)(lua_State* L, StkId plimit, StkId pstep, StkId pinit) = nullptr; void (*luaV_gettable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_settable)(lua_State* L, const TValue* t, TValue* key, StkId val) = nullptr; void (*luaV_getimport)(lua_State* L, Table* env, TValue* k, StkId res, uint32_t id, bool propagatenil) = nullptr; @@ -45,6 +44,7 @@ struct NativeContext Table* (*luaH_new)(lua_State* L, int narray, int lnhash) = nullptr; Table* (*luaH_clone)(lua_State* L, Table* tt) = nullptr; void (*luaH_resizearray)(lua_State* L, Table* t, int nasize) = nullptr; + TValue* (*luaH_setnum)(lua_State* L, Table* t, int key); void (*luaC_barriertable)(lua_State* L, Table* t, GCObject* v) = nullptr; void (*luaC_barrierf)(lua_State* L, GCObject* o, GCObject* v) = nullptr; @@ -112,6 +112,7 @@ using GateFn = int (*)(lua_State*, Proto*, uintptr_t, NativeContext*); struct NativeState { NativeState(); + NativeState(AllocationCallback* allocationCallback, void* allocationCallbackContext); ~NativeState(); CodeAllocator codeAllocator; diff --git a/CodeGen/src/OptimizeConstProp.cpp b/CodeGen/src/OptimizeConstProp.cpp index a5e20b16..de839061 100644 --- a/CodeGen/src/OptimizeConstProp.cpp +++ b/CodeGen/src/OptimizeConstProp.cpp @@ -2,17 +2,22 @@ #include "Luau/OptimizeConstProp.h" #include "Luau/DenseHash.h" -#include "Luau/IrAnalysis.h" +#include "Luau/IrData.h" #include "Luau/IrBuilder.h" #include "Luau/IrUtils.h" #include "lua.h" #include +#include #include LUAU_FASTINTVARIABLE(LuauCodeGenMinLinearBlockPath, 3) +LUAU_FASTINTVARIABLE(LuauCodeGenReuseSlotLimit, 64) LUAU_FASTFLAGVARIABLE(DebugLuauAbortingChecks, false) +LUAU_FASTFLAGVARIABLE(LuauReuseHashSlots2, false) +LUAU_FASTFLAGVARIABLE(LuauKeepVmapLinear, false) +LUAU_FASTFLAGVARIABLE(LuauMergeTagLoads, false) namespace Luau { @@ -31,6 +36,7 @@ struct RegisterInfo bool knownNotReadonly = false; bool knownNoMetatable = false; + int knownTableArraySize = -1; }; // Load instructions are linked to target register to carry knowledge about the target @@ -95,6 +101,7 @@ struct ConstPropState info->value = value; info->knownNotReadonly = false; info->knownNoMetatable = false; + info->knownTableArraySize = -1; info->version++; } } @@ -112,6 +119,7 @@ struct ConstPropState reg.value = {}; reg.knownNotReadonly = false; reg.knownNoMetatable = false; + reg.knownTableArraySize = -1; } reg.version++; @@ -170,12 +178,17 @@ struct ConstPropState { for (int i = 0; i <= maxReg; ++i) invalidateHeap(regs[i]); + + // If table memory has changed, we can't reuse previously computed and validated table slot lookups + getSlotNodeCache.clear(); + checkSlotMatchCache.clear(); } void invalidateHeap(RegisterInfo& reg) { reg.knownNotReadonly = false; reg.knownNoMetatable = false; + reg.knownTableArraySize = -1; } void invalidateUserCall() @@ -185,6 +198,21 @@ struct ConstPropState inSafeEnv = false; } + void invalidateTableArraySize() + { + for (int i = 0; i <= maxReg; ++i) + invalidateTableArraySize(regs[i]); + + // If table memory has changed, we can't reuse previously computed and validated table slot lookups + getSlotNodeCache.clear(); + checkSlotMatchCache.clear(); + } + + void invalidateTableArraySize(RegisterInfo& reg) + { + reg.knownTableArraySize = -1; + } + void createRegLink(uint32_t instIdx, IrOp regOp) { LUAU_ASSERT(!instLink.contains(instIdx)); @@ -236,28 +264,66 @@ struct ConstPropState return IrInst{loadCmd, op}; } + uint32_t* getPreviousInstIndex(const IrInst& inst) + { + LUAU_ASSERT(useValueNumbering); + + if (uint32_t* prevIdx = valueMap.find(inst)) + { + // Previous load might have been removed as unused + if (function.instructions[*prevIdx].useCount != 0) + return prevIdx; + } + + return nullptr; + } + + uint32_t* getPreviousVersionedLoadIndex(IrCmd cmd, IrOp vmReg) + { + LUAU_ASSERT(vmReg.kind == IrOpKind::VmReg); + return getPreviousInstIndex(versionedVmRegLoad(cmd, vmReg)); + } + + std::pair getPreviousVersionedLoadForTag(uint8_t tag, IrOp vmReg) + { + if (useValueNumbering && !function.cfg.captured.regs.test(vmRegOp(vmReg))) + { + if (tag == LUA_TBOOLEAN) + { + if (uint32_t* prevIdx = getPreviousVersionedLoadIndex(IrCmd::LOAD_INT, vmReg)) + return std::make_pair(IrCmd::LOAD_INT, *prevIdx); + } + else if (tag == LUA_TNUMBER) + { + if (uint32_t* prevIdx = getPreviousVersionedLoadIndex(IrCmd::LOAD_DOUBLE, vmReg)) + return std::make_pair(IrCmd::LOAD_DOUBLE, *prevIdx); + } + else if (isGCO(tag)) + { + if (uint32_t* prevIdx = getPreviousVersionedLoadIndex(IrCmd::LOAD_POINTER, vmReg)) + return std::make_pair(IrCmd::LOAD_POINTER, *prevIdx); + } + } + + return std::make_pair(IrCmd::NOP, kInvalidInstIdx); + } + // Find existing value of the instruction that is exactly the same, or record current on for future lookups void substituteOrRecord(IrInst& inst, uint32_t instIdx) { if (!useValueNumbering) return; - if (uint32_t* prevIdx = valueMap.find(inst)) + if (uint32_t* prevIdx = getPreviousInstIndex(inst)) { - const IrInst& prev = function.instructions[*prevIdx]; - - // Previous load might have been removed as unused - if (prev.useCount != 0) - { - substitute(function, inst, IrOp{IrOpKind::Inst, *prevIdx}); - return; - } + substitute(function, inst, IrOp{IrOpKind::Inst, *prevIdx}); + return; } valueMap[inst] = instIdx; } - // Vm register load can be replaced by a previous load of the same version of the register + // VM register load can be replaced by a previous load of the same version of the register // If there is no previous load, we record the current one for future lookups void substituteOrRecordVmRegLoad(IrInst& loadInst) { @@ -274,22 +340,16 @@ struct ConstPropState IrInst versionedLoad = versionedVmRegLoad(loadInst.cmd, loadInst.a); // Check if there is a value that already has this version of the register - if (uint32_t* prevIdx = valueMap.find(versionedLoad)) + if (uint32_t* prevIdx = getPreviousInstIndex(versionedLoad)) { - const IrInst& prev = function.instructions[*prevIdx]; + // Previous value might not be linked to a register yet + // For example, it could be a NEW_TABLE stored into a register and we might need to track guards made with this value + if (!instLink.contains(*prevIdx)) + createRegLink(*prevIdx, loadInst.a); - // Previous load might have been removed as unused - if (prev.useCount != 0) - { - // Previous value might not be linked to a register yet - // For example, it could be a NEW_TABLE stored into a register and we might need to track guards made with this value - if (!instLink.contains(*prevIdx)) - createRegLink(*prevIdx, loadInst.a); - - // Substitute load instructon with the previous value - substitute(function, loadInst, IrOp{IrOpKind::Inst, *prevIdx}); - return; - } + // Substitute load instructon with the previous value + substitute(function, loadInst, IrOp{IrOpKind::Inst, *prevIdx}); + return; } uint32_t instIdx = function.getInstIndex(loadInst); @@ -330,6 +390,8 @@ struct ConstPropState instLink.clear(); valueMap.clear(); + getSlotNodeCache.clear(); + checkSlotMatchCache.clear(); } IrFunction& function; @@ -347,6 +409,9 @@ struct ConstPropState DenseHashMap instLink{~0u}; DenseHashMap valueMap; + + std::vector getSlotNodeCache; + std::vector checkSlotMatchCache; }; static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid, uint32_t firstReturnReg, int nresults) @@ -403,10 +468,8 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid case LBF_MATH_CLAMP: case LBF_MATH_SIGN: case LBF_MATH_ROUND: - case LBF_RAWSET: case LBF_RAWGET: case LBF_RAWEQUAL: - case LBF_TABLE_INSERT: case LBF_TABLE_UNPACK: case LBF_VECTOR: case LBF_BIT32_COUNTLZ: @@ -418,6 +481,12 @@ static void handleBuiltinEffects(ConstPropState& state, LuauBuiltinFunction bfid case LBF_TONUMBER: case LBF_TOSTRING: break; + case LBF_TABLE_INSERT: + state.invalidateHeap(); + return; // table.insert does not modify result registers. + case LBF_RAWSET: + state.invalidateHeap(); + break; case LBF_SETMETATABLE: state.invalidateHeap(); // TODO: only knownNoMetatable is affected and we might know which one break; @@ -434,9 +503,16 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& { case IrCmd::LOAD_TAG: if (uint8_t tag = state.tryGetTag(inst.a); tag != 0xff) + { substitute(function, inst, build.constTag(tag)); + } else if (inst.a.kind == IrOpKind::VmReg) - state.createRegLink(index, inst.a); + { + if (FFlag::LuauMergeTagLoads) + state.substituteOrRecordVmRegLoad(inst); + else + state.createRegLink(index, inst.a); + } break; case IrCmd::LOAD_POINTER: if (inst.a.kind == IrOpKind::VmReg) @@ -470,21 +546,18 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& if (inst.a.kind == IrOpKind::VmReg) { const IrOp source = inst.a; - uint32_t activeLoadDoubleValue = kInvalidInstIdx; + + IrCmd activeLoadCmd = IrCmd::NOP; + uint32_t activeLoadValue = kInvalidInstIdx; if (inst.b.kind == IrOpKind::Constant) { uint8_t value = function.tagOp(inst.b); // STORE_TAG usually follows a store of the value, but it also bumps the version of the whole register - // To be able to propagate STORE_DOUBLE into LOAD_DOUBLE, we find active LOAD_DOUBLE value and recreate it with updated version + // To be able to propagate STORE_*** into LOAD_***, we find active LOAD_*** value and recreate it with updated version // Register in this optimization cannot be captured to avoid complications in lowering (IrValueLocationTracking doesn't model it) - // If stored tag is not a number, we can skip the lookup as there won't be future loads of this register as a number - if (value == LUA_TNUMBER && !function.cfg.captured.regs.test(vmRegOp(source))) - { - if (uint32_t* prevIdx = state.valueMap.find(state.versionedVmRegLoad(IrCmd::LOAD_DOUBLE, source))) - activeLoadDoubleValue = *prevIdx; - } + std::tie(activeLoadCmd, activeLoadValue) = state.getPreviousVersionedLoadForTag(value, source); if (state.tryGetTag(source) == value) { @@ -503,9 +576,9 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateTag(source); } - // Future LOAD_DOUBLE instructions can re-use previous register version load - if (activeLoadDoubleValue != kInvalidInstIdx) - state.valueMap[state.versionedVmRegLoad(IrCmd::LOAD_DOUBLE, source)] = activeLoadDoubleValue; + // Future LOAD_*** instructions can re-use previous register version load + if (activeLoadValue != kInvalidInstIdx) + state.valueMap[state.versionedVmRegLoad(activeLoadCmd, source)] = activeLoadValue; } break; case IrCmd::STORE_POINTER: @@ -520,6 +593,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& { info->knownNotReadonly = true; info->knownNoMetatable = true; + info->knownTableArraySize = function.uintOp(instOp->a); } } } @@ -562,17 +636,62 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateValue(inst.a); break; case IrCmd::STORE_TVALUE: + if (inst.a.kind == IrOpKind::VmReg || inst.a.kind == IrOpKind::Inst) + { + if (inst.a.kind == IrOpKind::VmReg) + state.invalidate(inst.a); + + uint8_t tag = state.tryGetTag(inst.b); + IrOp value = state.tryGetValue(inst.b); + + if (inst.a.kind == IrOpKind::VmReg) + { + if (tag != 0xff) + state.saveTag(inst.a, tag); + + if (value.kind != IrOpKind::None) + state.saveValue(inst.a, value); + } + + IrCmd activeLoadCmd = IrCmd::NOP; + uint32_t activeLoadValue = kInvalidInstIdx; + + if (tag != 0xff) + { + // If we know the tag, try to extract the value from a register used by LOAD_TVALUE + if (IrInst* arg = function.asInstOp(inst.b); arg && arg->cmd == IrCmd::LOAD_TVALUE && arg->a.kind == IrOpKind::VmReg) + { + std::tie(activeLoadCmd, activeLoadValue) = state.getPreviousVersionedLoadForTag(tag, arg->a); + + if (activeLoadValue != kInvalidInstIdx) + value = IrOp{IrOpKind::Inst, activeLoadValue}; + } + } + + // If we have constant tag and value, replace TValue store with tag/value pair store + if (tag != 0xff && value.kind != IrOpKind::None && (tag == LUA_TBOOLEAN || tag == LUA_TNUMBER || isGCO(tag))) + { + replace(function, block, index, {IrCmd::STORE_SPLIT_TVALUE, inst.a, build.constTag(tag), value, inst.c}); + + // Value can be propagated to future loads of the same register + if (inst.a.kind == IrOpKind::VmReg && activeLoadValue != kInvalidInstIdx) + state.valueMap[state.versionedVmRegLoad(activeLoadCmd, inst.a)] = activeLoadValue; + } + else if (inst.a.kind == IrOpKind::VmReg) + { + state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_TVALUE); + } + } + break; + case IrCmd::STORE_SPLIT_TVALUE: if (inst.a.kind == IrOpKind::VmReg) { state.invalidate(inst.a); - if (uint8_t tag = state.tryGetTag(inst.b); tag != 0xff) - state.saveTag(inst.a, tag); + state.saveTag(inst.a, function.tagOp(inst.b)); - if (IrOp value = state.tryGetValue(inst.b); value.kind != IrOpKind::None) - state.saveValue(inst.a, value); - - state.forwardVmRegStoreToLoad(inst, IrCmd::LOAD_TVALUE); + if (inst.c.kind == IrOpKind::Constant) + state.saveValue(inst.a, inst.c); } break; case IrCmd::JUMP_IF_TRUTHY: @@ -605,44 +724,20 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& else replace(function, block, index, {IrCmd::JUMP, inst.d}); } + else if (FFlag::LuauMergeTagLoads && inst.a == inst.b) + { + replace(function, block, index, {IrCmd::JUMP, inst.c}); + } break; } - case IrCmd::JUMP_EQ_INT: + case IrCmd::JUMP_CMP_INT: { std::optional valueA = function.asIntOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a)); std::optional valueB = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)); if (valueA && valueB) { - if (*valueA == *valueB) - replace(function, block, index, {IrCmd::JUMP, inst.c}); - else - replace(function, block, index, {IrCmd::JUMP, inst.d}); - } - break; - } - case IrCmd::JUMP_LT_INT: - { - std::optional valueA = function.asIntOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a)); - std::optional valueB = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)); - - if (valueA && valueB) - { - if (*valueA < *valueB) - replace(function, block, index, {IrCmd::JUMP, inst.c}); - else - replace(function, block, index, {IrCmd::JUMP, inst.d}); - } - break; - } - case IrCmd::JUMP_GE_UINT: - { - std::optional valueA = function.asUintOp(inst.a.kind == IrOpKind::Constant ? inst.a : state.tryGetValue(inst.a)); - std::optional valueB = function.asUintOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)); - - if (valueA && valueB) - { - if (*valueA >= *valueB) + if (compare(*valueA, *valueB, conditionOp(inst.c))) replace(function, block, index, {IrCmd::JUMP, inst.c}); else replace(function, block, index, {IrCmd::JUMP, inst.d}); @@ -666,6 +761,15 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::GET_UPVALUE: state.invalidate(inst.a); break; + case IrCmd::SET_UPVALUE: + if (inst.b.kind == IrOpKind::VmReg) + { + if (uint8_t tag = state.tryGetTag(inst.b); tag != 0xff) + { + replace(function, inst.c, build.constTag(tag)); + } + } + break; case IrCmd::CHECK_TAG: { uint8_t b = function.tagOp(inst.b); @@ -768,11 +872,27 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& // These instructions don't have an effect on register/memory state we are tracking case IrCmd::NOP: - case IrCmd::LOAD_NODE_VALUE_TV: - case IrCmd::STORE_NODE_VALUE_TV: case IrCmd::LOAD_ENV: case IrCmd::GET_ARR_ADDR: + break; case IrCmd::GET_SLOT_NODE_ADDR: + if (!FFlag::LuauReuseHashSlots2) + break; + + for (uint32_t prevIdx : state.getSlotNodeCache) + { + const IrInst& prev = function.instructions[prevIdx]; + + if (prev.a == inst.a && prev.c == inst.c) + { + substitute(function, inst, IrOp{IrOpKind::Inst, prevIdx}); + return; // Break out from both the loop and the switch + } + } + + if (int(state.getSlotNodeCache.size()) < FInt::LuauCodeGenReuseSlotLimit) + state.getSlotNodeCache.push_back(index); + break; case IrCmd::GET_HASH_NODE_ADDR: case IrCmd::GET_CLOSURE_UPVAL_ADDR: break; @@ -782,6 +902,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::SUB_NUM: case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: + case IrCmd::IDIV_NUM: case IrCmd::MOD_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: @@ -801,6 +922,10 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& case IrCmd::JUMP_EQ_POINTER: case IrCmd::JUMP_SLOT_MATCH: case IrCmd::TABLE_LEN: + break; + case IrCmd::TABLE_SETNUM: + state.invalidateTableArraySize(); + break; case IrCmd::STRING_LEN: case IrCmd::NEW_TABLE: case IrCmd::DUP_TABLE: @@ -824,12 +949,52 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.substituteOrRecord(inst, index); break; case IrCmd::CHECK_ARRAY_SIZE: + { + std::optional arrayIndex = function.asIntOp(inst.b.kind == IrOpKind::Constant ? inst.b : state.tryGetValue(inst.b)); + + if (RegisterInfo* info = state.tryGetRegisterInfo(inst.a); info && arrayIndex) + { + if (info->knownTableArraySize >= 0) + { + if (unsigned(*arrayIndex) < unsigned(info->knownTableArraySize)) + { + if (FFlag::DebugLuauAbortingChecks) + replace(function, inst.c, build.undef()); + else + kill(function, inst); + } + else + { + replace(function, block, index, {IrCmd::JUMP, inst.c}); + } + } + } + break; + } case IrCmd::CHECK_SLOT_MATCH: + if (!FFlag::LuauReuseHashSlots2) + break; + + for (uint32_t prevIdx : state.checkSlotMatchCache) + { + const IrInst& prev = function.instructions[prevIdx]; + + if (prev.a == inst.a && prev.b == inst.b) + { + // Only a check for 'nil' value is left + replace(function, block, index, {IrCmd::CHECK_NODE_VALUE, inst.a, inst.c}); + return; // Break out from both the loop and the switch + } + } + + if (int(state.checkSlotMatchCache.size()) < FInt::LuauCodeGenReuseSlotLimit) + state.checkSlotMatchCache.push_back(index); + break; case IrCmd::CHECK_NODE_NO_NEXT: + case IrCmd::CHECK_NODE_VALUE: case IrCmd::BARRIER_TABLE_BACK: case IrCmd::RETURN: case IrCmd::COVERAGE: - case IrCmd::SET_UPVALUE: case IrCmd::SET_SAVEDPC: // TODO: we may be able to remove some updates to PC case IrCmd::CLOSE_UPVALS: // Doesn't change memory that we track case IrCmd::CAPTURE: @@ -880,32 +1045,34 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction& state.invalidateRegisterRange(vmRegOp(inst.a), function.uintOp(inst.b)); state.invalidateUserCall(); // TODO: if only strings and numbers are concatenated, there will be no user calls break; - case IrCmd::PREPARE_FORN: - state.invalidateValue(inst.a); - state.saveTag(inst.a, LUA_TNUMBER); - state.invalidateValue(inst.b); - state.saveTag(inst.b, LUA_TNUMBER); - state.invalidateValue(inst.c); - state.saveTag(inst.c, LUA_TNUMBER); - break; case IrCmd::INTERRUPT: state.invalidateUserCall(); break; case IrCmd::SETLIST: - state.valueMap.clear(); // TODO: this can be relaxed when x64 emitInstSetList becomes aware of register allocator + if (RegisterInfo* info = state.tryGetRegisterInfo(inst.b); info && info->knownTableArraySize >= 0) + replace(function, inst.f, build.constUint(info->knownTableArraySize)); + + // TODO: this can be relaxed when x64 emitInstSetList becomes aware of register allocator + state.valueMap.clear(); + state.getSlotNodeCache.clear(); + state.checkSlotMatchCache.clear(); break; case IrCmd::CALL: state.invalidateRegistersFrom(vmRegOp(inst.a)); state.invalidateUserCall(); - // We cannot guarantee right now that all live values can be remeterialized from non-stack memory locations + // We cannot guarantee right now that all live values can be rematerialized from non-stack memory locations // To prevent earlier values from being propagated to after the call, we have to clear the map // TODO: remove only the values that don't have a guaranteed restore location state.valueMap.clear(); break; case IrCmd::FORGLOOP: state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified - state.valueMap.clear(); // TODO: this can be relaxed when x64 emitInstForGLoop becomes aware of register allocator + + // TODO: this can be relaxed when x64 emitInstForGLoop becomes aware of register allocator + state.valueMap.clear(); + state.getSlotNodeCache.clear(); + state.checkSlotMatchCache.clear(); break; case IrCmd::FORGLOOP_FALLBACK: state.invalidateRegistersFrom(vmRegOp(inst.a) + 2); // Rn and Rn+1 are not modified @@ -969,8 +1136,15 @@ static void constPropInBlock(IrBuilder& build, IrBlock& block, ConstPropState& s constPropInInst(state, build, function, block, inst, index); } - // Value numbering and load/store propagation is not performed between blocks - state.valueMap.clear(); + if (!FFlag::LuauKeepVmapLinear) + { + // Value numbering and load/store propagation is not performed between blocks + state.valueMap.clear(); + + // Same for table slot data propagation + state.getSlotNodeCache.clear(); + state.checkSlotMatchCache.clear(); + } } static void constPropInBlockChain(IrBuilder& build, std::vector& visited, IrBlock* block, ConstPropState& state) @@ -979,6 +1153,9 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite state.clear(); + const uint32_t startSortkey = block->sortkey; + uint32_t chainPos = 0; + while (block) { uint32_t blockIdx = function.getBlockIndex(*block); @@ -987,19 +1164,40 @@ static void constPropInBlockChain(IrBuilder& build, std::vector& visite constPropInBlock(build, *block, state); + if (FFlag::LuauKeepVmapLinear) + { + // Value numbering and load/store propagation is not performed between blocks right now + // This is because cross-block value uses limit creation of linear block (restriction in collectDirectBlockJumpPath) + state.valueMap.clear(); + + // Same for table slot data propagation + state.getSlotNodeCache.clear(); + state.checkSlotMatchCache.clear(); + } + + // Blocks in a chain are guaranteed to follow each other + // We force that by giving all blocks the same sorting key, but consecutive chain keys + block->sortkey = startSortkey; + block->chainkey = chainPos++; + IrInst& termInst = function.instructions[block->finish]; IrBlock* nextBlock = nullptr; // Unconditional jump into a block with a single user (current block) allows us to continue optimization // with the information we have gathered so far (unless we have already visited that block earlier) - if (termInst.cmd == IrCmd::JUMP && termInst.a.kind != IrOpKind::VmExit) + if (termInst.cmd == IrCmd::JUMP && termInst.a.kind == IrOpKind::Block) { IrBlock& target = function.blockOp(termInst.a); uint32_t targetIdx = function.getBlockIndex(target); if (target.useCount == 1 && !visited[targetIdx] && target.kind != IrBlockKind::Fallback) + { + // Make sure block ordering guarantee is checked at lowering time + block->expectedNextBlock = function.getBlockIndex(target); + nextBlock = ⌖ + } } block = nextBlock; @@ -1027,7 +1225,7 @@ static std::vector collectDirectBlockJumpPath(IrFunction& function, st IrBlock* nextBlock = nullptr; // A chain is made from internal blocks that were not a part of bytecode CFG - if (termInst.cmd == IrCmd::JUMP && termInst.a.kind != IrOpKind::VmExit) + if (termInst.cmd == IrCmd::JUMP && termInst.a.kind == IrOpKind::Block) { IrBlock& target = function.blockOp(termInst.a); uint32_t targetIdx = function.getBlockIndex(target); @@ -1068,8 +1266,8 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited if (termInst.cmd != IrCmd::JUMP) return; - // And it can't be jump to a VM exit - if (termInst.a.kind == IrOpKind::VmExit) + // And it can't be jump to a VM exit or undef + if (termInst.a.kind != IrOpKind::Block) return; // And it has to jump to a block with more than one user @@ -1089,14 +1287,14 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited // Initialize state with the knowledge of our current block state.clear(); - // TODO: using values from the first block can cause 'live out' of the linear block predecessor to not have all required registers constPropInBlock(build, startingBlock, state); // Verify that target hasn't changed LUAU_ASSERT(function.instructions[startingBlock.finish].a.index == targetBlockIdx); // Note: using startingBlock after this line is unsafe as the reference may be reallocated by build.block() below - uint32_t startingInsn = startingBlock.start; + const uint32_t startingSortKey = startingBlock.sortkey; + const uint32_t startingChainKey = startingBlock.chainkey; // Create new linearized block into which we are going to redirect starting block jump IrOp newBlock = build.block(IrBlockKind::Linearized); @@ -1106,7 +1304,11 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited // By default, blocks are ordered according to start instruction; we alter sort order to make sure linearized block is placed right after the // starting block - function.blocks[newBlock.index].sortkey = startingInsn + 1; + function.blocks[newBlock.index].sortkey = startingSortKey; + function.blocks[newBlock.index].chainkey = startingChainKey + 1; + + // Make sure block ordering guarantee is checked at lowering time + function.blocks[blockIdx].expectedNextBlock = newBlock.index; replace(function, termInst.a, newBlock); @@ -1145,6 +1347,12 @@ static void tryCreateLinearBlock(IrBuilder& build, std::vector& visited def.varargStart = pathDef.varargStart; } } + + // Update predecessors + function.cfg.predecessorsOffsets.push_back(uint32_t(function.cfg.predecessors.size())); + function.cfg.predecessors.push_back(blockIdx); + + // Updating successors will require visiting the instructions again and we don't have a current use for linearized block successor list } // Optimize our linear block diff --git a/CodeGen/src/OptimizeFinalX64.cpp b/CodeGen/src/OptimizeFinalX64.cpp index 63642c46..3f0469fa 100644 --- a/CodeGen/src/OptimizeFinalX64.cpp +++ b/CodeGen/src/OptimizeFinalX64.cpp @@ -58,6 +58,7 @@ static void optimizeMemoryOperandsX64(IrFunction& function, IrBlock& block) case IrCmd::SUB_NUM: case IrCmd::MUL_NUM: case IrCmd::DIV_NUM: + case IrCmd::IDIV_NUM: case IrCmd::MOD_NUM: case IrCmd::MIN_NUM: case IrCmd::MAX_NUM: diff --git a/CodeGen/src/UnwindBuilderDwarf2.cpp b/CodeGen/src/UnwindBuilderDwarf2.cpp index e9df184d..08c8e831 100644 --- a/CodeGen/src/UnwindBuilderDwarf2.cpp +++ b/CodeGen/src/UnwindBuilderDwarf2.cpp @@ -225,9 +225,10 @@ void UnwindBuilderDwarf2::prologueA64(uint32_t prologueSize, uint32_t stackSize, } } -void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) +void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, + const std::vector& simd) { - LUAU_ASSERT(stackSize > 0 && stackSize <= 128 && stackSize % 8 == 0); + LUAU_ASSERT(stackSize > 0 && stackSize < 4096 && stackSize % 8 == 0); unsigned int stackOffset = 8; // Return address was pushed by calling the function unsigned int prologueOffset = 0; @@ -247,7 +248,7 @@ void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, } // push reg - for (X64::RegisterX64 reg : regs) + for (X64::RegisterX64 reg : gpr) { LUAU_ASSERT(reg.size == X64::SizeX64::qword); @@ -258,9 +259,11 @@ void UnwindBuilderDwarf2::prologueX64(uint32_t prologueSize, uint32_t stackSize, pos = defineSavedRegisterLocation(pos, regIndexToDwRegX64[reg.index], stackOffset); } + LUAU_ASSERT(simd.empty()); + // sub rsp, stackSize stackOffset += stackSize; - prologueOffset += 4; + prologueOffset += stackSize >= 128 ? 7 : 4; pos = advanceLocation(pos, 4); pos = defineCfaExpressionOffset(pos, stackOffset); diff --git a/CodeGen/src/UnwindBuilderWin.cpp b/CodeGen/src/UnwindBuilderWin.cpp index f9b927c5..336a4e3f 100644 --- a/CodeGen/src/UnwindBuilderWin.cpp +++ b/CodeGen/src/UnwindBuilderWin.cpp @@ -82,7 +82,7 @@ void UnwindBuilderWin::finishFunction(uint32_t beginOffset, uint32_t endOffset) if (!unwindCodes.empty()) { // Copy unwind codes in reverse order - // Some unwind codes take up two array slots, but we don't use those atm + // Some unwind codes take up two array slots, we write those in reverse order uint8_t* unwindCodePos = rawDataPos + sizeof(UnwindCodeWin) * (unwindCodes.size() - 1); LUAU_ASSERT(unwindCodePos <= rawData + kRawDataLimit); @@ -109,9 +109,10 @@ void UnwindBuilderWin::prologueA64(uint32_t prologueSize, uint32_t stackSize, st LUAU_ASSERT(!"Not implemented"); } -void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list regs) +void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bool setupFrame, std::initializer_list gpr, + const std::vector& simd) { - LUAU_ASSERT(stackSize > 0 && stackSize <= 128 && stackSize % 8 == 0); + LUAU_ASSERT(stackSize > 0 && stackSize < 4096 && stackSize % 8 == 0); LUAU_ASSERT(prologueSize < 256); unsigned int stackOffset = 8; // Return address was pushed by calling the function @@ -132,7 +133,7 @@ void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bo } // push reg - for (X64::RegisterX64 reg : regs) + for (X64::RegisterX64 reg : gpr) { LUAU_ASSERT(reg.size == X64::SizeX64::qword); @@ -141,10 +142,51 @@ void UnwindBuilderWin::prologueX64(uint32_t prologueSize, uint32_t stackSize, bo unwindCodes.push_back({uint8_t(prologueOffset), UWOP_PUSH_NONVOL, reg.index}); } + // If frame pointer is used, simd register storage is not implemented, it will require reworking store offsets + LUAU_ASSERT(!setupFrame || simd.size() == 0); + + unsigned int simdStorageSize = unsigned(simd.size()) * 16; + + // It's the responsibility of the caller to provide simd register storage in 'stackSize', including alignment to 16 bytes + if (!simd.empty() && stackOffset % 16 == 8) + simdStorageSize += 8; + // sub rsp, stackSize - stackOffset += stackSize; - prologueOffset += 4; - unwindCodes.push_back({uint8_t(prologueOffset), UWOP_ALLOC_SMALL, uint8_t((stackSize - 8) / 8)}); + if (stackSize <= 128) + { + stackOffset += stackSize; + prologueOffset += stackSize == 128 ? 7 : 4; + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_ALLOC_SMALL, uint8_t((stackSize - 8) / 8)}); + } + else + { + // This command can handle allocations up to 512K-8 bytes, but that potentially requires stack probing + LUAU_ASSERT(stackSize < 4096); + + stackOffset += stackSize; + prologueOffset += 7; + + uint16_t encodedOffset = stackSize / 8; + unwindCodes.push_back(UnwindCodeWin()); + memcpy(&unwindCodes.back(), &encodedOffset, sizeof(encodedOffset)); + + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_ALLOC_LARGE, 0}); + } + + // It's the responsibility of the caller to provide simd register storage in 'stackSize' + unsigned int xmmStoreOffset = stackSize - simdStorageSize; + + // vmovaps [rsp+n], xmm + for (X64::RegisterX64 reg : simd) + { + LUAU_ASSERT(reg.size == X64::SizeX64::xmmword); + LUAU_ASSERT(xmmStoreOffset % 16 == 0 && "simd stores have to be performed to aligned locations"); + + prologueOffset += xmmStoreOffset >= 128 ? 10 : 7; + unwindCodes.push_back({uint8_t(xmmStoreOffset / 16), 0, 0}); + unwindCodes.push_back({uint8_t(prologueOffset), UWOP_SAVE_XMM128, reg.index}); + xmmStoreOffset += 16; + } LUAU_ASSERT(stackOffset % 16 == 0); LUAU_ASSERT(prologueOffset == prologueSize); diff --git a/Common/include/Luau/Bytecode.h b/Common/include/Luau/Bytecode.h index 976dd04f..36b570ce 100644 --- a/Common/include/Luau/Bytecode.h +++ b/Common/include/Luau/Bytecode.h @@ -44,7 +44,7 @@ // Version 1: Baseline version for the open-source release. Supported until 0.521. // Version 2: Adds Proto::linedefined. Supported until 0.544. // Version 3: Adds FORGPREP/JUMPXEQK* and enhances AUX encoding for FORGLOOP. Removes FORGLOOP_NEXT/INEXT and JUMPIFEQK/JUMPIFNOTEQK. Currently supported. -// Version 4: Adds Proto::flags and typeinfo. Currently supported. +// Version 4: Adds Proto::flags, typeinfo, and floor division opcodes IDIV/IDIVK. Currently supported. // Bytecode opcode, part of the instruction header enum LuauOpcode @@ -390,6 +390,18 @@ enum LuauOpcode LOP_JUMPXEQKN, LOP_JUMPXEQKS, + // IDIV: compute floor division between two source registers and put the result into target register + // A: target register + // B: source register 1 + // C: source register 2 + LOP_IDIV, + + // IDIVK compute floor division between the source register and a constant and put the result into target register + // A: target register + // B: source register + // C: constant table index (0..255) + LOP_IDIVK, + // Enum entry for number of opcodes, not a valid opcode by itself! LOP__COUNT }; diff --git a/Common/include/Luau/BytecodeUtils.h b/Common/include/Luau/BytecodeUtils.h new file mode 100644 index 00000000..957c804c --- /dev/null +++ b/Common/include/Luau/BytecodeUtils.h @@ -0,0 +1,42 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Bytecode.h" + +namespace Luau +{ + +inline int getOpLength(LuauOpcode op) +{ + switch (op) + { + case LOP_GETGLOBAL: + case LOP_SETGLOBAL: + case LOP_GETIMPORT: + case LOP_GETTABLEKS: + case LOP_SETTABLEKS: + case LOP_NAMECALL: + case LOP_JUMPIFEQ: + case LOP_JUMPIFLE: + case LOP_JUMPIFLT: + case LOP_JUMPIFNOTEQ: + case LOP_JUMPIFNOTLE: + case LOP_JUMPIFNOTLT: + case LOP_NEWTABLE: + case LOP_SETLIST: + case LOP_FORGLOOP: + case LOP_LOADKX: + case LOP_FASTCALL2: + case LOP_FASTCALL2K: + case LOP_JUMPXEQKNIL: + case LOP_JUMPXEQKB: + case LOP_JUMPXEQKN: + case LOP_JUMPXEQKS: + return 2; + + default: + return 1; + } +} + +} // namespace Luau diff --git a/Common/include/Luau/DenseHash.h b/Common/include/Luau/DenseHash.h index 997e090f..72aa6ec5 100644 --- a/Common/include/Luau/DenseHash.h +++ b/Common/include/Luau/DenseHash.h @@ -282,6 +282,13 @@ public: class const_iterator { public: + using value_type = Item; + using reference = Item&; + using pointer = Item*; + using iterator = pointer; + using difference_type = size_t; + using iterator_category = std::input_iterator_tag; + const_iterator() : set(0) , index(0) diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index fd074be1..94d13ea8 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -13,6 +13,7 @@ inline bool isFlagExperimental(const char* flag) static const char* const kList[] = { "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code "LuauTinyControlFlowAnalysis", // waiting for updates to packages depended by internal builtin plugins + "LuauFixIndexerSubtypingOrdering", // requires some small fixes to lua-apps code since this fixes a false negative // makes sure we always have at least one entry nullptr, }; diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index 3044e448..f5098d17 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -15,7 +15,7 @@ class BytecodeEncoder public: virtual ~BytecodeEncoder() {} - virtual uint8_t encodeOp(uint8_t op) = 0; + virtual void encode(uint32_t* data, size_t count) = 0; }; class BytecodeBuilder diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 36a21a72..9a661db2 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -39,7 +39,7 @@ struct CompileOptions const char* vectorType = nullptr; // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these - const char** mutableGlobals = nullptr; + const char* const* mutableGlobals = nullptr; }; class CompileError : public std::exception diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h index 7c59ce0b..60cdeee7 100644 --- a/Compiler/include/luacode.h +++ b/Compiler/include/luacode.h @@ -35,7 +35,7 @@ struct lua_CompileOptions const char* vectorType; // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these - const char** mutableGlobals; + const char* const* mutableGlobals; }; // compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 4ec083bb..a15c8f08 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,9 +4,6 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" -LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinTonumber, false) -LUAU_FASTFLAGVARIABLE(LuauCompileBuiltinTostring, false) - namespace Luau { namespace Compile @@ -72,9 +69,9 @@ static int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& op if (builtin.isGlobal("setmetatable")) return LBF_SETMETATABLE; - if (FFlag::LuauCompileBuiltinTonumber && builtin.isGlobal("tonumber")) + if (builtin.isGlobal("tonumber")) return LBF_TONUMBER; - if (FFlag::LuauCompileBuiltinTostring && builtin.isGlobal("tostring")) + if (builtin.isGlobal("tostring")) return LBF_TOSTRING; if (builtin.object == "math") diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index eeb9c10e..4ebed69a 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BytecodeBuilder.h" +#include "Luau/BytecodeUtils.h" #include "Luau/StringUtils.h" #include @@ -8,6 +9,8 @@ LUAU_FASTFLAGVARIABLE(BytecodeVersion4, false) +LUAU_FASTFLAG(LuauFloorDivision) + namespace Luau { @@ -55,39 +58,6 @@ static void writeVarInt(std::string& ss, unsigned int value) } while (value); } -static int getOpLength(LuauOpcode op) -{ - switch (op) - { - case LOP_GETGLOBAL: - case LOP_SETGLOBAL: - case LOP_GETIMPORT: - case LOP_GETTABLEKS: - case LOP_SETTABLEKS: - case LOP_NAMECALL: - case LOP_JUMPIFEQ: - case LOP_JUMPIFLE: - case LOP_JUMPIFLT: - case LOP_JUMPIFNOTEQ: - case LOP_JUMPIFNOTLE: - case LOP_JUMPIFNOTLT: - case LOP_NEWTABLE: - case LOP_SETLIST: - case LOP_FORGLOOP: - case LOP_LOADKX: - case LOP_FASTCALL2: - case LOP_FASTCALL2K: - case LOP_JUMPXEQKNIL: - case LOP_JUMPXEQKB: - case LOP_JUMPXEQKN: - case LOP_JUMPXEQKS: - return 2; - - default: - return 1; - } -} - inline bool isJumpD(LuauOpcode op) { switch (op) @@ -262,17 +232,20 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues, uin validate(); #endif + // this call is indirect to make sure we only gain link time dependency on dumpCurrentFunction when needed + if (dumpFunctionPtr) + func.dump = (this->*dumpFunctionPtr)(func.dumpinstoffs); + // very approximate: 4 bytes per instruction for code, 1 byte for debug line, and 1-2 bytes for aux data like constants plus overhead func.data.reserve(32 + insns.size() * 7); + if (encoder) + encoder->encode(insns.data(), insns.size()); + writeFunction(func.data, currentFunction, flags); currentFunction = ~0u; - // this call is indirect to make sure we only gain link time dependency on dumpCurrentFunction when needed - if (dumpFunctionPtr) - func.dump = (this->*dumpFunctionPtr)(func.dumpinstoffs); - insns.clear(); lines.clear(); constants.clear(); @@ -653,21 +626,8 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id, uint8_t flags) // instructions writeVarInt(ss, uint32_t(insns.size())); - for (size_t i = 0; i < insns.size();) - { - uint8_t op = LUAU_INSN_OP(insns[i]); - LUAU_ASSERT(op < LOP__COUNT); - - int oplen = getOpLength(LuauOpcode(op)); - uint8_t openc = encoder ? encoder->encodeOp(op) : op; - - writeInt(ss, openc | (insns[i] & ~0xff)); - - for (int j = 1; j < oplen; ++j) - writeInt(ss, insns[i + j]); - - i += oplen; - } + for (uint32_t insn : insns) + writeInt(ss, insn); // constants writeVarInt(ss, uint32_t(constants.size())); @@ -1326,8 +1286,11 @@ void BytecodeBuilder::validateInstructions() const case LOP_SUB: case LOP_MUL: case LOP_DIV: + case LOP_IDIV: case LOP_MOD: case LOP_POW: + LUAU_ASSERT(FFlag::LuauFloorDivision || op != LOP_IDIV); + VREG(LUAU_INSN_A(insn)); VREG(LUAU_INSN_B(insn)); VREG(LUAU_INSN_C(insn)); @@ -1337,8 +1300,11 @@ void BytecodeBuilder::validateInstructions() const case LOP_SUBK: case LOP_MULK: case LOP_DIVK: + case LOP_IDIVK: case LOP_MODK: case LOP_POWK: + LUAU_ASSERT(FFlag::LuauFloorDivision || op != LOP_IDIVK); + VREG(LUAU_INSN_A(insn)); VREG(LUAU_INSN_B(insn)); VCONST(LUAU_INSN_C(insn), Number); @@ -1905,6 +1871,12 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, formatAppend(result, "DIV R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); break; + case LOP_IDIV: + LUAU_ASSERT(FFlag::LuauFloorDivision); + + formatAppend(result, "IDIV R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + case LOP_MOD: formatAppend(result, "MOD R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); break; @@ -1937,6 +1909,14 @@ void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, result.append("]\n"); break; + case LOP_IDIVK: + LUAU_ASSERT(FFlag::LuauFloorDivision); + + formatAppend(result, "IDIVK R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + dumpConstant(result, LUAU_INSN_C(insn)); + result.append("]\n"); + break; + case LOP_MODK: formatAppend(result, "MODK R%d R%d K%d [", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); dumpConstant(result, LUAU_INSN_C(insn)); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 6ae31825..4fdf659b 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -26,12 +26,7 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) -LUAU_FASTFLAGVARIABLE(LuauCompileFunctionType, false) -LUAU_FASTFLAGVARIABLE(LuauCompileNativeComment, false) - -LUAU_FASTFLAGVARIABLE(LuauCompileFixBuiltinArity, false) - -LUAU_FASTFLAGVARIABLE(LuauCompileFoldMathK, false) +LUAU_FASTFLAG(LuauFloorDivision) namespace Luau { @@ -209,12 +204,9 @@ struct Compiler setDebugLine(func); - if (FFlag::LuauCompileFunctionType) - { - // note: we move types out of typeMap which is safe because compileFunction is only called once per function - if (std::string* funcType = typeMap.find(func)) - bytecode.setFunctionTypeInfo(std::move(*funcType)); - } + // note: we move types out of typeMap which is safe because compileFunction is only called once per function + if (std::string* funcType = typeMap.find(func)) + bytecode.setFunctionTypeInfo(std::move(*funcType)); if (func->vararg) bytecode.emitABC(LOP_PREPVARARGS, uint8_t(self + func->args.size), 0, 0); @@ -799,15 +791,10 @@ struct Compiler return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); else if (options.optimizationLevel >= 2) { - if (FFlag::LuauCompileFixBuiltinArity) - { - // when a builtin is none-safe with matching arity, even if the last expression returns 0 or >1 arguments, - // we can rely on the behavior of the function being the same (none-safe means nil and none are interchangeable) - BuiltinInfo info = getBuiltinInfo(bfid); - if (int(expr->args.size) == info.params && (info.flags & BuiltinInfo::Flag_NoneSafe) != 0) - return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); - } - else if (int(expr->args.size) == getBuiltinInfo(bfid).params) + // when a builtin is none-safe with matching arity, even if the last expression returns 0 or >1 arguments, + // we can rely on the behavior of the function being the same (none-safe means nil and none are interchangeable) + BuiltinInfo info = getBuiltinInfo(bfid); + if (int(expr->args.size) == info.params && (info.flags & BuiltinInfo::Flag_NoneSafe) != 0) return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); } } @@ -1034,6 +1021,11 @@ struct Compiler case AstExprBinary::Div: return k ? LOP_DIVK : LOP_DIV; + case AstExprBinary::FloorDiv: + LUAU_ASSERT(FFlag::LuauFloorDivision); + + return k ? LOP_IDIVK : LOP_IDIV; + case AstExprBinary::Mod: return k ? LOP_MODK : LOP_MOD; @@ -1484,9 +1476,12 @@ struct Compiler case AstExprBinary::Sub: case AstExprBinary::Mul: case AstExprBinary::Div: + case AstExprBinary::FloorDiv: case AstExprBinary::Mod: case AstExprBinary::Pow: { + LUAU_ASSERT(FFlag::LuauFloorDivision || expr->op != AstExprBinary::FloorDiv); + int32_t rc = getConstantNumber(expr->right); if (rc >= 0 && rc <= 255) @@ -3207,9 +3202,12 @@ struct Compiler case AstExprBinary::Sub: case AstExprBinary::Mul: case AstExprBinary::Div: + case AstExprBinary::FloorDiv: case AstExprBinary::Mod: case AstExprBinary::Pow: { + LUAU_ASSERT(FFlag::LuauFloorDivision || stat->op != AstExprBinary::FloorDiv); + if (var.kind != LValue::Kind_Local) compileLValueUse(var, target, /* set= */ false); @@ -3620,9 +3618,8 @@ struct Compiler { node->body->visit(this); - if (FFlag::LuauCompileFunctionType) - for (AstLocal* arg : node->args) - hasTypes |= arg->annotation != nullptr; + for (AstLocal* arg : node->args) + hasTypes |= arg->annotation != nullptr; // this makes sure all functions that are used when compiling this one have been already added to the vector functions.push_back(node); @@ -3863,7 +3860,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c if (hc.header && hc.content.compare(0, 9, "optimize ") == 0) options.optimizationLevel = std::max(0, std::min(2, atoi(hc.content.c_str() + 9))); - if (FFlag::LuauCompileNativeComment && hc.header && hc.content == "native") + if (hc.header && hc.content == "native") { mainFlags |= LPF_NATIVE_MODULE; options.optimizationLevel = 2; // note: this might be removed in the future in favor of --!optimize @@ -3885,9 +3882,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c { compiler.builtinsFold = &compiler.builtins; - if (FFlag::LuauCompileFoldMathK) - if (AstName math = names.get("math"); math.value && getGlobalState(compiler.globals, math) == Global::Default) - compiler.builtinsFoldMathK = true; + if (AstName math = names.get("math"); math.value && getGlobalState(compiler.globals, math) == Global::Default) + compiler.builtinsFoldMathK = true; } if (options.optimizationLevel >= 1) @@ -3916,7 +3912,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c root->visit(&functionVisitor); // computes type information for all functions based on type annotations - if (FFlag::LuauCompileFunctionType && functionVisitor.hasTypes) + if (functionVisitor.hasTypes) buildTypeMap(compiler.typeMap, root, options.vectorType); for (AstExprFunction* expr : functions) diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index a49a7748..5b098a11 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -104,6 +104,14 @@ static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& l } break; + case AstExprBinary::FloorDiv: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = floor(la.valueNumber / ra.valueNumber); + } + break; + case AstExprBinary::Mod: if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) { diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index 2f7af6ea..04adf3e3 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -6,8 +6,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauAssignmentHasCost, false) - namespace Luau { namespace Compile @@ -308,10 +306,7 @@ struct CostVisitor : AstVisitor { // unconditional 'else' may require a jump after the 'if' body // note: this ignores cases when 'then' always terminates and also assumes comparison requires an extra instruction which may be false - if (!FFlag::LuauAssignmentHasCost) - result += 2; - else - result += 1 + (node->elsebody && !node->elsebody->is()); + result += 1 + (node->elsebody && !node->elsebody->is()); return true; } @@ -337,9 +332,6 @@ struct CostVisitor : AstVisitor for (size_t i = 0; i < node->vars.size; ++i) assign(node->vars.data[i]); - if (!FFlag::LuauAssignmentHasCost) - return true; - for (size_t i = 0; i < node->vars.size || i < node->values.size; ++i) { Cost ac; diff --git a/Compiler/src/Types.cpp b/Compiler/src/Types.cpp index e85cc92c..8ac74d02 100644 --- a/Compiler/src/Types.cpp +++ b/Compiler/src/Types.cpp @@ -1,8 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/BytecodeBuilder.h" - #include "Types.h" +#include "Luau/BytecodeBuilder.h" + namespace Luau { diff --git a/Compiler/src/Types.h b/Compiler/src/Types.h index cad55ab5..62f9e916 100644 --- a/Compiler/src/Types.h +++ b/Compiler/src/Types.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/Ast.h" +#include "Luau/DenseHash.h" #include diff --git a/Compiler/src/ValueTracking.cpp b/Compiler/src/ValueTracking.cpp index 0bfaf9b3..616fca99 100644 --- a/Compiler/src/ValueTracking.cpp +++ b/Compiler/src/ValueTracking.cpp @@ -82,13 +82,13 @@ struct ValueVisitor : AstVisitor } }; -void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char** mutableGlobals) +void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char* const* mutableGlobals) { if (AstName name = names.get("_G"); name.value) globals[name] = Global::Mutable; if (mutableGlobals) - for (const char** ptr = mutableGlobals; *ptr; ++ptr) + for (const char* const* ptr = mutableGlobals; *ptr; ++ptr) if (AstName name = names.get(*ptr); name.value) globals[name] = Global::Mutable; } diff --git a/Compiler/src/ValueTracking.h b/Compiler/src/ValueTracking.h index fc74c84a..f8ecc6b8 100644 --- a/Compiler/src/ValueTracking.h +++ b/Compiler/src/ValueTracking.h @@ -28,7 +28,7 @@ struct Variable bool constant = false; // is the variable's value a compile-time constant? filled by constantFold }; -void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char** mutableGlobals); +void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char* const* mutableGlobals); void trackValues(DenseHashMap& globals, DenseHashMap& variables, AstNode* root); inline Global getGlobalState(const DenseHashMap& globals, AstName name) diff --git a/Analysis/include/Luau/Config.h b/Config/include/Luau/Config.h similarity index 97% rename from Analysis/include/Luau/Config.h rename to Config/include/Luau/Config.h index 88c10554..ae865b64 100644 --- a/Analysis/include/Luau/Config.h +++ b/Config/include/Luau/Config.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Linter.h" +#include "Luau/LinterConfig.h" #include "Luau/ParseOptions.h" #include diff --git a/Config/include/Luau/LinterConfig.h b/Config/include/Luau/LinterConfig.h new file mode 100644 index 00000000..6b24db17 --- /dev/null +++ b/Config/include/Luau/LinterConfig.h @@ -0,0 +1,124 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" + +#include +#include +#include + +#include + +namespace Luau +{ + +struct HotComment; + +struct LintWarning +{ + // Make sure any new lint codes are documented here: https://luau-lang.org/lint + // Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints + enum Code + { + Code_Unknown = 0, + + Code_UnknownGlobal = 1, // superseded by type checker + Code_DeprecatedGlobal = 2, + Code_GlobalUsedAsLocal = 3, + Code_LocalShadow = 4, // disabled in Studio + Code_SameLineStatement = 5, // disabled in Studio + Code_MultiLineStatement = 6, + Code_LocalUnused = 7, // disabled in Studio + Code_FunctionUnused = 8, // disabled in Studio + Code_ImportUnused = 9, // disabled in Studio + Code_BuiltinGlobalWrite = 10, + Code_PlaceholderRead = 11, + Code_UnreachableCode = 12, + Code_UnknownType = 13, + Code_ForRange = 14, + Code_UnbalancedAssignment = 15, + Code_ImplicitReturn = 16, // disabled in Studio, superseded by type checker in strict mode + Code_DuplicateLocal = 17, + Code_FormatString = 18, + Code_TableLiteral = 19, + Code_UninitializedLocal = 20, + Code_DuplicateFunction = 21, + Code_DeprecatedApi = 22, + Code_TableOperations = 23, + Code_DuplicateCondition = 24, + Code_MisleadingAndOr = 25, + Code_CommentDirective = 26, + Code_IntegerParsing = 27, + Code_ComparisonPrecedence = 28, + + Code__Count + }; + + Code code; + Location location; + std::string text; + + static const char* getName(Code code); + static Code parseName(const char* name); + static uint64_t parseMask(const std::vector& hotcomments); +}; + +struct LintOptions +{ + uint64_t warningMask = 0; + + void enableWarning(LintWarning::Code code) + { + warningMask |= 1ull << code; + } + void disableWarning(LintWarning::Code code) + { + warningMask &= ~(1ull << code); + } + + bool isEnabled(LintWarning::Code code) const + { + return 0 != (warningMask & (1ull << code)); + } + + void setDefaults(); +}; + +// clang-format off +static const char* kWarningNames[] = { + "Unknown", + + "UnknownGlobal", + "DeprecatedGlobal", + "GlobalUsedAsLocal", + "LocalShadow", + "SameLineStatement", + "MultiLineStatement", + "LocalUnused", + "FunctionUnused", + "ImportUnused", + "BuiltinGlobalWrite", + "PlaceholderRead", + "UnreachableCode", + "UnknownType", + "ForRange", + "UnbalancedAssignment", + "ImplicitReturn", + "DuplicateLocal", + "FormatString", + "TableLiteral", + "UninitializedLocal", + "DuplicateFunction", + "DeprecatedApi", + "TableOperations", + "DuplicateCondition", + "MisleadingAndOr", + "CommentDirective", + "IntegerParsing", + "ComparisonPrecedence", +}; +// clang-format on + +static_assert(std::size(kWarningNames) == unsigned(LintWarning::Code__Count), "did you forget to add warning to the list?"); + +} // namespace Luau diff --git a/Analysis/src/Config.cpp b/Config/src/Config.cpp similarity index 95% rename from Analysis/src/Config.cpp rename to Config/src/Config.cpp index 9369743e..8e9802cf 100644 --- a/Analysis/src/Config.cpp +++ b/Config/src/Config.cpp @@ -4,6 +4,7 @@ #include "Luau/Lexer.h" #include "Luau/StringUtils.h" +LUAU_FASTFLAG(LuauFloorDivision) namespace Luau { @@ -112,14 +113,23 @@ static void next(Lexer& lexer) lexer.next(); // skip C-style comments as Lexer only understands Lua-style comments atm - while (lexer.current().type == '/') + + if (FFlag::LuauFloorDivision) { - Lexeme peek = lexer.lookahead(); + while (lexer.current().type == Luau::Lexeme::FloorDiv) + lexer.nextline(); + } + else + { + while (lexer.current().type == '/') + { + Lexeme peek = lexer.lookahead(); - if (peek.type != '/' || peek.location.begin != lexer.current().location.end) - break; + if (peek.type != '/' || peek.location.begin != lexer.current().location.end) + break; - lexer.nextline(); + lexer.nextline(); + } } } diff --git a/Config/src/LinterConfig.cpp b/Config/src/LinterConfig.cpp new file mode 100644 index 00000000..d63969d5 --- /dev/null +++ b/Config/src/LinterConfig.cpp @@ -0,0 +1,63 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/LinterConfig.h" + +#include "Luau/ParseResult.h" + +namespace Luau +{ + +void LintOptions::setDefaults() +{ + // By default, we enable all warnings + warningMask = ~0ull; +} + +const char* LintWarning::getName(Code code) +{ + LUAU_ASSERT(unsigned(code) < Code__Count); + + return kWarningNames[code]; +} + +LintWarning::Code LintWarning::parseName(const char* name) +{ + for (int code = Code_Unknown; code < Code__Count; ++code) + if (strcmp(name, getName(Code(code))) == 0) + return Code(code); + + return Code_Unknown; +} + +uint64_t LintWarning::parseMask(const std::vector& hotcomments) +{ + uint64_t result = 0; + + for (const HotComment& hc : hotcomments) + { + if (!hc.header) + continue; + + if (hc.content.compare(0, 6, "nolint") != 0) + continue; + + size_t name = hc.content.find_first_not_of(" \t", 6); + + // --!nolint disables everything + if (name == std::string::npos) + return ~0ull; + + // --!nolint needs to be followed by a whitespace character + if (name == 6) + continue; + + // --!nolint name disables the specific lint + LintWarning::Code code = LintWarning::parseName(hc.content.c_str() + name); + + if (code != LintWarning::Code_Unknown) + result |= 1ull << int(code); + } + + return result; +} + +} // namespace Luau diff --git a/LICENSE.txt b/LICENSE.txt index 40fc04c1..34496ced 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2019-2022 Roblox Corporation +Copyright (c) 2019-2023 Roblox Corporation Copyright (c) 1994–2019 Lua.org, PUC-Rio. Permission is hereby granted, free of charge, to any person obtaining a copy of @@ -19,4 +19,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +SOFTWARE. diff --git a/Makefile b/Makefile index 17bae919..2e5e9791 100644 --- a/Makefile +++ b/Makefile @@ -16,6 +16,10 @@ COMPILER_SOURCES=$(wildcard Compiler/src/*.cpp) COMPILER_OBJECTS=$(COMPILER_SOURCES:%=$(BUILD)/%.o) COMPILER_TARGET=$(BUILD)/libluaucompiler.a +CONFIG_SOURCES=$(wildcard Config/src/*.cpp) +CONFIG_OBJECTS=$(CONFIG_SOURCES:%=$(BUILD)/%.o) +CONFIG_TARGET=$(BUILD)/libluauconfig.a + ANALYSIS_SOURCES=$(wildcard Analysis/src/*.cpp) ANALYSIS_OBJECTS=$(ANALYSIS_SOURCES:%=$(BUILD)/%.o) ANALYSIS_TARGET=$(BUILD)/libluauanalysis.a @@ -59,7 +63,7 @@ ifneq ($(opt),) TESTS_ARGS+=-O$(opt) endif -OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(COMPILE_CLI_OBJECTS) $(FUZZ_OBJECTS) +OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(CONFIG_OBJECTS) $(ANALYSIS_OBJECTS) $(CODEGEN_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(REPL_CLI_OBJECTS) $(ANALYZE_CLI_OBJECTS) $(COMPILE_CLI_OBJECTS) $(FUZZ_OBJECTS) EXECUTABLE_ALIASES = luau luau-analyze luau-compile luau-tests # common flags @@ -117,27 +121,27 @@ ifeq ($(protobuf),download) endif ifneq ($(native),) - CXXFLAGS+=-DLUA_CUSTOM_EXECUTION=1 TESTS_ARGS+=--codegen endif ifneq ($(nativelj),) - CXXFLAGS+=-DLUA_CUSTOM_EXECUTION=1 -DLUA_USE_LONGJMP=1 + CXXFLAGS+=-DLUA_USE_LONGJMP=1 TESTS_ARGS+=--codegen endif # target-specific flags $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -ICommon/include -IAst/include -$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include +$(CONFIG_OBJECTS): CXXFLAGS+=-std=c++17 -IConfig/include -ICommon/include -IAst/include +$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IConfig/include $(CODEGEN_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -ICodeGen/include -IVM/include -IVM/src # Code generation needs VM internals $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -ICommon/include -IVM/include $(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include -$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY +$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IConfig/include -IAnalysis/include -ICodeGen/include -IVM/include -ICLI -Iextern -DDOCTEST_CONFIG_DOUBLE_STRINGIFY $(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -Iextern -Iextern/isocline/include -$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -Iextern +$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -IAnalysis/include -IConfig/include -Iextern $(COMPILE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IVM/include -ICodeGen/include -$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICodeGen/include +$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICodeGen/include -IConfig/include $(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread @@ -205,30 +209,31 @@ luau-tests: $(TESTS_TARGET) ln -fs $^ $@ # executable targets -$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) +$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) -$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(AST_TARGET) +$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(COMPILE_CLI_TARGET): $(COMPILE_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(COMPILE_CLI_TARGET): $(CXX) $^ $(LDFLAGS) -o $@ # executable targets for fuzzing -fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) +fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(CXX) $^ $(LDFLAGS) -o $@ -fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator -fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator +fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator +fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(CONFIG_TARGET) $(VM_TARGET) | build/libprotobuf-mutator # static library targets $(AST_TARGET): $(AST_OBJECTS) $(COMPILER_TARGET): $(COMPILER_OBJECTS) +$(CONFIG_TARGET): $(CONFIG_OBJECTS) $(ANALYSIS_TARGET): $(ANALYSIS_OBJECTS) $(CODEGEN_TARGET): $(CODEGEN_OBJECTS) $(VM_TARGET): $(VM_OBJECTS) $(ISOCLINE_TARGET): $(ISOCLINE_OBJECTS) -$(AST_TARGET) $(COMPILER_TARGET) $(ANALYSIS_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET): +$(AST_TARGET) $(COMPILER_TARGET) $(CONFIG_TARGET) $(ANALYSIS_TARGET) $(CODEGEN_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET): ar rcs $@ $^ # object file targets @@ -250,6 +255,7 @@ $(BUILD)/fuzz/protoprint.cpp.o: fuzz/luau.pb.cpp build/libprotobuf-mutator: git clone https://github.com/google/libprotobuf-mutator build/libprotobuf-mutator + git -C build/libprotobuf-mutator checkout 212a7be1eb08e7f9c79732d2aab9b2097085d936 CXX= cmake -S build/libprotobuf-mutator -B build/libprotobuf-mutator $(DPROTOBUF) make -C build/libprotobuf-mutator -j8 diff --git a/Sources.cmake b/Sources.cmake index 88f3bb01..ddd514f5 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -4,6 +4,7 @@ if(NOT ${CMAKE_VERSION} VERSION_LESS "3.19") target_sources(Luau.Common PRIVATE Common/include/Luau/Common.h Common/include/Luau/Bytecode.h + Common/include/Luau/BytecodeUtils.h Common/include/Luau/DenseHash.h Common/include/Luau/ExperimentalFlags.h ) @@ -55,6 +56,15 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/ValueTracking.h ) +# Luau.Config Sources +target_sources(Luau.Config PRIVATE + Config/include/Luau/Config.h + Config/include/Luau/LinterConfig.h + + Config/src/Config.cpp + Config/src/LinterConfig.cpp +) + # Luau.CodeGen Sources target_sources(Luau.CodeGen PRIVATE CodeGen/include/Luau/AddressA64.h @@ -145,7 +155,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/BuiltinDefinitions.h Analysis/include/Luau/Cancellation.h Analysis/include/Luau/Clone.h - Analysis/include/Luau/Config.h Analysis/include/Luau/Constraint.h Analysis/include/Luau/ConstraintGraphBuilder.h Analysis/include/Luau/ConstraintSolver.h @@ -158,6 +167,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Error.h Analysis/include/Luau/FileResolver.h Analysis/include/Luau/Frontend.h + Analysis/include/Luau/GlobalTypes.h Analysis/include/Luau/InsertionOrderedMap.h Analysis/include/Luau/Instantiation.h Analysis/include/Luau/IostreamHelpers.h @@ -176,6 +186,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Scope.h Analysis/include/Luau/Simplify.h Analysis/include/Luau/Substitution.h + Analysis/include/Luau/Subtyping.h Analysis/include/Luau/Symbol.h Analysis/include/Luau/ToDot.h Analysis/include/Luau/TopoSortStatements.h @@ -206,7 +217,6 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Autocomplete.cpp Analysis/src/BuiltinDefinitions.cpp Analysis/src/Clone.cpp - Analysis/src/Config.cpp Analysis/src/Constraint.cpp Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintSolver.cpp @@ -217,6 +227,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/EmbeddedBuiltinDefinitions.cpp Analysis/src/Error.cpp Analysis/src/Frontend.cpp + Analysis/src/GlobalTypes.cpp Analysis/src/Instantiation.cpp Analysis/src/IostreamHelpers.cpp Analysis/src/JsonEmitter.cpp @@ -230,6 +241,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Scope.cpp Analysis/src/Simplify.cpp Analysis/src/Substitution.cpp + Analysis/src/Subtyping.cpp Analysis/src/Symbol.cpp Analysis/src/ToDot.cpp Analysis/src/TopoSortStatements.cpp @@ -348,33 +360,34 @@ endif() if(TARGET Luau.UnitTest) # Luau.UnitTest Sources target_sources(Luau.UnitTest PRIVATE - tests/AstQueryDsl.cpp - tests/AstQueryDsl.h - tests/ClassFixture.cpp - tests/ClassFixture.h - tests/ConstraintGraphBuilderFixture.cpp - tests/ConstraintGraphBuilderFixture.h - tests/Fixture.cpp - tests/Fixture.h - tests/IostreamOptional.h - tests/ScopedFlags.h tests/AssemblyBuilderA64.test.cpp tests/AssemblyBuilderX64.test.cpp tests/AstJsonEncoder.test.cpp tests/AstQuery.test.cpp + tests/AstQueryDsl.cpp + tests/AstQueryDsl.h tests/AstVisitor.test.cpp + tests/RegisterCallbacks.h + tests/RegisterCallbacks.cpp tests/Autocomplete.test.cpp tests/BuiltinDefinitions.test.cpp + tests/ClassFixture.cpp + tests/ClassFixture.h tests/CodeAllocator.test.cpp tests/Compiler.test.cpp tests/Config.test.cpp + tests/ConstraintGraphBuilderFixture.cpp + tests/ConstraintGraphBuilderFixture.h tests/ConstraintSolver.test.cpp tests/CostModel.test.cpp tests/DataFlowGraph.test.cpp tests/DenseHash.test.cpp - tests/Differ.test.cpp + tests/Differ.test.cpp tests/Error.test.cpp + tests/Fixture.cpp + tests/Fixture.h tests/Frontend.test.cpp + tests/IostreamOptional.h tests/IrBuilder.test.cpp tests/IrCallWrapperX64.test.cpp tests/IrRegAllocX64.test.cpp @@ -389,8 +402,10 @@ if(TARGET Luau.UnitTest) tests/Parser.test.cpp tests/RequireTracer.test.cpp tests/RuntimeLimits.test.cpp + tests/ScopedFlags.h tests/Simplify.test.cpp tests/StringUtils.test.cpp + tests/Subtyping.test.cpp tests/Symbol.test.cpp tests/ToDot.test.cpp tests/TopoSort.test.cpp @@ -436,6 +451,8 @@ endif() if(TARGET Luau.Conformance) # Luau.Conformance Sources target_sources(Luau.Conformance PRIVATE + tests/RegisterCallbacks.h + tests/RegisterCallbacks.cpp tests/Conformance.test.cpp tests/main.cpp) endif() @@ -453,6 +470,8 @@ if(TARGET Luau.CLI.Test) CLI/Profiler.cpp CLI/Repl.cpp + tests/RegisterCallbacks.h + tests/RegisterCallbacks.cpp tests/Repl.test.cpp tests/main.cpp) endif() diff --git a/VM/include/lua.h b/VM/include/lua.h index e49e6ad9..43c60d77 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -443,7 +443,7 @@ typedef struct lua_Callbacks lua_Callbacks; LUA_API lua_Callbacks* lua_callbacks(lua_State* L); /****************************************************************************** - * Copyright (c) 2019-2022 Roblox Corporation + * Copyright (c) 2019-2023 Roblox Corporation * Copyright (C) 1994-2008 Lua.org, PUC-Rio. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index dcb785b6..7a1bbb95 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -121,11 +121,6 @@ #define LUA_MAXCAPTURES 32 #endif -// enables callbacks to redirect code execution from Luau VM to a custom implementation -#ifndef LUA_CUSTOM_EXECUTION -#define LUA_CUSTOM_EXECUTION 1 -#endif - // }================================================================== /* diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 054faa7c..2b98e47d 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -38,7 +38,7 @@ const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Ri "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; -const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n" +const char* luau_ident = "$Luau: Copyright (C) 2019-2023 Roblox Corporation $\n" "$URL: luau-lang.org $\n"; #define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base)) diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 951b3028..63da1810 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,8 +11,6 @@ #include -LUAU_FASTFLAG(LuauFasterInterp) - // convert a stack index to positive #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -524,19 +522,10 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) { if (luaL_callmeta(L, idx, "__tostring")) // is there a metafield? { - if (FFlag::LuauFasterInterp) - { - const char* s = lua_tolstring(L, -1, len); - if (!s) - luaL_error(L, "'__tostring' must return a string"); - return s; - } - else - { - if (!lua_isstring(L, -1)) - luaL_error(L, "'__tostring' must return a string"); - return lua_tolstring(L, -1, len); - } + const char* s = lua_tolstring(L, -1, len); + if (!s) + luaL_error(L, "'__tostring' must return a string"); + return s; } switch (lua_type(L, idx)) diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index c893d603..a916f73a 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -23,8 +23,6 @@ #endif #endif -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauFastcallGC, false) - // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -832,7 +830,7 @@ static int luauF_char(lua_State* L, StkId res, TValue* arg0, int nresults, StkId if (nparams < int(sizeof(buffer)) && nresults <= 1) { - if (DFFlag::LuauFastcallGC && luaC_needsGC(L)) + if (luaC_needsGC(L)) return -1; // we can't call luaC_checkGC so fall back to C implementation if (nparams >= 1) @@ -904,7 +902,7 @@ static int luauF_sub(lua_State* L, StkId res, TValue* arg0, int nresults, StkId int i = int(nvalue(args)); int j = int(nvalue(args + 1)); - if (DFFlag::LuauFastcallGC && luaC_needsGC(L)) + if (luaC_needsGC(L)) return -1; // we can't call luaC_checkGC so fall back to C implementation if (i >= 1 && j >= i && unsigned(j - 1) < unsigned(ts->len)) @@ -1300,7 +1298,7 @@ static int luauF_tostring(lua_State* L, StkId res, TValue* arg0, int nresults, S } case LUA_TNUMBER: { - if (DFFlag::LuauFastcallGC && luaC_needsGC(L)) + if (luaC_needsGC(L)) return -1; // we can't call luaC_checkGC so fall back to C implementation char s[LUAI_MAXNUM2STR]; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index e5fde4d4..69d3c128 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,6 +17,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauPCallDebuggerFix, false) + /* ** {====================================================== ** Error-recovery functions @@ -576,11 +578,13 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e if (!oldactive) L->isactive = false; + bool yieldable = L->nCcalls <= L->baseCcalls; // Inlined logic from 'lua_isyieldable' to avoid potential for an out of line call. + // restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. L->nCcalls = oldnCcalls; // an error occurred, check if we have a protected error callback - if (L->global->cb.debugprotectederror) + if ((!FFlag::LuauPCallDebuggerFix || yieldable) && L->global->cb.debugprotectederror) { L->global->cb.debugprotectederror(L); diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index fe7b1a12..5a817f25 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -7,8 +7,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauFasterNoise, false) - #undef PI #define PI (3.14159265358979323846) #define RADIANS_PER_DEGREE (PI / 180.0) @@ -277,105 +275,37 @@ static int math_randomseed(lua_State* L) return 0; } -// TODO: Delete with LuauFasterNoise -static const unsigned char kPerlin[512] = {151, 160, 137, 91, 90, 15, 131, 13, 201, 95, 96, 53, 194, 233, 7, 225, 140, 36, 103, 30, 69, 142, 8, 99, - 37, 240, 21, 10, 23, 190, 6, 148, 247, 120, 234, 75, 0, 26, 197, 62, 94, 252, 219, 203, 117, 35, 11, 32, 57, 177, 33, 88, 237, 149, 56, 87, 174, - 20, 125, 136, 171, 168, 68, 175, 74, 165, 71, 134, 139, 48, 27, 166, 77, 146, 158, 231, 83, 111, 229, 122, 60, 211, 133, 230, 220, 105, 92, 41, - 55, 46, 245, 40, 244, 102, 143, 54, 65, 25, 63, 161, 1, 216, 80, 73, 209, 76, 132, 187, 208, 89, 18, 169, 200, 196, 135, 130, 116, 188, 159, 86, - 164, 100, 109, 198, 173, 186, 3, 64, 52, 217, 226, 250, 124, 123, 5, 202, 38, 147, 118, 126, 255, 82, 85, 212, 207, 206, 59, 227, 47, 16, 58, 17, - 182, 189, 28, 42, 223, 183, 170, 213, 119, 248, 152, 2, 44, 154, 163, 70, 221, 153, 101, 155, 167, 43, 172, 9, 129, 22, 39, 253, 19, 98, 108, 110, - 79, 113, 224, 232, 178, 185, 112, 104, 218, 246, 97, 228, 251, 34, 242, 193, 238, 210, 144, 12, 191, 179, 162, 241, 81, 51, 145, 235, 249, 14, - 239, 107, 49, 192, 214, 31, 181, 199, 106, 157, 184, 84, 204, 176, 115, 121, 50, 45, 127, 4, 150, 254, 138, 236, 205, 93, 222, 114, 67, 29, 24, - 72, 243, 141, 128, 195, 78, 66, 215, 61, 156, 180, - - 151, 160, 137, 91, 90, 15, 131, 13, 201, 95, 96, 53, 194, 233, 7, 225, 140, 36, 103, 30, 69, 142, 8, 99, 37, 240, 21, 10, 23, 190, 6, 148, 247, - 120, 234, 75, 0, 26, 197, 62, 94, 252, 219, 203, 117, 35, 11, 32, 57, 177, 33, 88, 237, 149, 56, 87, 174, 20, 125, 136, 171, 168, 68, 175, 74, - 165, 71, 134, 139, 48, 27, 166, 77, 146, 158, 231, 83, 111, 229, 122, 60, 211, 133, 230, 220, 105, 92, 41, 55, 46, 245, 40, 244, 102, 143, 54, 65, - 25, 63, 161, 1, 216, 80, 73, 209, 76, 132, 187, 208, 89, 18, 169, 200, 196, 135, 130, 116, 188, 159, 86, 164, 100, 109, 198, 173, 186, 3, 64, 52, - 217, 226, 250, 124, 123, 5, 202, 38, 147, 118, 126, 255, 82, 85, 212, 207, 206, 59, 227, 47, 16, 58, 17, 182, 189, 28, 42, 223, 183, 170, 213, - 119, 248, 152, 2, 44, 154, 163, 70, 221, 153, 101, 155, 167, 43, 172, 9, 129, 22, 39, 253, 19, 98, 108, 110, 79, 113, 224, 232, 178, 185, 112, - 104, 218, 246, 97, 228, 251, 34, 242, 193, 238, 210, 144, 12, 191, 179, 162, 241, 81, 51, 145, 235, 249, 14, 239, 107, 49, 192, 214, 31, 181, 199, - 106, 157, 184, 84, 204, 176, 115, 121, 50, 45, 127, 4, 150, 254, 138, 236, 205, 93, 222, 114, 67, 29, 24, 72, 243, 141, 128, 195, 78, 66, 215, 61, - 156, 180}; - -static const unsigned char kPerlinHash[257] = {151, 160, 137, 91, 90, 15, 131, 13, 201, 95, 96, 53, 194, 233, 7, 225, 140, 36, 103, 30, 69, 142, 8, 99, - 37, 240, 21, 10, 23, 190, 6, 148, 247, 120, 234, 75, 0, 26, 197, 62, 94, 252, 219, 203, 117, 35, 11, 32, 57, 177, 33, 88, 237, 149, 56, 87, 174, - 20, 125, 136, 171, 168, 68, 175, 74, 165, 71, 134, 139, 48, 27, 166, 77, 146, 158, 231, 83, 111, 229, 122, 60, 211, 133, 230, 220, 105, 92, 41, - 55, 46, 245, 40, 244, 102, 143, 54, 65, 25, 63, 161, 1, 216, 80, 73, 209, 76, 132, 187, 208, 89, 18, 169, 200, 196, 135, 130, 116, 188, 159, 86, - 164, 100, 109, 198, 173, 186, 3, 64, 52, 217, 226, 250, 124, 123, 5, 202, 38, 147, 118, 126, 255, 82, 85, 212, 207, 206, 59, 227, 47, 16, 58, 17, - 182, 189, 28, 42, 223, 183, 170, 213, 119, 248, 152, 2, 44, 154, 163, 70, 221, 153, 101, 155, 167, 43, 172, 9, 129, 22, 39, 253, 19, 98, 108, 110, - 79, 113, 224, 232, 178, 185, 112, 104, 218, 246, 97, 228, 251, 34, 242, 193, 238, 210, 144, 12, 191, 179, 162, 241, 81, 51, 145, 235, 249, 14, - 239, 107, 49, 192, 214, 31, 181, 199, 106, 157, 184, 84, 204, 176, 115, 121, 50, 45, 127, 4, 150, 254, 138, 236, 205, 93, 222, 114, 67, 29, 24, - 72, 243, 141, 128, 195, 78, 66, 215, 61, 156, 180, 151}; +static const unsigned char kPerlinHash[257] = {151, 160, 137, 91, 90, 15, 131, 13, 201, 95, 96, 53, 194, 233, 7, 225, 140, 36, 103, 30, 69, 142, 8, + 99, 37, 240, 21, 10, 23, 190, 6, 148, 247, 120, 234, 75, 0, 26, 197, 62, 94, 252, 219, 203, 117, 35, 11, 32, 57, 177, 33, 88, 237, 149, 56, 87, + 174, 20, 125, 136, 171, 168, 68, 175, 74, 165, 71, 134, 139, 48, 27, 166, 77, 146, 158, 231, 83, 111, 229, 122, 60, 211, 133, 230, 220, 105, 92, + 41, 55, 46, 245, 40, 244, 102, 143, 54, 65, 25, 63, 161, 1, 216, 80, 73, 209, 76, 132, 187, 208, 89, 18, 169, 200, 196, 135, 130, 116, 188, 159, + 86, 164, 100, 109, 198, 173, 186, 3, 64, 52, 217, 226, 250, 124, 123, 5, 202, 38, 147, 118, 126, 255, 82, 85, 212, 207, 206, 59, 227, 47, 16, 58, + 17, 182, 189, 28, 42, 223, 183, 170, 213, 119, 248, 152, 2, 44, 154, 163, 70, 221, 153, 101, 155, 167, 43, 172, 9, 129, 22, 39, 253, 19, 98, 108, + 110, 79, 113, 224, 232, 178, 185, 112, 104, 218, 246, 97, 228, 251, 34, 242, 193, 238, 210, 144, 12, 191, 179, 162, 241, 81, 51, 145, 235, 249, + 14, 239, 107, 49, 192, 214, 31, 181, 199, 106, 157, 184, 84, 204, 176, 115, 121, 50, 45, 127, 4, 150, 254, 138, 236, 205, 93, 222, 114, 67, 29, + 24, 72, 243, 141, 128, 195, 78, 66, 215, 61, 156, 180, 151}; const float kPerlinGrad[16][3] = {{1, 1, 0}, {-1, 1, 0}, {1, -1, 0}, {-1, -1, 0}, {1, 0, 1}, {-1, 0, 1}, {1, 0, -1}, {-1, 0, -1}, {0, 1, 1}, {0, -1, 1}, {0, 1, -1}, {0, -1, -1}, {1, 1, 0}, {0, -1, 1}, {-1, 1, 0}, {0, -1, -1}}; -static float perlin_fade(float t) +inline float perlin_fade(float t) { return t * t * t * (t * (t * 6 - 15) + 10); } -static float perlin_lerp(float t, float a, float b) +inline float perlin_lerp(float t, float a, float b) { return a + t * (b - a); } -static float grad(unsigned char hash, float x, float y, float z) -{ - LUAU_ASSERT(!FFlag::LuauFasterNoise); - unsigned char h = hash & 15; - float u = (h < 8) ? x : y; - float v = (h < 4) ? y : (h == 12 || h == 14) ? x : z; - - return (h & 1 ? -u : u) + (h & 2 ? -v : v); -} - -static float perlin_grad(int hash, float x, float y, float z) +inline float perlin_grad(int hash, float x, float y, float z) { const float* g = kPerlinGrad[hash & 15]; return g[0] * x + g[1] * y + g[2] * z; } -static float perlin_dep(float x, float y, float z) -{ - LUAU_ASSERT(!FFlag::LuauFasterNoise); - float xflr = floorf(x); - float yflr = floorf(y); - float zflr = floorf(z); - - int xi = int(xflr) & 255; - int yi = int(yflr) & 255; - int zi = int(zflr) & 255; - - float xf = x - xflr; - float yf = y - yflr; - float zf = z - zflr; - - float u = perlin_fade(xf); - float v = perlin_fade(yf); - float w = perlin_fade(zf); - - const unsigned char* p = kPerlin; - - int a = p[xi] + yi; - int aa = p[a] + zi; - int ab = p[a + 1] + zi; - - int b = p[xi + 1] + yi; - int ba = p[b] + zi; - int bb = p[b + 1] + zi; - - return perlin_lerp(w, - perlin_lerp(v, perlin_lerp(u, grad(p[aa], xf, yf, zf), grad(p[ba], xf - 1, yf, zf)), - perlin_lerp(u, grad(p[ab], xf, yf - 1, zf), grad(p[bb], xf - 1, yf - 1, zf))), - perlin_lerp(v, perlin_lerp(u, grad(p[aa + 1], xf, yf, zf - 1), grad(p[ba + 1], xf - 1, yf, zf - 1)), - perlin_lerp(u, grad(p[ab + 1], xf, yf - 1, zf - 1), grad(p[bb + 1], xf - 1, yf - 1, zf - 1)))); -} - static float perlin(float x, float y, float z) { - LUAU_ASSERT(FFlag::LuauFasterNoise); float xflr = floorf(x); float yflr = floorf(y); float zflr = floorf(z); @@ -412,33 +342,19 @@ static float perlin(float x, float y, float z) static int math_noise(lua_State* L) { - if (FFlag::LuauFasterNoise) - { - int nx, ny, nz; - double x = lua_tonumberx(L, 1, &nx); - double y = lua_tonumberx(L, 2, &ny); - double z = lua_tonumberx(L, 3, &nz); + int nx, ny, nz; + double x = lua_tonumberx(L, 1, &nx); + double y = lua_tonumberx(L, 2, &ny); + double z = lua_tonumberx(L, 3, &nz); - luaL_argexpected(L, nx, 1, "number"); - luaL_argexpected(L, ny || lua_isnoneornil(L, 2), 2, "number"); - luaL_argexpected(L, nz || lua_isnoneornil(L, 3), 3, "number"); + luaL_argexpected(L, nx, 1, "number"); + luaL_argexpected(L, ny || lua_isnoneornil(L, 2), 2, "number"); + luaL_argexpected(L, nz || lua_isnoneornil(L, 3), 3, "number"); - double r = perlin((float)x, (float)y, (float)z); + double r = perlin((float)x, (float)y, (float)z); - lua_pushnumber(L, r); - return 1; - } - else - { - double x = luaL_checknumber(L, 1); - double y = luaL_optnumber(L, 2, 0.0); - double z = luaL_optnumber(L, 3, 0.0); - - double r = perlin_dep((float)x, (float)y, (float)z); - - lua_pushnumber(L, r); - return 1; - } + lua_pushnumber(L, r); + return 1; } static int math_clamp(lua_State* L) diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 3cbdafff..25ca999d 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -98,6 +98,8 @@ */ #if defined(__APPLE__) #define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : gcc32) +#elif defined(__i386__) && defined(__MINGW32__) && !defined(__MINGW64__) +#define ABISWITCH(x64, ms32, gcc32) (ms32) #elif defined(__i386__) && !defined(_MSC_VER) #define ABISWITCH(x64, ms32, gcc32) (gcc32) #else diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index 5b27e2b8..38bfb322 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -40,6 +40,13 @@ inline double luai_nummod(double a, double b) } LUAU_FASTMATH_END +LUAU_FASTMATH_BEGIN +inline double luai_numidiv(double a, double b) +{ + return floor(a / b); +} +LUAU_FASTMATH_END + #define luai_num2int(i, d) ((i) = (int)(d)) // On MSVC in 32-bit, double to unsigned cast compiles into a call to __dtoui3, so we invoke x87->int64 conversion path manually diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index d9ce71f9..ca57786e 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauFasterInterp, false) - // macro to `unsign' a character #define uchar(c) ((unsigned char)(c)) @@ -968,7 +966,7 @@ static int str_format(lua_State* L) luaL_addchar(&b, *strfrmt++); else if (*++strfrmt == L_ESC) luaL_addchar(&b, *strfrmt++); // %% - else if (FFlag::LuauFasterInterp && *strfrmt == '*') + else if (*strfrmt == '*') { strfrmt++; if (++arg > top) @@ -1028,12 +1026,10 @@ static int str_format(lua_State* L) { size_t l; const char* s = luaL_checklstring(L, arg, &l); - if (!strchr(form, '.') && l >= 100) + // no precision and string is too long to be formatted, or no format necessary to begin with + if (form[2] == '\0' || (!strchr(form, '.') && l >= 100)) { - /* no precision and string is too long to be formatted; - keep original string */ - lua_pushvalue(L, arg); - luaL_addvalue(&b); + luaL_addlstring(&b, s, l, -1); continue; // skip the `luaL_addlstring' at the end } else @@ -1044,16 +1040,8 @@ static int str_format(lua_State* L) } case '*': { - if (FFlag::LuauFasterInterp || formatItemSize != 1) - luaL_error(L, "'%%*' does not take a form"); - - size_t length; - const char* string = luaL_tolstring(L, arg, &length); - - luaL_addlstring(&b, string, length, -2); - lua_pop(L, 1); - - continue; // skip the `luaL_addlstring' at the end + // %* is parsed above, so if we got here we must have %...* + luaL_error(L, "'%%*' does not take a form"); } default: { // also treat cases `pnLlh' diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index fbf03deb..cc44ee19 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,8 +10,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauFasterTableConcat, false) - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -222,7 +220,7 @@ static int tmove(lua_State* L) static void addfield(lua_State* L, luaL_Buffer* b, int i) { int tt = lua_rawgeti(L, 1, i); - if (FFlag::LuauFasterTableConcat ? (tt != LUA_TSTRING && tt != LUA_TNUMBER) : !lua_isstring(L, -1)) + if (tt != LUA_TSTRING && tt != LUA_TNUMBER) luaL_error(L, "invalid value (%s) at index %d in table for 'concat'", luaL_typename(L, -1), i); luaL_addvalue(b); } diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index d753e8a4..d8c69a70 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -48,6 +48,7 @@ const char* const luaT_eventname[] = { "__sub", "__mul", "__div", + "__idiv", "__mod", "__pow", "__unm", diff --git a/VM/src/ltm.h b/VM/src/ltm.h index 4b1c2818..7dafd4ed 100644 --- a/VM/src/ltm.h +++ b/VM/src/ltm.h @@ -27,6 +27,7 @@ typedef enum TM_SUB, TM_MUL, TM_DIV, + TM_IDIV, TM_MOD, TM_POW, TM_UNM, diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 2909d477..086ed649 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -103,7 +103,8 @@ VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_DEP_JUMPIFEQK), VM_DISPATCH_OP(LOP_DEP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), VM_DISPATCH_OP(LOP_JUMPXEQKNIL), \ - VM_DISPATCH_OP(LOP_JUMPXEQKB), VM_DISPATCH_OP(LOP_JUMPXEQKN), VM_DISPATCH_OP(LOP_JUMPXEQKS), + VM_DISPATCH_OP(LOP_JUMPXEQKB), VM_DISPATCH_OP(LOP_JUMPXEQKN), VM_DISPATCH_OP(LOP_JUMPXEQKS), VM_DISPATCH_OP(LOP_IDIV), \ + VM_DISPATCH_OP(LOP_IDIVK), #if defined(__GNUC__) || defined(__clang__) #define VM_USE_CGOTO 1 @@ -132,7 +133,9 @@ #endif // Does VM support native execution via ExecutionCallbacks? We mostly assume it does but keep the define to make it easy to quantify the cost. -#define VM_HAS_NATIVE LUA_CUSTOM_EXECUTION +#define VM_HAS_NATIVE 1 + +void (*lua_iter_call_telemetry)(lua_State* L, int gtt, int stt, int itt) = NULL; LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) { @@ -1660,6 +1663,54 @@ reentry: } } + VM_CASE(LOP_IDIV) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (LUAU_LIKELY(ttisnumber(rb) && ttisnumber(rc))) + { + setnvalue(ra, luai_numidiv(nvalue(rb), nvalue(rc))); + VM_NEXT(); + } + else if (ttisvector(rb) && ttisnumber(rc)) + { + const float* vb = vvalue(rb); + float vc = cast_to(float, nvalue(rc)); + setvvalue(ra, float(luai_numidiv(vb[0], vc)), float(luai_numidiv(vb[1], vc)), float(luai_numidiv(vb[2], vc)), + float(luai_numidiv(vb[3], vc))); + VM_NEXT(); + } + else + { + // fast-path for userdata with C functions + StkId rbc = ttisnumber(rb) ? rc : rb; + const TValue* fn = 0; + if (ttisuserdata(rbc) && (fn = luaT_gettmbyobj(L, rbc, TM_IDIV)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, rc); + L->top = top + 3; + + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_IDIV)); + VM_NEXT(); + } + } + } + VM_CASE(LOP_MOD) { Instruction insn = *pc++; @@ -1838,6 +1889,53 @@ reentry: } } + VM_CASE(LOP_IDIVK) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (LUAU_LIKELY(ttisnumber(rb))) + { + setnvalue(ra, luai_numidiv(nvalue(rb), nvalue(kv))); + VM_NEXT(); + } + else if (ttisvector(rb)) + { + const float* vb = vvalue(rb); + float vc = cast_to(float, nvalue(kv)); + setvvalue(ra, float(luai_numidiv(vb[0], vc)), float(luai_numidiv(vb[1], vc)), float(luai_numidiv(vb[2], vc)), + float(luai_numidiv(vb[3], vc))); + VM_NEXT(); + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_IDIV)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + VM_PROTECT(luaV_callTM(L, 2, LUAU_INSN_A(insn))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_IDIV)); + VM_NEXT(); + } + } + } + VM_CASE(LOP_MODK) { Instruction insn = *pc++; @@ -2193,6 +2291,10 @@ reentry: { // table or userdata with __call, will be called during FORGLOOP // TODO: we might be able to stop supporting this depending on whether it's used in practice + void (*telemetrycb)(lua_State * L, int gtt, int stt, int itt) = lua_iter_call_telemetry; + + if (telemetrycb) + telemetrycb(L, ttype(ra), ttype(ra + 1), ttype(ra + 2)); } else if (ttistable(ra)) { @@ -2380,7 +2482,7 @@ reentry: else goto exit; #else - LUAU_ASSERT(!"Opcode is only valid when LUA_CUSTOM_EXECUTION is defined"); + LUAU_ASSERT(!"Opcode is only valid when VM_HAS_NATIVE is defined"); LUAU_UNREACHABLE(); #endif } diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 7a065383..365aa5d3 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauLoadCheckGC, false) - // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template struct TempBuffer @@ -181,8 +179,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size } // we will allocate a fair amount of memory so check GC before we do - if (FFlag::LuauLoadCheckGC) - luaC_checkGC(L); + luaC_checkGC(L); // pause GC for the duration of deserialization - some objects we're creating aren't rooted // TODO: if an allocation error happens mid-load, we do not unpause GC! diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index b77207da..df30db2f 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -394,6 +394,9 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM case TM_DIV: setnvalue(ra, luai_numdiv(nb, nc)); break; + case TM_IDIV: + setnvalue(ra, luai_numidiv(nb, nc)); + break; case TM_MOD: setnvalue(ra, luai_nummod(nb, nc)); break; @@ -410,7 +413,12 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM } else { - // vector operations that we support: v + v, v - v, v * v, s * v, v * s, v / v, s / v, v / s, -v + // vector operations that we support: + // v+v v-v -v (add/sub/neg) + // v*v s*v v*s (mul) + // v/v s/v v/s (div) + // v//v s//v v//s (floor div) + const float* vb = luaV_tovector(rb); const float* vc = luaV_tovector(rc); @@ -430,6 +438,10 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM case TM_DIV: setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); return; + case TM_IDIV: + setvvalue(ra, float(luai_numidiv(vb[0], vc[0])), float(luai_numidiv(vb[1], vc[1])), float(luai_numidiv(vb[2], vc[2])), + float(luai_numidiv(vb[3], vc[3]))); + return; case TM_UNM: setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); return; @@ -453,6 +465,10 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM case TM_DIV: setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc, vb[3] / nc); return; + case TM_IDIV: + setvvalue(ra, float(luai_numidiv(vb[0], nc)), float(luai_numidiv(vb[1], nc)), float(luai_numidiv(vb[2], nc)), + float(luai_numidiv(vb[3], nc))); + return; default: break; } @@ -474,6 +490,10 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM case TM_DIV: setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2], nb / vc[3]); return; + case TM_IDIV: + setvvalue(ra, float(luai_numidiv(nb, vc[0])), float(luai_numidiv(nb, vc[1])), float(luai_numidiv(nb, vc[2])), + float(luai_numidiv(nb, vc[3]))); + return; default: break; } diff --git a/bench/micro_tests/test_StringInterp.lua b/bench/micro_tests/test_StringInterp.lua index 33d5ecea..1e7ccbc7 100644 --- a/bench/micro_tests/test_StringInterp.lua +++ b/bench/micro_tests/test_StringInterp.lua @@ -36,6 +36,13 @@ bench.runCode(function() end end, "interp: interp number") +bench.runCode(function() + local ok = "hello!" + for j=1,1e6 do + local _ = string.format("j=%s", ok) + end +end, "interp: %s format") + bench.runCode(function() local ok = "hello!" for j=1,1e6 do diff --git a/bench/micro_tests/test_TableSort.lua b/bench/micro_tests/test_TableSort.lua new file mode 100644 index 00000000..80031d1c --- /dev/null +++ b/bench/micro_tests/test_TableSort.lua @@ -0,0 +1,22 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +local arr_months = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"} + +local arr_num = {} +for i=1,100 do table.insert(arr_num, math.sin(i)) end + +local arr_numk = {} +for i=1,10000 do table.insert(arr_numk, math.sin(i)) end + +function test(arr) + local t = table.create(#arr) + + for i=1,1e6/#arr do + table.move(arr, 1, #arr, 1, t) + table.sort(t) + end +end + +bench.runCode(function() test(arr_months) end, "table.sort: 12 strings") +bench.runCode(function() test(arr_num) end, "table.sort: 100 numbers") +bench.runCode(function() test(arr_numk) end, "table.sort: 10k numbers") diff --git a/bench/tabulate.py b/bench/tabulate.py index fc034171..0b77209f 100644 --- a/bench/tabulate.py +++ b/bench/tabulate.py @@ -45,7 +45,7 @@ class TablePrinter(object): def _print_horizontal_separator(self): for i, align_width in enumerate(self._widths): if i > 0: - print('-+-', end='') + print('-|-', end='') print('-' * (align_width+1), end='') print() pass diff --git a/docs/_pages/compatibility.md b/docs/_pages/compatibility.md index b57a9af6..94f8809b 100644 --- a/docs/_pages/compatibility.md +++ b/docs/_pages/compatibility.md @@ -87,7 +87,7 @@ Ephemeron tables may be implemented at some point since they do have valid uses | bitwise operators | ❌ | `bit32` library covers this in absence of 64-bit integers | | basic utf-8 support | ✔️ | we include `utf8` library and other UTF8 features | | functions for packing and unpacking values (string.pack/unpack/packsize) | ✔️ | | -| floor division | 🔜 | | +| floor division | ✔️ | | | `ipairs` and the `table` library respect metamethods | ❌ | no strong use cases, performance implications | | new function `table.move` | ✔️ | | | `collectgarbage("count")` now returns only one result | ✔️ | | diff --git a/docs/_pages/performance.md b/docs/_pages/performance.md index d4dc0e8e..c2b4ced3 100644 --- a/docs/_pages/performance.md +++ b/docs/_pages/performance.md @@ -88,7 +88,7 @@ For this mechanism to work, function call must be "obvious" to the compiler - it The mechanism works by directly invoking a highly specialized and optimized implementation of a builtin function from the interpreter core loop without setting up a stack frame and omitting other work; additionally, some fastcall specializations are partial in that they don't support all types of arguments, for example all `math` library builtins are only specialized for numeric arguments, so calling `math.abs` with a string argument will fall back to the slower implementation that will do string->number coercion. -As a result, builtin calls are very fast in Luau - they are still slightly slower than core instructions such as arithmetic operations, but only slightly so. The set of fastcall builtins is slowly expanding over time and as of this writing contains `assert`, `type`, `typeof`, `rawget`/`rawset`/`rawequal`, `getmetatable`/`setmetatable`, all functions from `math` and `bit32`, and some functions from `string` and `table` library. +As a result, builtin calls are very fast in Luau - they are still slightly slower than core instructions such as arithmetic operations, but only slightly so. The set of fastcall builtins is slowly expanding over time and as of this writing contains `assert`, `type`, `typeof`, `rawget`/`rawset`/`rawequal`, `getmetatable`/`setmetatable`, `tonumber`/`tostring`, all functions from `math` (except `noise` and `random`/`randomseed`) and `bit32`, and some functions from `string` and `table` library. Some builtin functions have partial specializations that reduce the cost of the common case further. Notably: @@ -98,7 +98,7 @@ Some builtin functions have partial specializations that reduce the cost of the Some functions from `math` library like `math.floor` can additionally take advantage of advanced SIMD instruction sets like SSE4.1 when available. -In addition to runtime optimizations for builtin calls, many builtin calls can also be constant-folded by the bytecode compiler when using aggressive optimizations (level 2); this currently applies to most builtin calls with constant arguments and a single return value. For builtin calls that can not be constant folded, compiler assumes knowledge of argument/return count (level 2) to produce more efficient bytecode instructions. +In addition to runtime optimizations for builtin calls, many builtin calls, as well as constants like `math.pi`/`math.huge`, can also be constant-folded by the bytecode compiler when using aggressive optimizations (level 2); this currently applies to most builtin calls with constant arguments and a single return value. For builtin calls that can not be constant folded, compiler assumes knowledge of argument/return count (level 2) to produce more efficient bytecode instructions. ## Optimized table iteration diff --git a/docs/_pages/syntax.md b/docs/_pages/syntax.md index 595a4804..42fbf897 100644 --- a/docs/_pages/syntax.md +++ b/docs/_pages/syntax.md @@ -273,3 +273,9 @@ This restriction is made to prevent developers using other programming languages Luau currently does not support backtick string literals as a type annotation, so `` type Foo = `Foo` `` is invalid. Function calls with a backtick string literal without parenthesis is not supported, so `` print`hello` `` is invalid. + +## Floor division (`//`) + +Luau implements support for floor division operator (`//`) for numbers as well as support for `__idiv` metamethod. The syntax and semantics follow [Lua 5.3](https://www.lua.org/manual/5.3/manual.html#3.4.1). + +For numbers, `a // b` is equal to `math.floor(a / b)`; when `b` is 0, `a // b` results in infinity or NaN as appropriate. diff --git a/docs/assets/js/luau_mode.js b/docs/assets/js/luau_mode.js index a5dcdc97..3bd86505 100644 --- a/docs/assets/js/luau_mode.js +++ b/docs/assets/js/luau_mode.js @@ -17,10 +17,10 @@ var indentUnit = 4; function prefixRE(words) { - return new RegExp("^(?:" + words.join("|") + ")", "i"); + return new RegExp("^(?:" + words.join("|") + ")"); } function wordRE(words) { - return new RegExp("^(?:" + words.join("|") + ")$", "i"); + return new RegExp("^(?:" + words.join("|") + ")$"); } var specials = wordRE(parserConfig.specials || ["type"]); diff --git a/fuzz/luau.proto b/fuzz/luau.proto index e51d687b..5fed9ddc 100644 --- a/fuzz/luau.proto +++ b/fuzz/luau.proto @@ -135,17 +135,18 @@ message ExprBinary { Sub = 1; Mul = 2; Div = 3; - Mod = 4; - Pow = 5; - Concat = 6; - CompareNe = 7; - CompareEq = 8; - CompareLt = 9; - CompareLe = 10; - CompareGt = 11; - CompareGe = 12; - And = 13; - Or = 14; + FloorDiv = 4; + Mod = 5; + Pow = 6; + Concat = 7; + CompareNe = 8; + CompareEq = 9; + CompareLt = 10; + CompareLe = 11; + CompareGt = 12; + CompareGe = 13; + And = 14; + Or = 15; } required Op op = 1; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index e1ecaf65..1ad7852d 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -7,6 +7,7 @@ #include "Luau/CodeGen.h" #include "Luau/Common.h" #include "Luau/Compiler.h" +#include "Luau/Config.h" #include "Luau/Frontend.h" #include "Luau/Linter.h" #include "Luau/ModuleResolver.h" diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index 5c7c5bf6..4eab2893 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -495,6 +495,8 @@ struct ProtoToLuau source += " * "; else if (expr.op() == luau::ExprBinary::Div) source += " / "; + else if (expr.op() == luau::ExprBinary::FloorDiv) + source += " // "; else if (expr.op() == luau::ExprBinary::Mod) source += " % "; else if (expr.op() == luau::ExprBinary::Pow) diff --git a/papers/hatra23/bibliography.bib b/papers/hatra23/bibliography.bib index 067bf967..3fbae85e 100644 --- a/papers/hatra23/bibliography.bib +++ b/papers/hatra23/bibliography.bib @@ -1,6 +1,6 @@ @InProceedings{BFJ21:GoalsLuau, author = {L. Brown and A. Friesen and A. S. A. Jeffrey}, - title = {Goals of the Luau Type System}, + title = {Position Paper: Goals of the Luau Type System}, booktitle = {Proc. Human Aspects of Types and Reasoning Assistants}, year = {2021}, url = {https://asaj.org/papers/hatra21.pdf}, diff --git a/papers/hatra23/hatra23.pdf b/papers/hatra23/hatra23.pdf index 60201e65..b750ecae 100644 Binary files a/papers/hatra23/hatra23.pdf and b/papers/hatra23/hatra23.pdf differ diff --git a/papers/hatra23/hatra23.tex b/papers/hatra23/hatra23.tex index a5cf4359..1d73ee53 100644 --- a/papers/hatra23/hatra23.tex +++ b/papers/hatra23/hatra23.tex @@ -12,6 +12,7 @@ \newcommand{\ANY}{\mathtt{any}} \newcommand{\ERROR}{\mathtt{error}} +\newcommand{\NUMBER}{\mathtt{number}} \begin{document} @@ -48,7 +49,7 @@ error reporting). Roblox has hundreds of millions of users, and millions of creators, ranging from children learning to program for the first time to professional development studios. -In HATRA 2021, we presented \emph{The Goals Of The Luau Type +In HATRA 2021, we presented a position paper on the \emph{Goals Of The Luau Type System}~\cite{BFJ21:GoalsLuau}, describing the human factors issues with designing a type system for a language with a heterogeneous developer community. The design flows from the needs of the different @@ -58,9 +59,9 @@ quality concerns of more experienced developers; for all users type-driven tooling is important for productivity. These needs result in a design with two modes: \begin{itemize} \item \emph{non-strict mode}, aimed at non-professionals, focused on - minimizing false positives, and + minimizing false positives (that is, in non-strict mode, any program with a type error has a defect), and \item \emph{strict mode}, aimed at professionals, focused on - minimizing false negatives (i.e.~type soundness). + minimizing false negatives (that is, in strict mode, any program with a defect has a type error). \end{itemize} %% For both communities, type-driven tooling is important, so we %% provide \emph{infallible type inference}, which infers types @@ -81,21 +82,58 @@ inclusion~\cite{GF05:GentleIntroduction}. This is aligned with the \emph{minimize false positives} goal of Luau non-strict mode, since semantic subtyping only reports a failure of subtyping when there is a value which inhabits the candidate subtype, but not the candidate -supertype. This has the added benefit of improving error reporting in -the prototype implementation: since the prototype is constructive, the -witness for the failure of subtyping can help drive error reports. +supertype. +For example, the program: +\begin{verbatim} + local x : CFrame = CFrame.new() + local y : Vector3 | CFrame + if math.random() < 0.5 then y = CFrame.new() else y = Vector3.new() end + local z : Vector3 | CFrame = x * y +\end{verbatim} +cannot produce a run-time error, since multiplication of \verb|CFrame|s is overloaded: +\begin{verbatim} + ((CFrame, CFrame) -> CFrame) & ((CFrame, Vector3) -> Vector3) +\end{verbatim} +In order to typecheck this program, we check that that type is a subtype of: +\begin{verbatim} + (CFrame, Vector3 | CFrame) -> (Vector3 | CFrame) +\end{verbatim} +In the previous, syntax-driven, implementation of subtyping, this subtype check would fail, resulting in a false positive. +We have now released an implementation of semantic subtyping, which does not suffer from this defect. See our technical blog for more details~\cite{Jef22:SemanticSubtyping}. Rather than the gradual typing approach -of~\cite{ST07:GradualTyping}, which uses \emph{consistent +of Siek and Taha~\cite{ST07:GradualTyping}, which uses \emph{consistent subtyping} where $\ANY \lesssim T \lesssim \ANY$ for any type $T$, we adopt an approach based on \emph{error suppression}, where $\ANY$ is an error-suppressing type, and any failures of subtyping involving error-suppressing types are not reported. Users can explicitly suppress type errors by declaring variables with type $\ANY$, and since an expression with a type error has an error-suppressing type we -avoid cascading errors. Error suppression is in production Luau, and is -mechanically verified~\cite{BJ23:agda-typeck}. +avoid cascading errors. + +We do this by defining a \emph{infallible} typing judgment $\Gamma \vdash M : T$ +such that for any $\Gamma$ and $M$, there is a $T$ such that $\Gamma \vdash M : T$. +For example the rule for addition (ignoring overloads for simplicity) is: +\[ + \frac{\Gamma \vdash M : T \quad \Gamma \vdash M : U}{\Gamma \vdash M+N : \NUMBER} +\] +We define which judgments produce warnings, for example that rule produces a warning +when +\begin{itemize} + \item either $T \not<: \NUMBER$ and $T$ is not error-suppressing, + \item or $U \not<: \NUMBER$ and $U$ is not error-suppressing. +\end{itemize} +To retain type soundness (in the absence of user-supplied error-suppressing types) +we show that +if $\Gamma \vdash M : T$ and $T$ is error-suppressing, then either +\begin{itemize} + \item $\Gamma$ or $M$ contains an error-suppressing type, or + \item $\Gamma \vdash M : T$ produces a warning. +\end{itemize} +From this it is straightforward to show the usual ``well typed +programs don't go wrong'' type soundness result for programs without explicit +error-suppressing types~\cite{BJ23:agda-typeck}. \section{Further work} diff --git a/rfcs/local-type-inference.md b/rfcs/local-type-inference.md new file mode 100644 index 00000000..88cfe3cd --- /dev/null +++ b/rfcs/local-type-inference.md @@ -0,0 +1,180 @@ +# Local Type Inference + +## Summary + +We are going to supplant the current type solver with one based on Benjamin Pierce's Local Type Inference algorithm: + +https://www.cis.upenn.edu/~bcpierce/papers/lti-toplas.pdf + +## Motivation + +Luau's type inference algorithm is used for much more than typechecking scripts. It is also the backbone of an autocomplete algorithm which has to work even for people who don't know what types or type systems are. + +We originally implemented nonstrict mode by making some tactical adjustments to the type inference algorithm. This was great for reducing false positives in untyped code, but carried with it the drawback that the inference result was usually not good enough for the autocomplete system. In order to offer a high quality experience, we've found ourselves to run type inference on nonstrict scripts twice: once for error feedback, and once again to populate the autocomplete database. + +Separately, we would also like more accurate type inference in general. Our current type solver jumps to conclusions a little bit too quickly. For example, it cannot infer an accurate type for an ordinary search function: + +```lua +function index_of(tbl, el) + for i = 0, #tbl do + if tbl[i] == el then + return i + end + end + return nil +end +``` + +Our solver sees two `return` statements and assumes that, because the first statement yields a `number`, so too must the second. + +To fix this, we are going to move to an architecture where type inference and type checking are two separate steps. Whatever mode the user is programming with, we will run an accurate type inference pass over their code and then run one of two typechecking passes over it. + +## Notation + +We'll use the standard notation `A <: B` to indicate that `A` is a subtype of `B`. + +## Design + +At a very high level, local type inference is built around the idea that we track the lower bounds and their upper bounds. The lower bounds of a binding is the set of values that it might conceivably receive. If a binding receives a value outside of its upper bounds, the program will fail. + +At the implementation level, we reencode free types as the space between these bounds. + +Upper bounds arise only from type annotations and certain builtin operations whereas lower bounds arise from assignments, return statements, and uses. + +Free types all start out with bounds `never <: 't <: unknown`. Intuitively, we say that `'t` represents some set of values whose domain is at least `never` and at most `unknown`. This naturally could be any value at all. + +When dispatching a constraint `T <: 't`, we replace the lower bounds of `'t` by the union of its old lower bounds and `T`. When dispatching a constraint `'t <: T`, we replace the upper bounds by its upper bound intersected with `T`. In other words, lower bounds grow from nothing as we see the value used whereas the upper bound initially encompasses everything and shrinks as we constrain it. + +### Constraint Generation Rules + +A return statement expands the lower bounds of the enclosing function's return type. + +```lua +function f(): R + local x: X + return x + -- X <: R +end +``` + +An assignment adds to the lower bounds of the assignee. + +```lua +local a: A +local b: B +a = b +-- B <: A +``` + +A function call adds to the upper bounds of the function being called. + +Equivalently, passing a value to a function adds to the upper bounds of that value and to the lower bounds of its return value. + +```lua +local g +local h: H +local j = g(h) +-- G <: (H) -> I... +-- I... <: J +``` + +Property access is a constraint on a value's upper bounds. +```lua +local a: A +a.b = 2 +-- A <: {b: number} + +a[1] = 3 +-- A <: {number} +``` + +### Generalization + +Generalization is the process by which we infer that a function argument is generic. Broadly speaking, we solve constraints that arise from the function interior, we scan the signature of the function for types that are unconstrained, and we replace those types with generics. This much is all unchanged from the old solver. + +Unlike with the old solver, we never bind free types when dispatching a subtype constraint under local type inference. We only bind free types during generalization. + +If a type only appears in covariant positions in the function's signature, we can replace it by its lower bound. If it only appears in contravariant positions, we replace it by its upper bound. If it appears in both, we'll need to implement bounded generics to get it right. This is beyond the scope of this RFC. + +If a free type has neither upper nor lower bounds, we replace it with a generic. + +Some simple examples: + +```lua +function print_number(n: number) print(n) end + +function f(n) + print_number(n) +end +``` + +We arrive at the solution `never <: 'n <: number`. When we generalize, we can replace `'n` by its upper bound, namely `number`. We infer `f : (number) -> ()`. + +Next example: + +```lua +function index_of(tbl, el) -- index_of : ('a, 'b) -> 'r + for i = 0, #tbl do -- i : number + if tbl[i] == el then -- 'a <: {'c} + return i -- number <: 'r + end + end + return nil -- nil <: 'r +end +``` + +When typechecking this function, we have two constraints on `'r`, the return type. We can combine these constraints by taking the union of the lower bounds, leading us to `number | nil <: 'r <: unknown`. The type `'r` only appears in the return type of the function. The return type of this function is `number | nil`. + +At runtime, Luau allows any two values to be compared. Comparisons of values of mismatched types always return `false`. We therefore cannot produce any interesting constraints about `'b` or `'c`. + +We end up with these bounds: + +``` +never <: 'a <: {'c} +never <: 'b <: unknown +never <: 'c <: unknown +number | nil <: 'r <: unknown +``` + +`'a` appears in the argument position, so we replace it with its upper bound `{'c}`. `'b` and `'c` have no constraints at all so they are replaced by generics `B` and `C`. `'r` appears only in the return position and so is replaced by its lower bound `number | nil`. + +The final inferred type of `index_of` is `({C}, B) -> number | nil`. + +## Drawbacks + +This algorithm requires that we create a lot of union and intersection types. We need to be able to consistently pare down degenerate unions like `number | number`. + +Local type inference is also more permissive than what we have been doing up until now. For instance, the following is perfectly fine: + +```lua +local x = nil +if something then + x = 41 +else + x = "fourty one" +end +``` + +We'll infer `x : number | string | nil`. If the user wishes to constrain a value more tightly, they will have to write an annotation. + +## Alternatives + +### What TypeScript does + +TypeScript very clearly makes it work in what we would call a strict mode context, but we need more in order to offer a high quality nonstrict mode. For instance, TypeScript's autocomplete is completely helpless in the face of this code fragment: + +```ts +let x = null; +x = {a: "a", b: "pickles"}; +x. +``` + +TypeScript will complain that the assignment to `x` is illegal because `x` has type `null`. It will further offer no autocomplete suggestions at all when the user types the final `.`. + +It's not viable for us to require users to write type annotations. Many of our users do not yet know what types are but we are nevertheless committed to providing them a tool that is helpful to them. + +### Success Typing + +Success typing is the algorithm used by the Dialyzer inference engine for Erlang. Instead of attempting to prove that values always flow in sensible ways, it tries to prove that values _could_ flow in sensible ways. + +Success typing is quite nice in that it's very forgiving and can draw surprisingly useful information from untyped code, but that forgiving nature works against us in the case of strict mode. diff --git a/rfcs/syntax-floor-division-operator.md b/rfcs/syntax-floor-division-operator.md index 8ec3913f..b84ddd18 100644 --- a/rfcs/syntax-floor-division-operator.md +++ b/rfcs/syntax-floor-division-operator.md @@ -1,5 +1,7 @@ # Floor division operator +**Status**: Implemented + ## Summary Add floor division operator `//` to ease computing with integers. diff --git a/tests/AssemblyBuilderA64.test.cpp b/tests/AssemblyBuilderA64.test.cpp index ba8d40c2..56722c00 100644 --- a/tests/AssemblyBuilderA64.test.cpp +++ b/tests/AssemblyBuilderA64.test.cpp @@ -107,6 +107,13 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Binary") SINGLE_COMPARE(cmp(w0, 42), 0x7100A81F); } +TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "BinaryExtended") +{ + // reg, reg + SINGLE_COMPARE(add(x0, x1, w2, 3), 0x8B224C20); + SINGLE_COMPARE(sub(x0, x1, w2, 3), 0xCB224C20); +} + TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "BinaryImm") { // instructions @@ -524,6 +531,8 @@ TEST_CASE("LogTest") build.ldr(x0, mem(x1, 1, AddressKindA64::pre)); build.ldr(x0, mem(x1, 1, AddressKindA64::post)); + build.add(x1, x2, w3, 3); + build.setLabel(l); build.ret(); @@ -560,6 +569,7 @@ TEST_CASE("LogTest") ldr x0,[x1,#1] ldr x0,[x1,#1]! ldr x0,[x1]!,#1 + add x1,x2,w3 UXTW #3 .L1: ret )"; diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 25521e35..723e107e 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -295,4 +295,34 @@ TEST_CASE_FIXTURE(Fixture, "include_types_ancestry") CHECK(ancestryTypes.back()->asType()); } +TEST_CASE_FIXTURE(Fixture, "find_name_ancestry") +{ + ScopedFastFlag sff{"FixFindBindingAtFunctionName", true}; + check(R"( + local tbl = {} + function tbl:abc() end + )"); + const Position pos(2, 18); + + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), pos); + + REQUIRE(!ancestry.empty()); + CHECK(ancestry.back()->is()); +} + +TEST_CASE_FIXTURE(Fixture, "find_expr_ancestry") +{ + ScopedFastFlag sff{"FixFindBindingAtFunctionName", true}; + check(R"( + local tbl = {} + function tbl:abc() end + )"); + const Position pos(2, 29); + + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), pos); + + REQUIRE(!ancestry.empty()); + CHECK(ancestry.back()->is()); +} + TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index daa1f81a..84dee432 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,7 +15,6 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauAutocompleteLastTypecheck) using namespace Luau; @@ -34,36 +33,27 @@ struct ACFixtureImpl : BaseType AutocompleteResult autocomplete(unsigned row, unsigned column) { - if (FFlag::LuauAutocompleteLastTypecheck) - { - FrontendOptions opts; - opts.forAutocomplete = true; - this->frontend.check("MainModule", opts); - } + FrontendOptions opts; + opts.forAutocomplete = true; + this->frontend.check("MainModule", opts); return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); } AutocompleteResult autocomplete(char marker, StringCompletionCallback callback = nullCallback) { - if (FFlag::LuauAutocompleteLastTypecheck) - { - FrontendOptions opts; - opts.forAutocomplete = true; - this->frontend.check("MainModule", opts); - } + FrontendOptions opts; + opts.forAutocomplete = true; + this->frontend.check("MainModule", opts); return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), callback); } AutocompleteResult autocomplete(const ModuleName& name, Position pos, StringCompletionCallback callback = nullCallback) { - if (FFlag::LuauAutocompleteLastTypecheck) - { - FrontendOptions opts; - opts.forAutocomplete = true; - this->frontend.check(name, opts); - } + FrontendOptions opts; + opts.forAutocomplete = true; + this->frontend.check(name, opts); return Luau::autocomplete(this->frontend, name, pos, callback); } @@ -80,7 +70,7 @@ struct ACFixtureImpl : BaseType { if (prevChar == '@') { - LUAU_ASSERT("Illegal marker character" && c >= '0' && c <= '9'); + LUAU_ASSERT("Illegal marker character" && ((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z'))); LUAU_ASSERT("Duplicate marker found" && markerPosition.count(c) == 0); markerPosition.insert(std::pair{c, curPos}); } @@ -126,7 +116,6 @@ struct ACFixtureImpl : BaseType LUAU_ASSERT(i != markerPosition.end()); return i->second; } - ScopedFastFlag flag{"LuauAutocompleteHideSelfArg", true}; // Maps a marker character (0-9 inclusive) to a position in the source code. std::map markerPosition; }; @@ -987,6 +976,33 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_lambda") CHECK_EQ(ac.context, AutocompleteContext::Statement); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_of_do_block") +{ + ScopedFastFlag sff{"LuauAutocompleteDoEnd", true}; + + check("do @1"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("end")); + + check(R"( + function f() + do + @1 + end + @2 + )"); + + ac = autocomplete('1'); + + CHECK(ac.entryMap.count("end")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("end")); +} + TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") { check(R"( @@ -3083,6 +3099,86 @@ TEST_CASE_FIXTURE(ACFixture, "string_singleton_as_table_key") CHECK(ac.entryMap.count("\"down\"")); } +// https://github.com/Roblox/luau/issues/858 +TEST_CASE_FIXTURE(ACFixture, "string_singleton_in_if_statement") +{ + ScopedFastFlag sff{"LuauAutocompleteStringLiteralBounds", true}; + + check(R"( + --!strict + + type Direction = "left" | "right" + + local dir: Direction = "left" + + if dir == @1"@2"@3 then end + local a: {[Direction]: boolean} = {[@4"@5"@6]} + + if dir == @7`@8`@9 then end + local a: {[Direction]: boolean} = {[@A`@B`@C]} + )"); + + auto ac = autocomplete('1'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('2'); + + CHECK(ac.entryMap.count("left")); + CHECK(ac.entryMap.count("right")); + + ac = autocomplete('3'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('4'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('5'); + + CHECK(ac.entryMap.count("left")); + CHECK(ac.entryMap.count("right")); + + ac = autocomplete('6'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('7'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('8'); + + CHECK(ac.entryMap.count("left")); + CHECK(ac.entryMap.count("right")); + + ac = autocomplete('9'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('A'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); + + ac = autocomplete('B'); + + CHECK(ac.entryMap.count("left")); + CHECK(ac.entryMap.count("right")); + + ac = autocomplete('C'); + + CHECK(!ac.entryMap.count("left")); + CHECK(!ac.entryMap.count("right")); +} + TEST_CASE_FIXTURE(ACFixture, "autocomplete_string_singleton_equality") { check(R"( @@ -3556,8 +3652,6 @@ TEST_CASE_FIXTURE(ACFixture, "frontend_use_correct_global_scope") TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes") { - ScopedFastFlag flag{"LuauDisableCompletionOutsideQuotes", true}; - loadDefinition(R"( declare function require(path: string): any )"); @@ -3573,8 +3667,7 @@ TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes") )"); StringCompletionCallback callback = [](std::string, std::optional, - std::optional contents) -> std::optional - { + std::optional contents) -> std::optional { Luau::AutocompleteEntryMap results = {{"test", Luau::AutocompleteEntry{Luau::AutocompleteEntryKind::String, std::nullopt, false, false}}}; return results; }; @@ -3595,8 +3688,6 @@ TEST_CASE_FIXTURE(ACFixture, "string_completion_outside_quotes") TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_empty") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: () -> ()) a() @@ -3618,8 +3709,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (number, string) -> ()) a() @@ -3641,8 +3730,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_single_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (number, string) -> (string)) a() @@ -3664,8 +3751,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_args_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (number, string) -> (string, number)) a() @@ -3687,8 +3772,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__noargs_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: () -> (string, number)) a() @@ -3710,8 +3793,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled__varargs_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (...number) -> (string, number)) a() @@ -3733,8 +3814,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (string, ...number) -> (string, number)) a() @@ -3756,8 +3835,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_varargs_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (string, ...number) -> ...number) a() @@ -3779,8 +3856,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_multi_varargs_multi_varargs_return") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (string, ...number) -> (boolean, ...number)) a() @@ -3802,8 +3877,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_named_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (foo: number, bar: string) -> (string, number)) a() @@ -3825,8 +3898,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (number, bar: string) -> (string, number)) a() @@ -3848,8 +3919,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_partially_args_last") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (foo: number, string) -> (string, number)) a() @@ -3871,8 +3940,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local t = { a = 1, b = 2 } @@ -3896,8 +3963,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (tbl: { x: number, y: number }) -> number) return a({x=2, y = 3}) end foo(@1) @@ -3916,8 +3981,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_returns") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local t = { a = 1, b = 2 } @@ -3941,8 +4004,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_table_literal_args") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: () -> { x: number, y: number }) return {x=2, y = 3} end foo(@1) @@ -3961,8 +4022,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_typeof_vararg") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local t = { a = 1, b = 2 } @@ -3986,8 +4045,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_type_pack_vararg") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (...A) -> number, ...: A) return a(...) @@ -4009,8 +4066,6 @@ foo(@1) TEST_CASE_FIXTURE(ACFixture, "anonymous_autofilled_generic_on_argument_type_pack_vararg") { - ScopedFastFlag flag{"LuauAnonymousAutofilled1", true}; - check(R"( local function foo(a: (...: T...) -> number) return a(4, 5, 6) diff --git a/tests/CodeAllocator.test.cpp b/tests/CodeAllocator.test.cpp index b1416736..9c99862d 100644 --- a/tests/CodeAllocator.test.cpp +++ b/tests/CodeAllocator.test.cpp @@ -47,6 +47,55 @@ TEST_CASE("CodeAllocation") CHECK(nativeEntry == nativeData + kCodeAlignment); } +TEST_CASE("CodeAllocationCallbacks") +{ + struct AllocationData + { + size_t bytesAllocated = 0; + size_t bytesFreed = 0; + }; + + AllocationData allocationData{}; + + const auto allocationCallback = [](void* context, void* oldPointer, size_t oldSize, void* newPointer, size_t newSize) { + AllocationData& allocationData = *static_cast(context); + if (oldPointer != nullptr) + { + CHECK(oldSize != 0); + + allocationData.bytesFreed += oldSize; + } + + if (newPointer != nullptr) + { + CHECK(newSize != 0); + + allocationData.bytesAllocated += newSize; + } + }; + + const size_t blockSize = 1024 * 1024; + const size_t maxTotalSize = 1024 * 1024; + + { + CodeAllocator allocator(blockSize, maxTotalSize, allocationCallback, &allocationData); + + uint8_t* nativeData = nullptr; + size_t sizeNativeData = 0; + uint8_t* nativeEntry = nullptr; + + std::vector code; + code.resize(128); + + REQUIRE(allocator.allocate(nullptr, 0, code.data(), code.size(), nativeData, sizeNativeData, nativeEntry)); + CHECK(allocationData.bytesAllocated == blockSize); + CHECK(allocationData.bytesFreed == 0); + } + + CHECK(allocationData.bytesAllocated == blockSize); + CHECK(allocationData.bytesFreed == blockSize); +} + TEST_CASE("CodeAllocationFailure") { size_t blockSize = 3000; @@ -137,7 +186,7 @@ TEST_CASE("WindowsUnwindCodesX64") unwind.startInfo(UnwindBuilder::X64); unwind.startFunction(); - unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}); + unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}, {}); unwind.finishFunction(0x11223344, 0x55443322); unwind.finishInfo(); @@ -161,7 +210,7 @@ TEST_CASE("Dwarf2UnwindCodesX64") unwind.startInfo(UnwindBuilder::X64); unwind.startFunction(); - unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}); + unwind.prologueX64(/* prologueSize= */ 23, /* stackSize= */ 72, /* setupFrame= */ true, {rdi, rsi, rbx, r12, r13, r14, r15}, {}); unwind.finishFunction(0, 0); unwind.finishInfo(); @@ -259,6 +308,11 @@ static void throwing(int64_t arg) throw std::runtime_error("testing"); } +static void nonthrowing(int64_t arg) +{ + CHECK(arg == 25); +} + TEST_CASE("GeneratedCodeExecutionWithThrowX64") { using namespace X64; @@ -289,7 +343,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") uint32_t prologueSize = build.setLabel().location; - unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}); + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}, {}); // Body build.mov(rNonVol1, rArg1); @@ -329,6 +383,8 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") using FunctionType = int64_t(int64_t, void (*)(int64_t)); FunctionType* f = (FunctionType*)nativeEntry; + f(10, nonthrowing); + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here try { @@ -340,6 +396,121 @@ TEST_CASE("GeneratedCodeExecutionWithThrowX64") } } +static void obscureThrowCase(int64_t (*f)(int64_t, void (*)(int64_t))) +{ + // To simplify debugging, CHECK_THROWS_WITH_AS is not used here + try + { + f(10, throwing); + } + catch (const std::runtime_error& error) + { + CHECK(strcmp(error.what(), "testing") == 0); + } +} + +TEST_CASE("GeneratedCodeExecutionWithThrowX64Simd") +{ + // This test requires AVX + if (!Luau::CodeGen::isSupported()) + return; + + using namespace X64; + + AssemblyBuilderX64 build(/* logText= */ false); + +#if defined(_WIN32) + std::unique_ptr unwind = std::make_unique(); +#else + std::unique_ptr unwind = std::make_unique(); +#endif + + unwind->startInfo(UnwindBuilder::X64); + + Label functionBegin = build.setLabel(); + unwind->startFunction(); + + int stackSize = 32 + 64; + int localsSize = 16; + + // Prologue + build.push(rNonVol1); + build.push(rNonVol2); + build.push(rbp); + build.sub(rsp, stackSize + localsSize); + + if (build.abi == ABIX64::Windows) + { + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x40)], xmm6); + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x30)], xmm7); + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x20)], xmm8); + build.vmovaps(xmmword[rsp + ((stackSize + localsSize) - 0x10)], xmm9); + } + + uint32_t prologueSize = build.setLabel().location; + + if (build.abi == ABIX64::Windows) + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rbp}, {xmm6, xmm7, xmm8, xmm9}); + else + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rbp}, {}); + + // Body + build.vxorpd(xmm0, xmm0, xmm0); + build.vmovsd(xmm6, xmm0, xmm0); + build.vmovsd(xmm7, xmm0, xmm0); + build.vmovsd(xmm8, xmm0, xmm0); + build.vmovsd(xmm9, xmm0, xmm0); + + build.mov(rNonVol1, rArg1); + build.mov(rNonVol2, rArg2); + + build.add(rNonVol1, 15); + build.mov(rArg1, rNonVol1); + build.call(rNonVol2); + + // Epilogue + if (build.abi == ABIX64::Windows) + { + build.vmovaps(xmm6, xmmword[rsp + ((stackSize + localsSize) - 0x40)]); + build.vmovaps(xmm7, xmmword[rsp + ((stackSize + localsSize) - 0x30)]); + build.vmovaps(xmm8, xmmword[rsp + ((stackSize + localsSize) - 0x20)]); + build.vmovaps(xmm9, xmmword[rsp + ((stackSize + localsSize) - 0x10)]); + } + + build.add(rsp, stackSize + localsSize); + build.pop(rbp); + build.pop(rNonVol2); + build.pop(rNonVol1); + build.ret(); + + unwind->finishFunction(build.getLabelOffset(functionBegin), ~0u); + + build.finalize(); + + unwind->finishInfo(); + + size_t blockSize = 1024 * 1024; + size_t maxTotalSize = 1024 * 1024; + CodeAllocator allocator(blockSize, maxTotalSize); + + allocator.context = unwind.get(); + allocator.createBlockUnwindInfo = createBlockUnwindInfo; + allocator.destroyBlockUnwindInfo = destroyBlockUnwindInfo; + + uint8_t* nativeData; + size_t sizeNativeData; + uint8_t* nativeEntry; + REQUIRE(allocator.allocate(build.data.data(), build.data.size(), build.code.data(), build.code.size(), nativeData, sizeNativeData, nativeEntry)); + REQUIRE(nativeEntry); + + using FunctionType = int64_t(int64_t, void (*)(int64_t)); + FunctionType* f = (FunctionType*)nativeEntry; + + f(10, nonthrowing); + + obscureThrowCase(f); +} + TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") { using namespace X64; @@ -375,7 +546,7 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") uint32_t prologueSize = build.setLabel().location - start1.location; - unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}); + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {rNonVol1, rNonVol2}, {}); // Body build.mov(rNonVol1, rArg1); @@ -414,7 +585,7 @@ TEST_CASE("GeneratedCodeExecutionMultipleFunctionsWithThrowX64") uint32_t prologueSize = build.setLabel().location - start2.location; - unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rNonVol3, rNonVol4}); + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ false, {rNonVol1, rNonVol2, rNonVol3, rNonVol4}, {}); // Body build.mov(rNonVol3, rArg1); @@ -511,7 +682,7 @@ TEST_CASE("GeneratedCodeExecutionWithThrowOutsideTheGateX64") uint32_t prologueSize = build.setLabel().location; - unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {r10, r11, r12, r13, r14, r15}); + unwind->prologueX64(prologueSize, stackSize + localsSize, /* setupFrame= */ true, {r10, r11, r12, r13, r14, r15}, {}); // Body build.mov(rax, rArg1); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index d368af66..49498744 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -63,6 +63,35 @@ static std::string compileTypeTable(const char* source) TEST_SUITE_BEGIN("Compiler"); +TEST_CASE("BytecodeIsStable") +{ + // As noted in Bytecode.h, all enums used for bytecode storage and serialization are order-sensitive + // Adding entries in the middle will typically pass the tests but break compatibility + // This test codifies this by validating that in each enum, the last (or close-to-last) entry has a fixed encoding + + // This test will need to get occasionally revised to "move" the checked enum entries forward as we ship newer versions + // When doing so, please add *new* checks for more recent bytecode versions and keep existing checks in place. + + // Bytecode ops (serialized & in-memory) + CHECK(LOP_FASTCALL2K == 75); // bytecode v1 + CHECK(LOP_JUMPXEQKS == 80); // bytecode v3 + + // Bytecode fastcall ids (serialized & in-memory) + // Note: these aren't strictly bound to specific bytecode versions, but must monotonically increase to keep backwards compat + CHECK(LBF_VECTOR == 54); + CHECK(LBF_TOSTRING == 63); + + // Bytecode capture type (serialized & in-memory) + CHECK(LCT_UPVAL == 2); // bytecode v1 + + // Bytecode constants (serialized) + CHECK(LBC_CONSTANT_CLOSURE == 6); // bytecode v1 + + // Bytecode type encoding (serialized & in-memory) + // Note: these *can* change retroactively *if* type version is bumped, but probably shouldn't + LUAU_ASSERT(LBC_TYPE_VECTOR == 8); // type version 1 +} + TEST_CASE("CompileToBytecode") { Luau::BytecodeBuilder bcb; @@ -4213,7 +4242,7 @@ RETURN R0 0 bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); Luau::CompileOptions options; const char* mutableGlobals[] = {"Game", "Workspace", "game", "plugin", "script", "shared", "workspace", NULL}; - options.mutableGlobals = &mutableGlobals[0]; + options.mutableGlobals = mutableGlobals; Luau::compileOrThrow(bcb, source, options); CHECK_EQ("\n" + bcb.dumpFunction(0), R"( @@ -5085,7 +5114,7 @@ RETURN R1 1 )"); } -TEST_CASE("InlineBasicProhibited") +TEST_CASE("InlineProhibited") { // we can't inline variadic functions CHECK_EQ("\n" + compileFunction(R"( @@ -5125,6 +5154,66 @@ RETURN R1 1 )"); } +TEST_CASE("InlineProhibitedRecursion") +{ + // we can't inline recursive invocations of functions in the functions + // this is actually profitable in certain cases, but it complicates the compiler as it means a local has multiple registers/values + + // in this example, inlining is blocked because we're compiling fact() and we don't yet have the cost model / profitability data for fact() + CHECK_EQ("\n" + compileFunction(R"( +local function fact(n) + return if n <= 1 then 1 else fact(n-1)*n +end + +return fact +)", + 0, 2), + R"( +LOADN R2 1 +JUMPIFNOTLE R0 R2 L0 +LOADN R1 1 +RETURN R1 1 +L0: GETUPVAL R2 0 +SUBK R3 R0 K0 [1] +CALL R2 1 1 +MUL R1 R2 R0 +RETURN R1 1 +)"); + + // in this example, inlining of fact() succeeds, but the nested call to fact() fails since fact is already on the inline stack + CHECK_EQ("\n" + compileFunction(R"( +local function fact(n) + return if n <= 1 then 1 else fact(n-1)*n +end + +local function factsafe(n) + assert(n >= 1) + return fact(n) +end + +return factsafe +)", + 1, 2), + R"( +LOADN R3 1 +JUMPIFLE R3 R0 L0 +LOADB R2 0 +1 +L0: LOADB R2 1 +L1: FASTCALL1 1 R2 L2 +GETIMPORT R1 1 [assert] +CALL R1 1 0 +L2: LOADN R2 1 +JUMPIFNOTLE R0 R2 L3 +LOADN R1 1 +RETURN R1 1 +L3: GETUPVAL R2 0 +SUBK R3 R0 K2 [1] +CALL R2 1 1 +MUL R1 R2 R0 +RETURN R1 1 +)"); +} + TEST_CASE("InlineNestedLoops") { // functions with basic loops get inlined @@ -6978,8 +7067,6 @@ L3: RETURN R0 0 TEST_CASE("BuiltinArity") { - ScopedFastFlag sff("LuauCompileFixBuiltinArity", true); - // by default we can't assume that we know parameter/result count for builtins as they can be overridden at runtime CHECK_EQ("\n" + compileFunction(R"( return math.abs(unknown()) @@ -7100,8 +7187,6 @@ L1: RETURN R3 1 TEST_CASE("EncodedTypeTable") { - ScopedFastFlag sff("LuauCompileFunctionType", true); - CHECK_EQ("\n" + compileTypeTable(R"( function myfunc(test: string, num: number) print(test) @@ -7153,8 +7238,6 @@ Str:test(234) TEST_CASE("HostTypesAreUserdata") { - ScopedFastFlag sff("LuauCompileFunctionType", true); - CHECK_EQ("\n" + compileTypeTable(R"( function myfunc(test: string, num: number) print(test) @@ -7181,8 +7264,6 @@ end TEST_CASE("HostTypesVector") { - ScopedFastFlag sff("LuauCompileFunctionType", true); - CHECK_EQ("\n" + compileTypeTable(R"( function myfunc(test: Instance, pos: Vector3) end @@ -7206,8 +7287,6 @@ end TEST_CASE("TypeAliasScoping") { - ScopedFastFlag sff("LuauCompileFunctionType", true); - CHECK_EQ("\n" + compileTypeTable(R"( do type Part = number @@ -7242,8 +7321,6 @@ type Instance = string TEST_CASE("TypeAliasResolve") { - ScopedFastFlag sff("LuauCompileFunctionType", true); - CHECK_EQ("\n" + compileTypeTable(R"( type Foo1 = number type Foo2 = { number } @@ -7264,16 +7341,38 @@ end )"); } +TEST_CASE("TypeUnionIntersection") +{ + CHECK_EQ("\n" + compileTypeTable(R"( +function myfunc(test: string | nil, foo: nil) +end + +function myfunc2(test: string & nil, foo: nil) +end + +function myfunc3(test: string | number, foo: nil) +end + +function myfunc4(test: string & number, foo: nil) +end +)"), + R"( +0: function(string?, nil) +1: function(any, nil) +2: function(any, nil) +3: function(any, nil) +)"); +} + TEST_CASE("BuiltinFoldMathK") { - ScopedFastFlag sff("LuauCompileFoldMathK", true); - // we can fold math.pi at optimization level 2 CHECK_EQ("\n" + compileFunction(R"( function test() return math.pi * 2 end -)", 0, 2), +)", + 0, 2), R"( LOADK R0 K0 [6.2831853071795862] RETURN R0 1 @@ -7284,7 +7383,8 @@ RETURN R0 1 function test() return math.pi * 2 end -)", 0, 1), +)", + 0, 1), R"( GETIMPORT R1 3 [math.pi] MULK R0 R1 K0 [2] @@ -7298,7 +7398,8 @@ function test() end math = { pi = 4 } -)", 0, 2), +)", + 0, 2), R"( GETGLOBAL R2 K1 ['math'] GETTABLEKS R1 R2 K2 ['pi'] diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index f4a74b66..c8918954 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -8,7 +8,6 @@ #include "Luau/DenseHash.h" #include "Luau/ModuleResolver.h" #include "Luau/TypeInfer.h" -#include "Luau/StringUtils.h" #include "Luau/BytecodeBuilder.h" #include "Luau/Frontend.h" @@ -24,6 +23,9 @@ extern bool verbose; extern bool codegen; extern int optimizationLevel; +LUAU_FASTFLAG(LuauPCallDebuggerFix); +LUAU_FASTFLAG(LuauFloorDivision); + static lua_CompileOptions defaultOptions() { lua_CompileOptions copts = {}; @@ -266,6 +268,12 @@ static void* limitedRealloc(void* ud, void* ptr, size_t osize, size_t nsize) TEST_SUITE_BEGIN("Conformance"); +TEST_CASE("CodegenSupported") +{ + if (codegen && !luau_codegen_supported()) + MESSAGE("Native code generation is not supported by the current configuration and will be disabled"); +} + TEST_CASE("Assert") { runConformance("assert.lua"); @@ -273,7 +281,8 @@ TEST_CASE("Assert") TEST_CASE("Basic") { - ScopedFastFlag sff("LuauCompileFixBuiltinArity", true); + ScopedFastFlag sffs{"LuauFloorDivision", true}; + ScopedFastFlag sfff{"LuauImproveForN", true}; runConformance("basic.lua"); } @@ -328,8 +337,6 @@ TEST_CASE("Clear") TEST_CASE("Strings") { - ScopedFastFlag sff("LuauCompileFixBuiltinArity", true); - runConformance("strings.lua"); } @@ -360,6 +367,7 @@ TEST_CASE("Errors") TEST_CASE("Events") { + ScopedFastFlag sffs{"LuauFloorDivision", true}; runConformance("events.lua"); } @@ -441,6 +449,8 @@ TEST_CASE("Pack") TEST_CASE("Vector") { + ScopedFastFlag sffs{"LuauFloorDivision", true}; + lua_CompileOptions copts = defaultOptions(); copts.vectorCtor = "vector"; @@ -1211,15 +1221,90 @@ TEST_CASE("IfElseExpression") runConformance("ifelseexpr.lua"); } +// Optionally returns debug info for the first Luau stack frame that is encountered on the callstack. +static std::optional getFirstLuauFrameDebugInfo(lua_State* L) +{ + static std::string_view kLua = "Lua"; + lua_Debug ar; + for (int i = 0; lua_getinfo(L, i, "sl", &ar); i++) + { + if (kLua == ar.what) + return ar; + } + return std::nullopt; +} + TEST_CASE("TagMethodError") { - runConformance("tmerror.lua", [](lua_State* L) { - auto* cb = lua_callbacks(L); + static std::vector expectedHits; - cb->debugprotectederror = [](lua_State* L) { - CHECK(lua_isyieldable(L)); - }; - }); + // Loop over two modes: + // when doLuaBreak is false the test only verifies that callbacks occur on the expected lines in the Luau source + // when doLuaBreak is true the test additionally calls lua_break to ensure breaking the debugger doesn't cause the VM to crash + for (bool doLuaBreak : {false, true}) + { + std::optional sff; + if (doLuaBreak) + { + // If doLuaBreak is true then LuauPCallDebuggerFix must be enabled to avoid crashing the tests. + sff = {"LuauPCallDebuggerFix", true}; + } + + if (FFlag::LuauPCallDebuggerFix) + { + expectedHits = {22, 32}; + } + else + { + expectedHits = { + 9, + 17, + 17, + 22, + 27, + 27, + 32, + 37, + }; + } + + static int index; + static bool luaBreak; + index = 0; + luaBreak = doLuaBreak; + + // 'yieldCallback' doesn't do anything, but providing the callback to runConformance + // ensures that the call to lua_break doesn't cause an error to be generated because + // runConformance doesn't expect the VM to be in the state LUA_BREAK. + auto yieldCallback = [](lua_State* L) {}; + + runConformance( + "tmerror.lua", + [](lua_State* L) { + auto* cb = lua_callbacks(L); + + cb->debugprotectederror = [](lua_State* L) { + std::optional ar = getFirstLuauFrameDebugInfo(L); + + CHECK(lua_isyieldable(L)); + REQUIRE(ar.has_value()); + REQUIRE(index < int(std::size(expectedHits))); + CHECK(ar->currentline == expectedHits[index++]); + + if (luaBreak) + { + // Cause luau execution to break when 'error' is called via 'pcall' + // This call to lua_break is a regression test for an issue where debugprotectederror + // was called on a thread that couldn't be yielded even though lua_isyieldable was true. + lua_break(L); + } + }; + }, + yieldCallback); + + // Make sure the number of break points hit was the expected number + CHECK(index == std::size(expectedHits)); + } } TEST_CASE("Coverage") @@ -1538,6 +1623,9 @@ static void pushInt64(lua_State* L, int64_t value) TEST_CASE("Userdata") { + + ScopedFastFlag sffs{"LuauFloorDivision", true}; + runConformance("userdata.lua", [](lua_State* L) { // create metatable with all the metamethods lua_newtable(L); @@ -1657,6 +1745,19 @@ TEST_CASE("Userdata") nullptr); lua_setfield(L, -2, "__div"); + // __idiv + lua_pushcfunction( + L, + [](lua_State* L) { + // for testing we use different semantics here compared to __div: __idiv rounds to negative inf, __div truncates (rounds to zero) + // additionally, division loses precision here outside of 2^53 range + // we do not necessarily recommend this behavior in production code! + pushInt64(L, int64_t(floor(double(getInt64(L, 1)) / double(getInt64(L, 2))))); + return 1; + }, + nullptr); + lua_setfield(L, -2, "__idiv"); + // __mod lua_pushcfunction( L, @@ -1726,7 +1827,6 @@ TEST_CASE("Native") TEST_CASE("NativeTypeAnnotations") { ScopedFastFlag bytecodeVersion4("BytecodeVersion4", true); - ScopedFastFlag luauCompileFunctionType("LuauCompileFunctionType", true); // This tests requires code to run natively, otherwise all 'is_native' checks will fail if (!codegen || !luau_codegen_supported()) diff --git a/tests/CostModel.test.cpp b/tests/CostModel.test.cpp index 206b83a8..29fffb4f 100644 --- a/tests/CostModel.test.cpp +++ b/tests/CostModel.test.cpp @@ -133,8 +133,6 @@ end TEST_CASE("ControlFlow") { - ScopedFastFlag sff("LuauAssignmentHasCost", true); - uint64_t model = modelFunction(R"( function test(a) while a < 0 do @@ -244,8 +242,6 @@ end TEST_CASE("MultipleAssignments") { - ScopedFastFlag sff("LuauAssignmentHasCost", true); - uint64_t model = modelFunction(R"( function test(a) local x = 0 diff --git a/tests/Differ.test.cpp b/tests/Differ.test.cpp index c2b09bbd..84a383c1 100644 --- a/tests/Differ.test.cpp +++ b/tests/Differ.test.cpp @@ -1528,4 +1528,43 @@ TEST_CASE_FIXTURE(DifferFixture, "equal_generictp_cyclic") compareTypesEq("foo", "almostFoo"); } +TEST_CASE_FIXTURE(DifferFixture, "symbol_forward") +{ + CheckResult result = check(R"( + local foo = 5 + local almostFoo = "five" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + INFO(Luau::toString(requireType("foo"))); + INFO(Luau::toString(requireType("almostFoo"))); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at foo has type number, while the right type at almostFoo has type string)", + true); +} + +TEST_CASE_FIXTURE(DifferFixture, "newlines") +{ + CheckResult result = check(R"( + local foo = 5 + local almostFoo = "five" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + INFO(Luau::toString(requireType("foo"))); + INFO(Luau::toString(requireType("almostFoo"))); + + compareTypesNe("foo", "almostFoo", + R"(DiffError: these two types are not equal because the left type at + foo +has type + number, +while the right type at + almostFoo +has type + string)", + true, true); +} + TEST_SUITE_END(); diff --git a/tests/Error.test.cpp b/tests/Error.test.cpp index 5ba5c112..0a71794f 100644 --- a/tests/Error.test.cpp +++ b/tests/Error.test.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Error.h" +#include "Fixture.h" #include "doctest.h" using namespace Luau; @@ -13,4 +14,24 @@ TEST_CASE("TypeError_code_should_return_nonzero_code") CHECK_GE(e.code(), 1000); } +TEST_CASE_FIXTURE(BuiltinsFixture, "metatable_names_show_instead_of_tables") +{ + frontend.options.retainFullTypeGraphs = false; + ScopedFastFlag sff{"LuauStacklessTypeClone2", true}; + CheckResult result = check(R"( +--!strict +local Account = {} +Account.__index = Account +function Account.deposit(self: Account, x: number) + self.balance += x +end +type Account = typeof(setmetatable({} :: { balance: number }, Account)) +local x: Account = 5 +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'number' could not be converted into 'Account'", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index d4fa7178..f31727e0 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -9,6 +9,7 @@ #include "Luau/Parser.h" #include "Luau/Type.h" #include "Luau/TypeAttach.h" +#include "Luau/TypeInfer.h" #include "Luau/Transpiler.h" #include "doctest.h" @@ -144,8 +145,6 @@ Fixture::Fixture(bool freeze, bool prepareAutocomplete) configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.parseOptions.captureComments = true; - registerBuiltinTypes(frontend.globals); - Luau::freeze(frontend.globals.globalTypes); Luau::freeze(frontend.globalsForAutocomplete.globalTypes); diff --git a/tests/Fixture.h b/tests/Fixture.h index a9c5d9b0..8d997140 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -160,14 +160,37 @@ void createSomeClasses(Frontend* frontend); template struct DifferFixtureGeneric : BaseFixture { - void compareNe(TypeId left, TypeId right, const std::string& expectedMessage) + std::string normalizeWhitespace(std::string msg) + { + std::string normalizedMsg = ""; + bool wasWhitespace = true; + for (char c : msg) + { + bool isWhitespace = c == ' ' || c == '\n'; + if (wasWhitespace && isWhitespace) + continue; + normalizedMsg += isWhitespace ? ' ' : c; + wasWhitespace = isWhitespace; + } + if (wasWhitespace) + normalizedMsg.pop_back(); + return normalizedMsg; + } + + void compareNe(TypeId left, TypeId right, const std::string& expectedMessage, bool multiLine) + { + compareNe(left, std::nullopt, right, std::nullopt, expectedMessage, multiLine); + } + + void compareNe(TypeId left, std::optional symbolLeft, TypeId right, std::optional symbolRight, + const std::string& expectedMessage, bool multiLine) { std::string diffMessage; try { - DifferResult diffRes = diff(left, right); + DifferResult diffRes = diffWithSymbols(left, right, symbolLeft, symbolRight); REQUIRE_MESSAGE(diffRes.diffError.has_value(), "Differ did not report type error, even though types are unequal"); - diffMessage = diffRes.diffError->toString(); + diffMessage = diffRes.diffError->toString(multiLine); } catch (const InternalCompilerError& e) { @@ -176,9 +199,19 @@ struct DifferFixtureGeneric : BaseFixture CHECK_EQ(expectedMessage, diffMessage); } - void compareTypesNe(const std::string& leftSymbol, const std::string& rightSymbol, const std::string& expectedMessage) + void compareTypesNe(const std::string& leftSymbol, const std::string& rightSymbol, const std::string& expectedMessage, bool forwardSymbol = false, + bool multiLine = false) { - compareNe(BaseFixture::requireType(leftSymbol), BaseFixture::requireType(rightSymbol), expectedMessage); + if (forwardSymbol) + { + compareNe( + BaseFixture::requireType(leftSymbol), leftSymbol, BaseFixture::requireType(rightSymbol), rightSymbol, expectedMessage, multiLine); + } + else + { + compareNe( + BaseFixture::requireType(leftSymbol), std::nullopt, BaseFixture::requireType(rightSymbol), std::nullopt, expectedMessage, multiLine); + } } void compareEq(TypeId left, TypeId right) diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index fda0a6f0..96032fd1 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -446,8 +446,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_longer") TEST_CASE_FIXTURE(FrontendFixture, "cycle_incremental_type_surface_exports") { - ScopedFastFlag luauFixCyclicModuleExports{"LuauFixCyclicModuleExports", true}; - fileResolver.source["game/A"] = R"( local b = require(game.B) export type atype = { x: b.btype } @@ -1224,4 +1222,28 @@ TEST_CASE_FIXTURE(FrontendFixture, "parse_only") CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(FrontendFixture, "markdirty_early_return") +{ + ScopedFastFlag fflag("CorrectEarlyReturnInMarkDirty", true); + + constexpr char moduleName[] = "game/Gui/Modules/A"; + fileResolver.source[moduleName] = R"( + return 1 + )"; + + { + std::vector markedDirty; + frontend.markDirty(moduleName, &markedDirty); + CHECK(markedDirty.empty()); + } + + frontend.parse(moduleName); + + { + std::vector markedDirty; + frontend.markDirty(moduleName, &markedDirty); + CHECK(!markedDirty.empty()); + } +} + TEST_SUITE_END(); diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 57103352..b511156d 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -5,6 +5,7 @@ #include "Luau/IrUtils.h" #include "Luau/OptimizeConstProp.h" #include "Luau/OptimizeFinalX64.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -620,11 +621,11 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowEq") }); withTwoBlocks([this](IrOp a, IrOp b) { - build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(0), a, b); + build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(0), build.cond(IrCondition::Equal), a, b); }); withTwoBlocks([this](IrOp a, IrOp b) { - build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), a, b); + build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(1), build.cond(IrCondition::Equal), a, b); }); updateUseCounts(build.function); @@ -883,8 +884,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue") bb_0: STORE_TAG R0, tnumber STORE_DOUBLE R0, 0.5 - %2 = LOAD_TVALUE R0 - STORE_TVALUE R1, %2 + STORE_SPLIT_TVALUE R1, tnumber, 0.5 STORE_TAG R3, tnumber STORE_DOUBLE R3, 0.5 RETURN 0u @@ -991,6 +991,52 @@ bb_fallback_1: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "RememberNewTableState") +{ + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + IrOp newtable = build.inst(IrCmd::NEW_TABLE, build.constUint(16), build.constUint(32)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(0), newtable); + + IrOp table = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + + build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); + build.inst(IrCmd::CHECK_READONLY, table, fallback); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table, build.constInt(14), fallback); + + build.inst(IrCmd::SET_TABLE, build.vmReg(1), build.vmReg(0), build.constUint(13)); // Invalidate table knowledge + + build.inst(IrCmd::CHECK_NO_METATABLE, table, fallback); + build.inst(IrCmd::CHECK_READONLY, table, fallback); + build.inst(IrCmd::CHECK_ARRAY_SIZE, table, build.constInt(14), fallback); + + build.inst(IrCmd::RETURN, build.constUint(0)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = NEW_TABLE 16u, 32u + STORE_POINTER R0, %0 + SET_TABLE R1, R0, 13u + CHECK_NO_METATABLE %0, bb_fallback_1 + CHECK_READONLY %0, bb_fallback_1 + CHECK_ARRAY_SIZE %0, 14i, bb_fallback_1 + RETURN 0u + +bb_fallback_1: + RETURN 1u + +)"); +} + TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") { IrOp block = build.block(IrBlockKind::Internal); @@ -1313,7 +1359,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") build.beginBlock(block); build.inst(IrCmd::STORE_INT, build.vmReg(1), build.constInt(5)); IrOp value = build.inst(IrCmd::LOAD_INT, build.vmReg(1)); - build.inst(IrCmd::JUMP_EQ_INT, value, build.constInt(5), trueBlock, falseBlock); + build.inst(IrCmd::JUMP_CMP_INT, value, build.constInt(5), build.cond(IrCondition::Equal), trueBlock, falseBlock); build.beginBlock(trueBlock); build.inst(IrCmd::RETURN, build.constUint(1)); @@ -1510,7 +1556,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") IrOp repeat = build.block(IrBlockKind::Internal); build.beginBlock(entry); - build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), block, exit1); + build.inst(IrCmd::JUMP_CMP_INT, build.constInt(0), build.constInt(1), build.cond(IrCondition::Equal), block, exit1); build.beginBlock(exit1); build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); @@ -1586,7 +1632,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InvalidateReglinkVersion") build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tstring)); IrOp tv2 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(2)); build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), tv2); - IrOp ft = build.inst(IrCmd::NEW_TABLE); + IrOp ft = build.inst(IrCmd::NEW_TABLE, build.constUint(0), build.constUint(0)); build.inst(IrCmd::STORE_POINTER, build.vmReg(2), ft); build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(ttable)); IrOp tv1 = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1)); @@ -1606,7 +1652,7 @@ bb_0: STORE_TAG R2, tstring %1 = LOAD_TVALUE R2 STORE_TVALUE R1, %1 - %3 = NEW_TABLE + %3 = NEW_TABLE 0u, 0u STORE_POINTER R2, %3 STORE_TAG R2, ttable STORE_TVALUE R0, %1 @@ -1811,8 +1857,8 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PartialStoreInvalidation") build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); // Should be reloaded - build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tboolean)); - build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); // Should be reloaded + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(1), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(0))); build.inst(IrCmd::RETURN, build.constUint(0)); @@ -1826,9 +1872,8 @@ bb_0: STORE_DOUBLE R0, 0.5 %3 = LOAD_TVALUE R0 STORE_TVALUE R1, %3 - STORE_TAG R0, tboolean - %6 = LOAD_TVALUE R0 - STORE_TVALUE R1, %6 + STORE_TAG R0, tnumber + STORE_SPLIT_TVALUE R1, tnumber, 0.5 RETURN 0u )"); @@ -1886,6 +1931,135 @@ bb_0: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateHashSlotChecks") +{ + ScopedFastFlag luauReuseHashSlots{"LuauReuseHashSlots2", true}; + + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + // This roughly corresponds to 'return t.a + t.a' + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + IrOp slot1 = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table1, build.constUint(3), build.vmConst(1)); + build.inst(IrCmd::CHECK_SLOT_MATCH, slot1, build.vmConst(1), fallback); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, slot1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + IrOp slot1b = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table1, build.constUint(8), build.vmConst(1)); // This will be removed + build.inst(IrCmd::CHECK_SLOT_MATCH, slot1b, build.vmConst(1), fallback); // Key will be replaced with undef here + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, slot1b, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(4), value1b); + + IrOp a = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(3)); + IrOp b = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(4)); + IrOp sum = build.inst(IrCmd::ADD_NUM, a, b); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(2), sum); + + build.inst(IrCmd::RETURN, build.vmReg(2), build.constUint(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(0), build.constUint(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + // In the future, we might even see duplicate identical TValue loads go away + // In the future, we might even see loads of different VM regs with the same value go away + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_POINTER R1 + %1 = GET_SLOT_NODE_ADDR %0, 3u, K1 + CHECK_SLOT_MATCH %1, K1, bb_fallback_1 + %3 = LOAD_TVALUE %1, 0i + STORE_TVALUE R3, %3 + CHECK_NODE_VALUE %1, bb_fallback_1 + %7 = LOAD_TVALUE %1, 0i + STORE_TVALUE R4, %7 + %9 = LOAD_DOUBLE R3 + %10 = LOAD_DOUBLE R4 + %11 = ADD_NUM %9, %10 + STORE_DOUBLE R2, %11 + RETURN R2, 1u + +bb_fallback_1: + RETURN R0, 1u + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "DuplicateHashSlotChecksAvoidNil") +{ + ScopedFastFlag luauReuseHashSlots{"LuauReuseHashSlots2", true}; + + IrOp block = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(block); + + IrOp table1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + IrOp slot1 = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table1, build.constUint(3), build.vmConst(1)); + build.inst(IrCmd::CHECK_SLOT_MATCH, slot1, build.vmConst(1), fallback); + IrOp value1 = build.inst(IrCmd::LOAD_TVALUE, slot1, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1); + + IrOp table2 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(2)); + IrOp slot2 = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table2, build.constUint(6), build.vmConst(1)); + build.inst(IrCmd::CHECK_SLOT_MATCH, slot2, build.vmConst(1), fallback); + build.inst(IrCmd::CHECK_READONLY, table2, fallback); + + build.inst(IrCmd::STORE_TAG, build.vmReg(4), build.constTag(tnil)); + IrOp valueNil = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(4)); + build.inst(IrCmd::STORE_TVALUE, slot2, valueNil, build.constInt(0)); + + // In the future, we might get to track that value became 'nil' and that fallback will be taken + IrOp slot1b = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table1, build.constUint(8), build.vmConst(1)); // This will be removed + build.inst(IrCmd::CHECK_SLOT_MATCH, slot1b, build.vmConst(1), fallback); // Key will be replaced with undef here + IrOp value1b = build.inst(IrCmd::LOAD_TVALUE, slot1b, build.constInt(0)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), value1b); + + IrOp slot2b = build.inst(IrCmd::GET_SLOT_NODE_ADDR, table2, build.constUint(11), build.vmConst(1)); // This will be removed + build.inst(IrCmd::CHECK_SLOT_MATCH, slot2b, build.vmConst(1), fallback); // Key will be replaced with undef here + build.inst(IrCmd::CHECK_READONLY, table2, fallback); + + build.inst(IrCmd::STORE_SPLIT_TVALUE, slot2b, build.constTag(tnumber), build.constDouble(1), build.constInt(0)); + + build.inst(IrCmd::RETURN, build.vmReg(3), build.constUint(2)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(1), build.constUint(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_POINTER R1 + %1 = GET_SLOT_NODE_ADDR %0, 3u, K1 + CHECK_SLOT_MATCH %1, K1, bb_fallback_1 + %3 = LOAD_TVALUE %1, 0i + STORE_TVALUE R3, %3 + %5 = LOAD_POINTER R2 + %6 = GET_SLOT_NODE_ADDR %5, 6u, K1 + CHECK_SLOT_MATCH %6, K1, bb_fallback_1 + CHECK_READONLY %5, bb_fallback_1 + STORE_TAG R4, tnil + %10 = LOAD_TVALUE R4 + STORE_TVALUE %6, %10, 0i + CHECK_NODE_VALUE %1, bb_fallback_1 + %14 = LOAD_TVALUE %1, 0i + STORE_TVALUE R3, %14 + CHECK_NODE_VALUE %6, bb_fallback_1 + STORE_SPLIT_TVALUE %6, tnumber, 1, 0i + RETURN R3, 2u + +bb_fallback_1: + RETURN R1, 2u + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("Analysis"); @@ -2239,6 +2413,32 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification3") CHECK(build.function.cfg.idoms == std::vector{~0u, 0, 0, 0, 2, 0, 4, 0}); } +// 'Static Single Assignment Book' Figure 4.1 +TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification4") +{ + defineCfgTree({{1}, {2, 10}, {3, 7}, {4}, {5}, {4, 6}, {1}, {8}, {5, 9}, {7}, {}}); + + IdfContext ctx; + + computeIteratedDominanceFrontierForDefs(ctx, build.function, {0, 2, 3, 6}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + CHECK(ctx.idf == std::vector{1, 4, 5}); +} + +// 'Static Single Assignment Book' Figure 4.5 +TEST_CASE_FIXTURE(IrBuilderFixture, "DominanceVerification4") +{ + defineCfgTree({{1}, {2}, {3, 7}, {4, 5}, {6}, {6}, {8}, {8}, {9}, {10, 11}, {11}, {9, 12}, {2}}); + + IdfContext ctx; + + computeIteratedDominanceFrontierForDefs(ctx, build.function, {4, 5, 7, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + CHECK(ctx.idf == std::vector{2, 6, 8}); + + // Pruned form, when variable is only live-in in limited set of blocks + computeIteratedDominanceFrontierForDefs(ctx, build.function, {4, 5, 7, 12}, {6, 8, 9}); + CHECK(ctx.idf == std::vector{6, 8}); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ValueNumbering"); @@ -2484,4 +2684,138 @@ bb_0: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "TValueLoadToSplitStore") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + + build.beginBlock(entry); + IrOp op1 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(0)); + IrOp op1v2 = build.inst(IrCmd::ADD_NUM, op1, build.constDouble(4.0)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(1), op1v2); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tnumber)); + + // Check that this TValue store will be replaced by a split store + IrOp tv = build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1)); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), tv); + + // Check that tag and value can be extracted from R2 now (removing the fallback) + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(2)); + build.inst(IrCmd::CHECK_TAG, tag2, build.constTag(tnumber), fallback); + IrOp op2 = build.inst(IrCmd::LOAD_DOUBLE, build.vmReg(2)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(3), op2); + build.inst(IrCmd::STORE_TAG, build.vmReg(3), build.constTag(tnumber)); + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + build.beginBlock(fallback); + build.inst(IrCmd::RETURN, build.vmReg(2), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_DOUBLE R0 + %1 = ADD_NUM %0, 4 + STORE_DOUBLE R1, %1 + STORE_TAG R1, tnumber + STORE_SPLIT_TVALUE R2, tnumber, %1 + STORE_DOUBLE R3, %1 + STORE_TAG R3, tnumber + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "TagStoreUpdatesValueVersion") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + + IrOp op1 = build.inst(IrCmd::LOAD_POINTER, build.vmReg(0)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(1), op1); + build.inst(IrCmd::STORE_TAG, build.vmReg(1), build.constTag(tstring)); + + IrOp str = build.inst(IrCmd::LOAD_POINTER, build.vmReg(1)); + build.inst(IrCmd::STORE_POINTER, build.vmReg(2), str); + build.inst(IrCmd::STORE_TAG, build.vmReg(2), build.constTag(tstring)); + + build.inst(IrCmd::RETURN, build.vmReg(1), build.constInt(1)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + %0 = LOAD_POINTER R0 + STORE_POINTER R1, %0 + STORE_TAG R1, tstring + STORE_POINTER R2, %0 + STORE_TAG R2, tstring + RETURN R1, 1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "TagStoreUpdatesSetUpval") +{ + IrOp entry = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); + + build.inst(IrCmd::SET_UPVALUE, build.vmUpvalue(0), build.vmReg(0), build.undef()); + + build.inst(IrCmd::RETURN, build.vmReg(0), build.constInt(0)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + STORE_TAG R0, tnumber + STORE_DOUBLE R0, 0.5 + SET_UPVALUE U0, R0, tnumber + RETURN R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "TagSelfEqualityCheckRemoval") +{ + ScopedFastFlag luauMergeTagLoads{"LuauMergeTagLoads", true}; + + IrOp entry = build.block(IrBlockKind::Internal); + IrOp trueBlock = build.block(IrBlockKind::Internal); + IrOp falseBlock = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + + IrOp tag1 = build.inst(IrCmd::LOAD_TAG, build.vmReg(0)); + IrOp tag2 = build.inst(IrCmd::LOAD_TAG, build.vmReg(0)); + build.inst(IrCmd::JUMP_EQ_TAG, tag1, tag2, trueBlock, falseBlock); + + build.beginBlock(trueBlock); + build.inst(IrCmd::RETURN, build.constUint(1)); + + build.beginBlock(falseBlock); + build.inst(IrCmd::RETURN, build.constUint(2)); + + updateUseCounts(build.function); + constPropInBlockChains(build, true); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + JUMP bb_1 + +bb_1: + RETURN 1u + +)"); +} + TEST_SUITE_END(); diff --git a/tests/IrCallWrapperX64.test.cpp b/tests/IrCallWrapperX64.test.cpp index ec04e531..1ff22a32 100644 --- a/tests/IrCallWrapperX64.test.cpp +++ b/tests/IrCallWrapperX64.test.cpp @@ -12,7 +12,7 @@ class IrCallWrapperX64Fixture public: IrCallWrapperX64Fixture(ABIX64 abi = ABIX64::Windows) : build(/* logText */ true, abi) - , regs(build, function) + , regs(build, function, nullptr) , callWrap(regs, build, ~0u) { } diff --git a/tests/IrRegAllocX64.test.cpp b/tests/IrRegAllocX64.test.cpp index bbf9c154..b4b63f4b 100644 --- a/tests/IrRegAllocX64.test.cpp +++ b/tests/IrRegAllocX64.test.cpp @@ -11,7 +11,7 @@ class IrRegAllocX64Fixture public: IrRegAllocX64Fixture() : build(/* logText */ true, ABIX64::Windows) - , regs(build, function) + , regs(build, function, nullptr) { } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index e906c224..6c728da9 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1657,8 +1657,6 @@ _ = (math.random() < 0.5 and false) or 42 -- currently ignored TEST_CASE_FIXTURE(Fixture, "WrongComment") { - ScopedFastFlag sff("LuauLintNativeComment", true); - LintResult result = lint(R"( --!strict --!struct diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 5b1849a7..4f489065 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Clone.h" +#include "Luau/Common.h" #include "Luau/Module.h" #include "Luau/Scope.h" #include "Luau/RecursionCounter.h" @@ -7,11 +8,13 @@ #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); +LUAU_FASTFLAG(LuauStacklessTypeClone2) TEST_SUITE_BEGIN("ModuleTests"); @@ -78,7 +81,7 @@ TEST_CASE_FIXTURE(Fixture, "is_within_comment_parse_result") TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") { TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; // numberType is persistent. We leave it as-is. TypeId newNumber = clone(builtinTypes->numberType, dest, cloneState); @@ -88,7 +91,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") { TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; // Create a new number type that isn't persistent unfreeze(frontend.globals.globalTypes); @@ -129,7 +132,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") TypeId ty = requireType("Cyclic"); TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; TypeId cloneTy = clone(ty, dest, cloneState); TableType* ttv = getMutable(cloneTy); @@ -165,7 +168,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table_2") TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; TypeId cloneTy = clone(tableTy, dest, cloneState); TableType* ctt = getMutable(cloneTy); REQUIRE(ctt); @@ -209,7 +212,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") TEST_CASE_FIXTURE(Fixture, "deepClone_union") { TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; unfreeze(frontend.globals.globalTypes); TypeId oldUnion = frontend.globals.globalTypes.addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}); @@ -224,7 +227,7 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") { TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; unfreeze(frontend.globals.globalTypes); TypeId oldIntersection = frontend.globals.globalTypes.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}); @@ -251,7 +254,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") std::nullopt, &exampleMetaClass, {}, {}, "Test"}}; TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; TypeId cloned = clone(&exampleClass, dest, cloneState); const ClassType* ctv = get(cloned); @@ -274,12 +277,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_free_types") TypePackVar freeTp(FreeTypePack{TypeLevel{}}); TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; TypeId clonedTy = clone(freeTy, dest, cloneState); CHECK(get(clonedTy)); - cloneState = {}; + cloneState = {builtinTypes}; TypePackId clonedTp = clone(&freeTp, dest, cloneState); CHECK(get(clonedTp)); } @@ -291,7 +294,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_free_tables") ttv->state = TableState::Free; TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; TypeId cloned = clone(&tableTy, dest, cloneState); const TableType* clonedTtv = get(cloned); @@ -332,6 +335,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") #else int limit = 400; #endif + + ScopedFastFlag sff{"LuauStacklessTypeClone2", false}; ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; TypeArena src; @@ -348,11 +353,39 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") } TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; CHECK_THROWS_AS(clone(table, dest, cloneState), RecursionLimitException); } +TEST_CASE_FIXTURE(Fixture, "clone_iteration_limit") +{ + ScopedFastFlag sff{"LuauStacklessTypeClone2", true}; + ScopedFastInt sfi{"LuauTypeCloneIterationLimit", 500}; + + TypeArena src; + + TypeId table = src.addType(TableType{}); + TypeId nested = table; + + for (int i = 0; i < 2500; i++) + { + TableType* ttv = getMutable(nested); + ttv->props["a"].setType(src.addType(TableType{})); + nested = ttv->props["a"].type(); + } + + TypeArena dest; + CloneState cloneState{builtinTypes}; + + TypeId ty = clone(table, dest, cloneState); + CHECK(get(ty)); + + // Cloning it again is an important test. + TypeId ty2 = clone(table, dest, cloneState); + CHECK(get(ty2)); +} + // Unions should never be cyclic, but we should clone them correctly even if // they are. TEST_CASE_FIXTURE(Fixture, "clone_cyclic_union") @@ -368,7 +401,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_cyclic_union") uu->options.push_back(u); TypeArena dest; - CloneState cloneState; + CloneState cloneState{builtinTypes}; TypeId cloned = clone(u, dest, cloneState); REQUIRE(cloned); @@ -468,4 +501,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_clone_types_of_reexported_values") CHECK(tableA->props["a"].type() == tableB->props["b"].type()); } +TEST_CASE_FIXTURE(BuiltinsFixture, "clone_table_bound_to_table_bound_to_table") +{ + TypeArena arena; + + TypeId a = arena.addType(TableType{TableState::Free, TypeLevel{}}); + getMutable(a)->name = "a"; + + TypeId b = arena.addType(TableType{TableState::Free, TypeLevel{}}); + getMutable(b)->name = "b"; + + TypeId c = arena.addType(TableType{TableState::Free, TypeLevel{}}); + getMutable(c)->name = "c"; + + getMutable(a)->boundTo = b; + getMutable(b)->boundTo = c; + + TypeArena dest; + CloneState state{builtinTypes}; + TypeId res = clone(a, dest, state); + + REQUIRE(dest.types.size() == 1); + + auto tableA = get(res); + REQUIRE_MESSAGE(tableA, "Expected table, got " << res); + REQUIRE(tableA->name == "c"); + REQUIRE(!tableA->boundTo); +} + TEST_SUITE_END(); diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index afcc08f2..7bed110b 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -516,6 +516,71 @@ struct NormalizeFixture : Fixture TEST_SUITE_BEGIN("Normalize"); +TEST_CASE_FIXTURE(NormalizeFixture, "string_intersection_is_commutative") +{ + auto c4 = toString(normal(R"( + string & (string & Not<"a"> & Not<"b">) +)")); + auto c4Reverse = toString(normal(R"( + (string & Not<"a"> & Not<"b">) & string +)")); + CHECK(c4 == c4Reverse); + CHECK_EQ("string & ~\"a\" & ~\"b\"", c4); + + auto c5 = toString(normal(R"( + (string & Not<"a"> & Not<"b">) & (string & Not<"b"> & Not<"c">) +)")); + auto c5Reverse = toString(normal(R"( + (string & Not<"b"> & Not<"c">) & (string & Not<"a"> & Not<"c">) +)")); + CHECK(c5 == c5Reverse); + CHECK_EQ("string & ~\"a\" & ~\"b\" & ~\"c\"", c5); + + auto c6 = toString(normal(R"( + ("a" | "b") & (string & Not<"b"> & Not<"c">) +)")); + auto c6Reverse = toString(normal(R"( + (string & Not<"b"> & Not<"c">) & ("a" | "b") +)")); + CHECK(c6 == c6Reverse); + CHECK_EQ("\"a\"", c6); + + auto c7 = toString(normal(R"( + string & ("b" | "c") +)")); + auto c7Reverse = toString(normal(R"( + ("b" | "c") & string +)")); + CHECK(c7 == c7Reverse); + CHECK_EQ("\"b\" | \"c\"", c7); + + auto c8 = toString(normal(R"( +(string & Not<"a"> & Not<"b">) & ("b" | "c") +)")); + auto c8Reverse = toString(normal(R"( + ("b" | "c") & (string & Not<"a"> & Not<"b">) +)")); + CHECK(c8 == c8Reverse); + CHECK_EQ("\"c\"", c8); + auto c9 = toString(normal(R"( + ("a" | "b") & ("b" | "c") + )")); + auto c9Reverse = toString(normal(R"( + ("b" | "c") & ("a" | "b") + )")); + CHECK(c9 == c9Reverse); + CHECK_EQ("\"b\"", c9); + + auto l = toString(normal(R"( + (string | number) & ("a" | true) + )")); + auto r = toString(normal(R"( + ("a" | true) & (string | number) + )")); + CHECK(l == r); + CHECK_EQ("\"a\"", l); +} + TEST_CASE_FIXTURE(NormalizeFixture, "negate_string") { CHECK("number" == toString(normal(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 3798082b..2492c4a0 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -900,7 +900,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_begin") } catch (const ParseErrors& e) { - CHECK_EQ("Double braces are not permitted within interpolated strings. Did you mean '\\{'?", e.getErrors().front().getMessage()); + CHECK_EQ("Double braces are not permitted within interpolated strings; did you mean '\\{'?", e.getErrors().front().getMessage()); } } @@ -915,7 +915,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_mid") } catch (const ParseErrors& e) { - CHECK_EQ("Double braces are not permitted within interpolated strings. Did you mean '\\{'?", e.getErrors().front().getMessage()); + CHECK_EQ("Double braces are not permitted within interpolated strings; did you mean '\\{'?", e.getErrors().front().getMessage()); } } @@ -933,7 +933,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace") CHECK_EQ(e.getErrors().size(), 1); auto error = e.getErrors().front(); - CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", error.getMessage()); + CHECK_EQ("Malformed interpolated string; did you forget to add a '}'?", error.getMessage()); return error.getLocation().begin.column; } }; @@ -956,7 +956,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace_in_table { CHECK_EQ(e.getErrors().size(), 2); - CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", e.getErrors().front().getMessage()); + CHECK_EQ("Malformed interpolated string; did you forget to add a '}'?", e.getErrors().front().getMessage()); CHECK_EQ("Expected '}' (to close '{' at line 2), got ", e.getErrors().back().getMessage()); } } @@ -974,7 +974,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_mid_without_end_brace_in_t { CHECK_EQ(e.getErrors().size(), 2); - CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", e.getErrors().front().getMessage()); + CHECK_EQ("Malformed interpolated string; did you forget to add a '}'?", e.getErrors().front().getMessage()); CHECK_EQ("Expected '}' (to close '{' at line 2), got ", e.getErrors().back().getMessage()); } } @@ -1041,6 +1041,36 @@ TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_expression") } } +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_malformed_escape") +{ + try + { + parse(R"( + local a = `???\xQQ {1}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Interpolated string literal contains malformed escape sequence", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_weird_token") +{ + try + { + parse(R"( + local a = `??? {42 !!}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Malformed interpolated string, got '!'", e.getErrors().front().getMessage()); + } +} + TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection") { try @@ -1569,9 +1599,9 @@ TEST_CASE_FIXTURE(Fixture, "string_literals_escapes_broken") TEST_CASE_FIXTURE(Fixture, "string_literals_broken") { - matchParseError("return \"", "Malformed string"); - matchParseError("return \"\\", "Malformed string"); - matchParseError("return \"\r\r", "Malformed string"); + matchParseError("return \"", "Malformed string; did you forget to finish it?"); + matchParseError("return \"\\", "Malformed string; did you forget to finish it?"); + matchParseError("return \"\r\r", "Malformed string; did you forget to finish it?"); } TEST_CASE_FIXTURE(Fixture, "number_literals") @@ -2530,12 +2560,12 @@ TEST_CASE_FIXTURE(Fixture, "incomplete_method_call_still_yields_an_AstExprIndexN TEST_CASE_FIXTURE(Fixture, "recover_confusables") { // Binary - matchParseError("local a = 4 != 10", "Unexpected '!=', did you mean '~='?"); - matchParseError("local a = true && false", "Unexpected '&&', did you mean 'and'?"); - matchParseError("local a = false || true", "Unexpected '||', did you mean 'or'?"); + matchParseError("local a = 4 != 10", "Unexpected '!='; did you mean '~='?"); + matchParseError("local a = true && false", "Unexpected '&&'; did you mean 'and'?"); + matchParseError("local a = false || true", "Unexpected '||'; did you mean 'or'?"); // Unary - matchParseError("local a = !false", "Unexpected '!', did you mean 'not'?"); + matchParseError("local a = !false", "Unexpected '!'; did you mean 'not'?"); // Check that separate tokens are not considered as a single one matchParseError("local a = 4 ! = 10", "Expected identifier when parsing expression, got '!'"); @@ -2880,4 +2910,100 @@ TEST_CASE_FIXTURE(Fixture, "missing_default_type_pack_argument_after_variadic_ty CHECK_EQ("Expected type pack after '=', got type", result.errors[1].getMessage()); } +TEST_CASE_FIXTURE(Fixture, "table_type_keys_cant_contain_nul") +{ + ParseResult result = tryParse(R"( + type Foo = { ["\0"]: number } + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 21}, {1, 22}}, result.errors[0].getLocation()); + CHECK_EQ("String literal contains malformed escape sequence or \\0", result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "invalid_escape_literals_get_reported_but_parsing_continues") +{ + ParseResult result = tryParse(R"( + local foo = "\xQQ" + print(foo) + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 20}, {1, 26}}, result.errors[0].getLocation()); + CHECK_EQ("String literal contains malformed escape sequence", result.errors[0].getMessage()); + + REQUIRE(result.root); + CHECK_EQ(result.root->body.size, 2); +} + +TEST_CASE_FIXTURE(Fixture, "unfinished_string_literals_get_reported_but_parsing_continues") +{ + ParseResult result = tryParse(R"( + local foo = "hi + print(foo) + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 20}, {1, 23}}, result.errors[0].getLocation()); + CHECK_EQ("Malformed string; did you forget to finish it?", result.errors[0].getMessage()); + + REQUIRE(result.root); + CHECK_EQ(result.root->body.size, 2); +} + +TEST_CASE_FIXTURE(Fixture, "unfinished_string_literal_types_get_reported_but_parsing_continues") +{ + ParseResult result = tryParse(R"( + type Foo = "hi + print(foo) + )"); + + REQUIRE_EQ(1, result.errors.size()); + + CHECK_EQ(Location{{1, 19}, {1, 22}}, result.errors[0].getLocation()); + CHECK_EQ("Malformed string; did you forget to finish it?", result.errors[0].getMessage()); + + REQUIRE(result.root); + CHECK_EQ(result.root->body.size, 2); +} + +TEST_CASE_FIXTURE(Fixture, "do_block_with_no_end") +{ + ParseResult result = tryParse(R"( + do + )"); + + REQUIRE_EQ(1, result.errors.size()); + + AstStatBlock* stat0 = result.root->body.data[0]->as(); + REQUIRE(stat0); + + CHECK(!stat0->hasEnd); +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved") +{ + ScopedFastFlag sff{"LuauLexerLookaheadRemembersBraceType", true}; + + ParseResult result = tryParse(R"( + local x = `{ {y} }` + )"); + + REQUIRE_MESSAGE(result.errors.empty(), result.errors[0].getMessage()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_with_lookahead_involved2") +{ + ScopedFastFlag sff{"LuauLexerLookaheadRemembersBraceType", true}; + + ParseResult result = tryParse(R"( + local x = `{ { y{} } }` + )"); + + REQUIRE_MESSAGE(result.errors.empty(), result.errors[0].getMessage()); +} + TEST_SUITE_END(); diff --git a/tests/RegisterCallbacks.cpp b/tests/RegisterCallbacks.cpp new file mode 100644 index 00000000..9f471933 --- /dev/null +++ b/tests/RegisterCallbacks.cpp @@ -0,0 +1,20 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "RegisterCallbacks.h" + +namespace Luau +{ + +std::unordered_set& getRegisterCallbacks() +{ + static std::unordered_set cbs; + return cbs; +} + +int addTestCallback(RegisterCallback cb) +{ + getRegisterCallbacks().insert(cb); + return 0; +} + +} // namespace Luau diff --git a/tests/RegisterCallbacks.h b/tests/RegisterCallbacks.h new file mode 100644 index 00000000..f62ac0e7 --- /dev/null +++ b/tests/RegisterCallbacks.h @@ -0,0 +1,22 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +namespace Luau +{ + +using RegisterCallback = void (*)(); + +/// Gets a set of callbacks to run immediately before running tests, intended +/// for registering new tests at runtime. +std::unordered_set& getRegisterCallbacks(); + +/// Adds a new callback to be ran immediately before running tests. +/// +/// @param cb the callback to add. +/// @returns a dummy integer to satisfy a doctest internal contract. +int addTestCallback(RegisterCallback cb); + +} // namespace Luau diff --git a/tests/Subtyping.test.cpp b/tests/Subtyping.test.cpp new file mode 100644 index 00000000..ce418b71 --- /dev/null +++ b/tests/Subtyping.test.cpp @@ -0,0 +1,993 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "doctest.h" +#include "Fixture.h" +#include "RegisterCallbacks.h" + +#include "Luau/Normalize.h" +#include "Luau/Subtyping.h" +#include "Luau/Type.h" +#include "Luau/TypePack.h" + +using namespace Luau; + +struct SubtypeFixture : Fixture +{ + TypeArena arena; + InternalErrorReporter iceReporter; + UnifierSharedState sharedState{&ice}; + Normalizer normalizer{&arena, builtinTypes, NotNull{&sharedState}}; + + ScopePtr rootScope{new Scope(builtinTypes->emptyTypePack)}; + ScopePtr moduleScope{new Scope(rootScope)}; + + Subtyping subtyping = mkSubtyping(rootScope); + + Subtyping mkSubtyping(const ScopePtr& scope) + { + return Subtyping{builtinTypes, NotNull{&arena}, NotNull{&normalizer}, NotNull{&iceReporter}, NotNull{scope.get()}}; + } + + TypePackId pack(std::initializer_list tys) + { + return arena.addTypePack(tys); + } + + TypePackId pack(std::initializer_list tys, TypePackVariant tail) + { + return arena.addTypePack(tys, arena.addTypePack(std::move(tail))); + } + + TypeId fn(std::initializer_list args, std::initializer_list rets) + { + return arena.addType(FunctionType{pack(args), pack(rets)}); + } + + TypeId fn(std::initializer_list argHead, TypePackVariant argTail, std::initializer_list rets) + { + return arena.addType(FunctionType{pack(argHead, std::move(argTail)), pack(rets)}); + } + + TypeId fn(std::initializer_list args, std::initializer_list retHead, TypePackVariant retTail) + { + return arena.addType(FunctionType{pack(args), pack(retHead, std::move(retTail))}); + } + + TypeId fn(std::initializer_list argHead, TypePackVariant argTail, std::initializer_list retHead, TypePackVariant retTail) + { + return arena.addType(FunctionType{pack(argHead, std::move(argTail)), pack(retHead, std::move(retTail))}); + } + + TypeId tbl(TableType::Props&& props) + { + return arena.addType(TableType{std::move(props), std::nullopt, {}, TableState::Sealed}); + } + + // `&` + TypeId meet(TypeId a, TypeId b) + { + return arena.addType(IntersectionType{{a, b}}); + } + + // `|` + TypeId join(TypeId a, TypeId b) + { + return arena.addType(UnionType{{a, b}}); + } + + // `~` + TypeId negate(TypeId ty) + { + return arena.addType(NegationType{ty}); + } + + // "literal" + TypeId str(const char* literal) + { + return arena.addType(SingletonType{StringSingleton{literal}}); + } + + TypeId cls(const std::string& name, std::optional parent = std::nullopt) + { + return arena.addType(ClassType{name, {}, parent.value_or(builtinTypes->classType), {}, {}, nullptr, ""}); + } + + TypeId cls(const std::string& name, ClassType::Props&& props) + { + TypeId ty = cls(name); + getMutable(ty)->props = std::move(props); + return ty; + } + + TypeId cyclicTable(std::function&& cb) + { + TypeId res = arena.addType(GenericType{}); + TableType tt{}; + cb(res, &tt); + emplaceType(asMutable(res), std::move(tt)); + return res; + } + + TypeId meta(TableType::Props&& metaProps, TableType::Props&& tableProps = {}) + { + return arena.addType(MetatableType{tbl(std::move(tableProps)), tbl(std::move(metaProps))}); + } + + TypeId genericT = arena.addType(GenericType{moduleScope.get(), "T"}); + TypeId genericU = arena.addType(GenericType{moduleScope.get(), "U"}); + + TypePackId genericAs = arena.addTypePack(GenericTypePack{"A"}); + TypePackId genericBs = arena.addTypePack(GenericTypePack{"B"}); + TypePackId genericCs = arena.addTypePack(GenericTypePack{"C"}); + + SubtypingResult isSubtype(TypeId subTy, TypeId superTy) + { + return subtyping.isSubtype(subTy, superTy); + } + + TypeId helloType = arena.addType(SingletonType{StringSingleton{"hello"}}); + TypeId helloType2 = arena.addType(SingletonType{StringSingleton{"hello"}}); + TypeId worldType = arena.addType(SingletonType{StringSingleton{"world"}}); + + TypeId aType = arena.addType(SingletonType{StringSingleton{"a"}}); + TypeId bType = arena.addType(SingletonType{StringSingleton{"b"}}); + TypeId trueSingleton = arena.addType(SingletonType{BooleanSingleton{true}}); + TypeId falseSingleton = arena.addType(SingletonType{BooleanSingleton{false}}); + TypeId helloOrWorldType = join(helloType, worldType); + TypeId trueOrFalseType = join(builtinTypes->trueType, builtinTypes->falseType); + + TypeId helloAndWorldType = meet(helloType, worldType); + TypeId booleanAndTrueType = meet(builtinTypes->booleanType, builtinTypes->trueType); + + /** + * class + * \- Root + * |- Child + * | |-GrandchildOne + * | \-GrandchildTwo + * \- AnotherChild + * |- AnotherGrandchildOne + * \- AnotherGrandchildTwo + */ + TypeId rootClass = cls("Root"); + TypeId childClass = cls("Child", rootClass); + TypeId grandchildOneClass = cls("GrandchildOne", childClass); + TypeId grandchildTwoClass = cls("GrandchildTwo", childClass); + TypeId anotherChildClass = cls("AnotherChild", rootClass); + TypeId anotherGrandchildOneClass = cls("AnotherGrandchildOne", anotherChildClass); + TypeId anotherGrandchildTwoClass = cls("AnotherGrandchildTwo", anotherChildClass); + + TypeId vec2Class = cls("Vec2", { + {"X", builtinTypes->numberType}, + {"Y", builtinTypes->numberType}, + }); + + // "hello" | "hello" + TypeId helloOrHelloType = arena.addType(UnionType{{helloType, helloType}}); + + // () -> () + const TypeId nothingToNothingType = fn({}, {}); + + // (number) -> string + const TypeId numberToStringType = fn({builtinTypes->numberType}, {builtinTypes->stringType}); + + // (unknown) -> string + const TypeId unknownToStringType = fn({builtinTypes->unknownType}, {builtinTypes->stringType}); + + // (number) -> () + const TypeId numberToNothingType = fn({builtinTypes->numberType}, {}); + + // () -> number + const TypeId nothingToNumberType = fn({}, {builtinTypes->numberType}); + + // (number) -> number + const TypeId numberToNumberType = fn({builtinTypes->numberType}, {builtinTypes->numberType}); + + // (number) -> unknown + const TypeId numberToUnknownType = fn({builtinTypes->numberType}, {builtinTypes->unknownType}); + + // (number) -> (string, string) + const TypeId numberToTwoStringsType = fn({builtinTypes->numberType}, {builtinTypes->stringType, builtinTypes->stringType}); + + // (number) -> (string, unknown) + const TypeId numberToStringAndUnknownType = fn({builtinTypes->numberType}, {builtinTypes->stringType, builtinTypes->unknownType}); + + // (number, number) -> string + const TypeId numberNumberToStringType = fn({builtinTypes->numberType, builtinTypes->numberType}, {builtinTypes->stringType}); + + // (unknown, number) -> string + const TypeId unknownNumberToStringType = fn({builtinTypes->unknownType, builtinTypes->numberType}, {builtinTypes->stringType}); + + // (number, string) -> string + const TypeId numberAndStringToStringType = fn({builtinTypes->numberType, builtinTypes->stringType}, {builtinTypes->stringType}); + + // (number, ...string) -> string + const TypeId numberAndStringsToStringType = + fn({builtinTypes->numberType}, VariadicTypePack{builtinTypes->stringType}, {builtinTypes->stringType}); + + // (number, ...string?) -> string + const TypeId numberAndOptionalStringsToStringType = + fn({builtinTypes->numberType}, VariadicTypePack{builtinTypes->optionalStringType}, {builtinTypes->stringType}); + + // (...number) -> number + const TypeId numbersToNumberType = + arena.addType(FunctionType{arena.addTypePack(VariadicTypePack{builtinTypes->numberType}), arena.addTypePack({builtinTypes->numberType})}); + + // (T) -> () + const TypeId genericTToNothingType = arena.addType(FunctionType{{genericT}, {}, arena.addTypePack({genericT}), builtinTypes->emptyTypePack}); + + // (T) -> T + const TypeId genericTToTType = arena.addType(FunctionType{{genericT}, {}, arena.addTypePack({genericT}), arena.addTypePack({genericT})}); + + // (U) -> () + const TypeId genericUToNothingType = arena.addType(FunctionType{{genericU}, {}, arena.addTypePack({genericU}), builtinTypes->emptyTypePack}); + + // () -> T + const TypeId genericNothingToTType = arena.addType(FunctionType{{genericT}, {}, builtinTypes->emptyTypePack, arena.addTypePack({genericT})}); + + // (A...) -> A... + const TypeId genericAsToAsType = arena.addType(FunctionType{{}, {genericAs}, genericAs, genericAs}); + + // (A...) -> number + const TypeId genericAsToNumberType = arena.addType(FunctionType{{}, {genericAs}, genericAs, arena.addTypePack({builtinTypes->numberType})}); + + // (B...) -> B... + const TypeId genericBsToBsType = arena.addType(FunctionType{{}, {genericBs}, genericBs, genericBs}); + + // (B...) -> C... + const TypeId genericBsToCsType = arena.addType(FunctionType{{}, {genericBs, genericCs}, genericBs, genericCs}); + + // () -> A... + const TypeId genericNothingToAsType = arena.addType(FunctionType{{}, {genericAs}, builtinTypes->emptyTypePack, genericAs}); + + // { lower : string -> string } + TypeId tableWithLower = tbl(TableType::Props{{"lower", fn({builtinTypes->stringType}, {builtinTypes->stringType})}}); + // { insaneThingNoScalarHas : () -> () } + TypeId tableWithoutScalarProp = tbl(TableType::Props{{"insaneThingNoScalarHas", fn({}, {})}}); +}; + +#define CHECK_IS_SUBTYPE(left, right) \ + do \ + { \ + const auto& leftTy = (left); \ + const auto& rightTy = (right); \ + SubtypingResult result = isSubtype(leftTy, rightTy); \ + CHECK_MESSAGE(result.isSubtype, "Expected " << leftTy << " <: " << rightTy); \ + } while (0) + +#define CHECK_IS_NOT_SUBTYPE(left, right) \ + do \ + { \ + const auto& leftTy = (left); \ + const auto& rightTy = (right); \ + SubtypingResult result = isSubtype(leftTy, rightTy); \ + CHECK_MESSAGE(!result.isSubtype, "Expected " << leftTy << " numberType, builtinTypes->anyType); +TEST_IS_NOT_SUBTYPE(builtinTypes->numberType, builtinTypes->stringType); + +TEST_CASE_FIXTURE(SubtypeFixture, "any anyType, builtinTypes->unknownType); + CHECK(!result.isSubtype); + CHECK(result.isErrorSuppressing); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "number? <: unknown") +{ + CHECK_IS_SUBTYPE(builtinTypes->optionalNumberType, builtinTypes->unknownType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "number <: unknown") +{ + CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->unknownType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "number <: number") +{ + CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "number <: number?") +{ + CHECK_IS_SUBTYPE(builtinTypes->numberType, builtinTypes->optionalNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" <: string") +{ + CHECK_IS_SUBTYPE(helloType, builtinTypes->stringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "string stringType, helloType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" <: \"hello\"") +{ + CHECK_IS_SUBTYPE(helloType, helloType2); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true <: boolean") +{ + CHECK_IS_SUBTYPE(builtinTypes->trueType, builtinTypes->booleanType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true <: true | false") +{ + CHECK_IS_SUBTYPE(builtinTypes->trueType, trueOrFalseType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true | false trueType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true | false <: boolean") +{ + CHECK_IS_SUBTYPE(trueOrFalseType, builtinTypes->booleanType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true | false <: true | false") +{ + CHECK_IS_SUBTYPE(trueOrFalseType, trueOrFalseType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" | \"world\" <: number") +{ + CHECK_IS_NOT_SUBTYPE(helloOrWorldType, builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "string stringType, helloOrHelloType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "true <: boolean & true") +{ + CHECK_IS_SUBTYPE(builtinTypes->trueType, booleanAndTrueType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "boolean & true <: true") +{ + CHECK_IS_SUBTYPE(booleanAndTrueType, builtinTypes->trueType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "boolean & true <: boolean & true") +{ + CHECK_IS_SUBTYPE(booleanAndTrueType, booleanAndTrueType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" & \"world\" <: number") +{ + CHECK_IS_SUBTYPE(helloAndWorldType, builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "false falseType, booleanAndTrueType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(unknown) -> string <: (number) -> string") +{ + CHECK_IS_SUBTYPE(unknownToStringType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, unknownToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, number) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberNumberToStringType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, numberNumberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, number) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberNumberToStringType, unknownNumberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(unknown, number) -> string <: (number, number) -> string") +{ + CHECK_IS_SUBTYPE(unknownNumberToStringType, numberNumberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> (string, unknown) (string, string)") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringAndUnknownType, numberToTwoStringsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> (string, string) <: (number) -> (string, unknown)") +{ + CHECK_IS_SUBTYPE(numberToTwoStringsType, numberToStringAndUnknownType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> (string, string) string") +{ + CHECK_IS_NOT_SUBTYPE(numberToTwoStringsType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> string (string, string)") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, numberToTwoStringsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, ...string) -> string <: (number) -> string") +{ + CHECK_IS_SUBTYPE(numberAndStringsToStringType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, numberAndStringsToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, ...string?) -> string <: (number, ...string) -> string") +{ + CHECK_IS_SUBTYPE(numberAndOptionalStringsToStringType, numberAndStringsToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, ...string) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberAndStringsToStringType, numberAndOptionalStringsToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, ...string) -> string <: (number, string) -> string") +{ + CHECK_IS_SUBTYPE(numberAndStringsToStringType, numberAndStringToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number, string) -> string string") +{ + CHECK_IS_NOT_SUBTYPE(numberAndStringToStringType, numberAndStringsToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> T <: () -> number") +{ + CHECK_IS_SUBTYPE(genericNothingToTType, nothingToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> () <: (U) -> ()") +{ + CHECK_IS_SUBTYPE(genericTToNothingType, genericUToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> number () -> T") +{ + CHECK_IS_NOT_SUBTYPE(nothingToNumberType, genericNothingToTType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> () <: (number) -> ()") +{ + CHECK_IS_SUBTYPE(genericTToNothingType, numberToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> T <: (number) -> number") +{ + CHECK_IS_SUBTYPE(genericTToTType, numberToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> T string") +{ + CHECK_IS_NOT_SUBTYPE(genericTToTType, numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(T) -> () <: (U) -> ()") +{ + CHECK_IS_SUBTYPE(genericTToNothingType, genericUToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> () (T) -> ()") +{ + CHECK_IS_NOT_SUBTYPE(numberToNothingType, genericTToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... <: (number) -> number") +{ + CHECK_IS_SUBTYPE(genericAsToAsType, numberToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> number (A...) -> A...") +{ + CHECK_IS_NOT_SUBTYPE(numberToNumberType, genericAsToAsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... <: (B...) -> B...") +{ + CHECK_IS_SUBTYPE(genericAsToAsType, genericBsToBsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(B...) -> C... <: (A...) -> A...") +{ + CHECK_IS_SUBTYPE(genericBsToCsType, genericAsToAsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... (B...) -> C...") +{ + CHECK_IS_NOT_SUBTYPE(genericAsToAsType, genericBsToCsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> number <: (number) -> number") +{ + CHECK_IS_SUBTYPE(genericAsToNumberType, numberToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(number) -> number (A...) -> number") +{ + CHECK_IS_NOT_SUBTYPE(numberToNumberType, genericAsToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> number <: (...number) -> number") +{ + CHECK_IS_SUBTYPE(genericAsToNumberType, numbersToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(...number) -> number (A...) -> number") +{ + CHECK_IS_NOT_SUBTYPE(numbersToNumberType, genericAsToNumberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> A... <: () -> ()") +{ + CHECK_IS_SUBTYPE(genericNothingToAsType, nothingToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> () () -> A...") +{ + CHECK_IS_NOT_SUBTYPE(nothingToNothingType, genericNothingToAsType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(A...) -> A... <: () -> ()") +{ + CHECK_IS_SUBTYPE(genericAsToAsType, nothingToNothingType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "() -> () (A...) -> A...") +{ + CHECK_IS_NOT_SUBTYPE(nothingToNothingType, genericAsToAsType); +} + + +TEST_CASE_FIXTURE(SubtypeFixture, "{} <: {}") +{ + CHECK_IS_SUBTYPE(tbl({}), tbl({})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number} <: {}") +{ + CHECK_IS_SUBTYPE(tbl({{"x", builtinTypes->numberType}}), tbl({})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number} numberType}}), tbl({{"x", builtinTypes->stringType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number} numberType}}), tbl({{"x", builtinTypes->optionalNumberType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: number?} optionalNumberType}}), tbl({{"x", builtinTypes->numberType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{x: (T) -> ()} <: {x: (U) -> ()}") +{ + CHECK_IS_SUBTYPE(tbl({{"x", genericTToNothingType}}), tbl({{"x", genericUToNothingType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { x: number } } <: { @metatable {} }") +{ + CHECK_IS_SUBTYPE(meta({{"x", builtinTypes->numberType}}), meta({})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { x: number } } numberType}}), meta({{"x", builtinTypes->booleanType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable {} } booleanType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable {} } <: {}") +{ + CHECK_IS_SUBTYPE(meta({}), tbl({})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { u: boolean }, x: number } <: { x: number }") +{ + CHECK_IS_SUBTYPE(meta({{"u", builtinTypes->booleanType}}, {{"x", builtinTypes->numberType}}), tbl({{"x", builtinTypes->numberType}})); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ @metatable { x: number } } numberType}}), tbl({{"x", builtinTypes->numberType}})); +} + +// Negated subtypes +TEST_IS_NOT_SUBTYPE(negate(builtinTypes->neverType), builtinTypes->stringType); +TEST_IS_SUBTYPE(negate(builtinTypes->unknownType), builtinTypes->stringType); +TEST_IS_NOT_SUBTYPE(negate(builtinTypes->anyType), builtinTypes->stringType); +TEST_IS_SUBTYPE(negate(meet(builtinTypes->neverType, builtinTypes->unknownType)), builtinTypes->stringType); +TEST_IS_NOT_SUBTYPE(negate(join(builtinTypes->neverType, builtinTypes->unknownType)), builtinTypes->stringType); + +// Negated supertypes: never/unknown/any/error +TEST_IS_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->neverType)); +TEST_IS_SUBTYPE(builtinTypes->neverType, negate(builtinTypes->unknownType)); +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->unknownType)); +TEST_IS_SUBTYPE(builtinTypes->numberType, negate(builtinTypes->anyType)); +TEST_IS_SUBTYPE(builtinTypes->unknownType, negate(builtinTypes->anyType)); + +// Negated supertypes: unions +TEST_IS_SUBTYPE(builtinTypes->booleanType, negate(join(builtinTypes->stringType, builtinTypes->numberType))); +TEST_IS_SUBTYPE(rootClass, negate(join(childClass, builtinTypes->numberType))); +TEST_IS_SUBTYPE(str("foo"), negate(join(builtinTypes->numberType, builtinTypes->booleanType))); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(join(builtinTypes->stringType, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(childClass, negate(join(rootClass, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(numbersToNumberType, negate(join(builtinTypes->functionType, rootClass))); + +// Negated supertypes: intersections +TEST_IS_SUBTYPE(builtinTypes->booleanType, negate(meet(builtinTypes->stringType, str("foo")))); +TEST_IS_SUBTYPE(builtinTypes->trueType, negate(meet(builtinTypes->booleanType, builtinTypes->numberType))); +TEST_IS_SUBTYPE(rootClass, negate(meet(builtinTypes->classType, childClass))); +TEST_IS_SUBTYPE(childClass, negate(meet(builtinTypes->classType, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(builtinTypes->unknownType, negate(meet(builtinTypes->classType, builtinTypes->numberType))); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(meet(builtinTypes->stringType, negate(str("bar"))))); + +// Negated supertypes: tables and metatables +TEST_IS_SUBTYPE(tbl({}), negate(builtinTypes->numberType)); +TEST_IS_NOT_SUBTYPE(tbl({}), negate(builtinTypes->tableType)); +TEST_IS_SUBTYPE(meta({}), negate(builtinTypes->numberType)); +TEST_IS_NOT_SUBTYPE(meta({}), negate(builtinTypes->tableType)); + +// Negated supertypes: Functions +TEST_IS_SUBTYPE(numberToNumberType, negate(builtinTypes->classType)); +TEST_IS_NOT_SUBTYPE(numberToNumberType, negate(builtinTypes->functionType)); + +// Negated supertypes: Primitives and singletons +TEST_IS_SUBTYPE(builtinTypes->stringType, negate(builtinTypes->numberType)); +TEST_IS_SUBTYPE(str("foo"), meet(builtinTypes->stringType, negate(str("bar")))); +TEST_IS_NOT_SUBTYPE(builtinTypes->trueType, negate(builtinTypes->booleanType)); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(str("foo"))); +TEST_IS_NOT_SUBTYPE(str("foo"), negate(builtinTypes->stringType)); +TEST_IS_SUBTYPE(builtinTypes->falseType, negate(builtinTypes->trueType)); +TEST_IS_SUBTYPE(builtinTypes->falseType, meet(builtinTypes->booleanType, negate(builtinTypes->trueType))); +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, meet(builtinTypes->booleanType, negate(builtinTypes->trueType))); +TEST_IS_NOT_SUBTYPE(builtinTypes->stringType, negate(str("foo"))); +TEST_IS_NOT_SUBTYPE(builtinTypes->booleanType, negate(builtinTypes->falseType)); + +// Negated supertypes: Classes +TEST_IS_SUBTYPE(rootClass, negate(builtinTypes->tableType)); +TEST_IS_NOT_SUBTYPE(rootClass, negate(builtinTypes->classType)); +TEST_IS_NOT_SUBTYPE(childClass, negate(rootClass)); +TEST_IS_NOT_SUBTYPE(childClass, meet(builtinTypes->classType, negate(rootClass))); +TEST_IS_SUBTYPE(anotherChildClass, meet(builtinTypes->classType, negate(childClass))); + +TEST_CASE_FIXTURE(SubtypeFixture, "Root <: class") +{ + CHECK_IS_SUBTYPE(rootClass, builtinTypes->classType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child | AnotherChild <: class") +{ + CHECK_IS_SUBTYPE(join(childClass, anotherChildClass), builtinTypes->classType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child | AnotherChild <: Child | AnotherChild") +{ + CHECK_IS_SUBTYPE(join(childClass, anotherChildClass), join(childClass, anotherChildClass)); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child | Root <: Root") +{ + CHECK_IS_SUBTYPE(join(childClass, rootClass), rootClass); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child & AnotherChild <: class") +{ + CHECK_IS_SUBTYPE(meet(childClass, anotherChildClass), builtinTypes->classType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child & Root <: class") +{ + CHECK_IS_SUBTYPE(meet(childClass, rootClass), builtinTypes->classType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child & ~Root <: class") +{ + CHECK_IS_SUBTYPE(meet(childClass, negate(rootClass)), builtinTypes->classType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child & AnotherChild <: number") +{ + CHECK_IS_SUBTYPE(meet(childClass, anotherChildClass), builtinTypes->numberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Child & ~GrandchildOne numberType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> string} <: t2 where t2 = {trim: (t2) -> string}") +{ + TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + }); + + TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + }); + + CHECK_IS_SUBTYPE(t1, t2); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> string} t2}") +{ + TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + }); + + TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) { + tt->props["trim"] = fn({ty}, {ty}); + }); + + CHECK_IS_NOT_SUBTYPE(t1, t2); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "t1 where t1 = {trim: (t1) -> t1} string}") +{ + TypeId t1 = cyclicTable([&](TypeId ty, TableType* tt) { + tt->props["trim"] = fn({ty}, {ty}); + }); + + TypeId t2 = cyclicTable([&](TypeId ty, TableType* tt) { + tt->props["trim"] = fn({ty}, {builtinTypes->stringType}); + }); + + CHECK_IS_NOT_SUBTYPE(t1, t2); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Vec2 <: { X: number, Y: number }") +{ + TypeId xy = tbl({ + {"X", builtinTypes->numberType}, + {"Y", builtinTypes->numberType}, + }); + + CHECK_IS_SUBTYPE(vec2Class, xy); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Vec2 <: { X: number }") +{ + TypeId x = tbl({ + {"X", builtinTypes->numberType}, + }); + + CHECK_IS_SUBTYPE(vec2Class, x); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ X: number, Y: number } numberType}, + {"Y", builtinTypes->numberType}, + }); + + CHECK_IS_NOT_SUBTYPE(xy, vec2Class); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "{ X: number } numberType}, + }); + + CHECK_IS_NOT_SUBTYPE(x, vec2Class); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "table & { X: number, Y: number } numberType}, + {"Y", builtinTypes->numberType}, + }); + + CHECK_IS_NOT_SUBTYPE(meet(builtinTypes->tableType, x), vec2Class); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "Vec2 numberType}, + {"Y", builtinTypes->numberType}, + }); + + CHECK_IS_NOT_SUBTYPE(vec2Class, meet(builtinTypes->tableType, xy)); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" <: { lower : (string) -> string }") +{ + CHECK_IS_SUBTYPE(helloType, tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"hello\" () }") +{ + CHECK_IS_NOT_SUBTYPE(helloType, tableWithoutScalarProp); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "string <: { lower : (string) -> string }") +{ + CHECK_IS_SUBTYPE(builtinTypes->stringType, tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "string () }") +{ + CHECK_IS_NOT_SUBTYPE(builtinTypes->stringType, tableWithoutScalarProp); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "~fun & (string) -> number <: (string) -> number") +{ + CHECK_IS_SUBTYPE(meet(negate(builtinTypes->functionType), numberToStringType), numberToStringType); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(string) -> number <: ~fun & (string) -> number") +{ + CHECK_IS_NOT_SUBTYPE(numberToStringType, meet(negate(builtinTypes->functionType), numberToStringType)); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "~\"a\" & ~\"b\" & string <: { lower : (string) -> ()}") +{ + CHECK_IS_SUBTYPE(meet(meet(negate(aType), negate(bType)), builtinTypes->stringType), tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "\"a\" | (~\"b\" & string) <: { lower : (string) -> ()}") +{ + CHECK_IS_SUBTYPE(join(aType, meet(negate(bType), builtinTypes->stringType)), tableWithLower); +} + +TEST_CASE_FIXTURE(SubtypeFixture, "(string | number) & (\"a\" | true) <: { lower: (string) -> string }") +{ + auto base = meet(join(builtinTypes->stringType, builtinTypes->numberType), join(aType, trueSingleton)); + CHECK_IS_SUBTYPE(base, tableWithLower); +} + +/* + * Within the scope to which a generic belongs, that generic ought to be treated + * as its bounds. + * + * We do not yet support bounded generics, so all generics are considered to be + * bounded by unknown. + */ +TEST_CASE_FIXTURE(SubtypeFixture, "unknown <: X") +{ + ScopePtr childScope{new Scope(rootScope)}; + ScopePtr grandChildScope{new Scope(childScope)}; + + TypeId genericX = arena.addType(GenericType(childScope.get(), "X")); + + SubtypingResult usingGlobalScope = subtyping.isSubtype(builtinTypes->unknownType, genericX); + CHECK_MESSAGE(!usingGlobalScope.isSubtype, "Expected " << builtinTypes->unknownType << " unknownType, genericX); + CHECK_MESSAGE(usingChildScope.isSubtype, "Expected " << builtinTypes->unknownType << " <: " << genericX); + + Subtyping grandChildSubtyping{mkSubtyping(grandChildScope)}; + + SubtypingResult usingGrandChildScope = grandChildSubtyping.isSubtype(builtinTypes->unknownType, genericX); + CHECK_MESSAGE(usingGrandChildScope.isSubtype, "Expected " << builtinTypes->unknownType << " <: " << genericX); +} + +/* + * (A) -> A <: (X) -> X + * A can be bound to X. + * + * (A) -> A (X) -> number + * A can be bound to X, but A number (A) -> A + * Only generics on the left side can be bound. + * number (A, B) -> boolean <: (X, X) -> boolean + * It is ok to bind both A and B to X. + * + * (A, A) -> boolean (X, Y) -> boolean + * A cannot be bound to both X and Y. + */ + +TEST_SUITE_END(); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 9293bfb2..ac01b5f3 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -44,25 +44,34 @@ TEST_SUITE_BEGIN("ToDot"); TEST_CASE_FIXTURE(Fixture, "primitive") { - CheckResult result = check(R"( -local a: nil -local b: number -local c: any -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_NE("nil", toDot(requireType("a"))); + CHECK_EQ(R"(digraph graphname { +n1 [label="nil"]; +})", + toDot(builtinTypes->nilType)); CHECK_EQ(R"(digraph graphname { n1 [label="number"]; })", - toDot(requireType("b"))); + toDot(builtinTypes->numberType)); CHECK_EQ(R"(digraph graphname { n1 [label="any"]; })", - toDot(requireType("c"))); + toDot(builtinTypes->anyType)); + CHECK_EQ(R"(digraph graphname { +n1 [label="unknown"]; +})", + toDot(builtinTypes->unknownType)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="never"]; +})", + toDot(builtinTypes->neverType)); +} + +TEST_CASE_FIXTURE(Fixture, "no_duplicatePrimitives") +{ ToDotOptions opts; opts.showPointers = false; opts.duplicatePrimitives = false; @@ -70,12 +79,22 @@ n1 [label="any"]; CHECK_EQ(R"(digraph graphname { n1 [label="PrimitiveType number"]; })", - toDot(requireType("b"), opts)); + toDot(builtinTypes->numberType, opts)); CHECK_EQ(R"(digraph graphname { n1 [label="AnyType 1"]; })", - toDot(requireType("c"), opts)); + toDot(builtinTypes->anyType, opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="UnknownType 1"]; +})", + toDot(builtinTypes->unknownType, opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="NeverType 1"]; +})", + toDot(builtinTypes->neverType, opts)); } TEST_CASE_FIXTURE(Fixture, "bound") @@ -283,6 +302,30 @@ n1 [label="FreeType 1"]; toDot(&type, opts)); } +TEST_CASE_FIXTURE(Fixture, "free_with_constraints") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + }; + + Type type{TypeVariant{FreeType{nullptr, builtinTypes->numberType, builtinTypes->optionalNumberType}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FreeType 1"]; +n1 -> n2 [label="[lowerBound]"]; +n2 [label="number"]; +n1 -> n3 [label="[upperBound]"]; +n3 [label="UnionType 3"]; +n3 -> n4; +n4 [label="number"]; +n3 -> n5; +n5 [label="nil"]; +})", + toDot(&type, opts)); +} + TEST_CASE_FIXTURE(Fixture, "error") { Type type{TypeVariant{ErrorType{}}}; @@ -440,4 +483,19 @@ n5 [label="SingletonType boolean: false"]; toDot(requireType("x"), opts)); } +TEST_CASE_FIXTURE(Fixture, "negation") +{ + TypeArena arena; + TypeId t = arena.addType(NegationType{builtinTypes->stringType}); + + ToDotOptions opts; + opts.showPointers = false; + + CHECK(R"(digraph graphname { +n1 [label="NegationType 1"]; +n1 -> n2 [label="[negated]"]; +n2 [label="string"]; +})" == toDot(t, opts)); +} + TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index d4a25f80..cb1576f8 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -831,4 +831,31 @@ TEST_CASE_FIXTURE(Fixture, "tostring_unsee_ttv_if_array") CHECK(toString(requireType("y")) == "({string}, {string}) -> ()"); } +TEST_CASE_FIXTURE(Fixture, "tostring_error_mismatch") +{ + CheckResult result = check(R"( +--!strict + function f1() : {a : number, b : string, c : { d : number}} + return { a = 1, b = "a", c = {d = "a"}} + end + +)"); + std::string expected = R"(Type + '{ a: number, b: string, c: { d: string } }' +could not be converted into + '{| a: number, b: string, c: {| d: number |} |}' +caused by: + Property 'c' is not compatible. +Type + '{ d: string }' +could not be converted into + '{| d: number |}' +caused by: + Property 'd' is not compatible. +Type 'string' could not be converted into 'number' in an invariant context)"; + std::string actual = toString(result.errors[0]); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(expected == actual); +} TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 87f0b2b8..ae7a925c 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -529,14 +529,17 @@ until c CHECK_EQ(code, transpile(code, {}, true).code); } -TEST_CASE_FIXTURE(Fixture, "transpile_compound_assignmenr") +TEST_CASE_FIXTURE(Fixture, "transpile_compound_assignment") { + ScopedFastFlag sffs{"LuauFloorDivision", true}; + std::string code = R"( local a = 1 a += 2 a -= 3 a *= 4 a /= 5 +a //= 5 a %= 6 a ^= 7 a ..= ' - result' diff --git a/tests/TypeFamily.test.cpp b/tests/TypeFamily.test.cpp index 6e6dba09..c7ab66dd 100644 --- a/tests/TypeFamily.test.cpp +++ b/tests/TypeFamily.test.cpp @@ -1,5 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeFamily.h" + +#include "Luau/TxnLog.h" #include "Luau/Type.h" #include "Fixture.h" diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 55f4caec..29dcff83 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -187,8 +187,9 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") TEST_CASE_FIXTURE(Fixture, "generic_aliases") { - ScopedFastFlag sff_DebugLuauDeferredConstraintResolution{"DebugLuauDeferredConstraintResolution", true}; - + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + }; CheckResult result = check(R"( type T = { v: a } local x: T = { v = 123 } @@ -197,18 +198,19 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - const char* expectedError = "Type 'bad' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; - + const std::string expected = R"(Type 'bad' could not be converted into 'T' +caused by: + Property 'v' is not compatible. +Type 'string' could not be converted into 'number' in an invariant context)"; CHECK(result.errors[0].location == Location{{4, 31}, {4, 44}}); - CHECK(toString(result.errors[0]) == expectedError); + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") { - ScopedFastFlag sff_DebugLuauDeferredConstraintResolution{"DebugLuauDeferredConstraintResolution", true}; + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + }; CheckResult result = check(R"( type T = { v: a } @@ -218,15 +220,16 @@ TEST_CASE_FIXTURE(Fixture, "dependent_generic_aliases") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - std::string expectedError = "Type 'bad' could not be converted into 'U'\n" - "caused by:\n" - " Property 't' is not compatible. Type '{ v: string }' could not be converted into 'T'\n" - "caused by:\n" - " Property 'v' is not compatible. Type 'string' could not be converted into 'number' in an invariant context"; + const std::string expected = R"(Type 'bad' could not be converted into 'U' +caused by: + Property 't' is not compatible. +Type '{ v: string }' could not be converted into 'T' +caused by: + Property 'v' is not compatible. +Type 'string' could not be converted into 'number' in an invariant context)"; CHECK(result.errors[0].location == Location{{4, 31}, {4, 52}}); - CHECK(toString(result.errors[0]) == expectedError); + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "mutually_recursive_generic_aliases") @@ -261,7 +264,7 @@ TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") // We had a UAF in this example caused by not cloning type function arguments ModulePtr module = frontend.moduleResolver.getModule("MainModule"); unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes, builtinTypes); freeze(module->interfaceTypes); module->internalTypes.clear(); module->astTypes.clear(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 4e0b7a7e..de29812b 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -132,7 +132,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_predicate") TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") { - ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + ScopedFastFlag sff[] = { + {"LuauAlwaysCommitInferencesOfFunctionCalls", true}, + }; CheckResult result = check(R"( --!strict @@ -142,12 +144,20 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "sort_with_bad_predicate") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type '(number, number) -> boolean' could not be converted into '((string, string) -> boolean)?' + const std::string expected = R"(Type + '(number, number) -> boolean' +could not be converted into + '((string, string) -> boolean)?' caused by: - None of the union options are compatible. For example: Type '(number, number) -> boolean' could not be converted into '(string, string) -> boolean' + None of the union options are compatible. For example: +Type + '(number, number) -> boolean' +could not be converted into + '(string, string) -> boolean' caused by: - Argument #1 type is not compatible. Type 'string' could not be converted into 'number')", - toString(result.errors[0])); + Argument #1 type is not compatible. +Type 'string' could not be converted into 'number')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "strings_have_methods") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index a4df1be7..412ec1d0 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -367,8 +367,9 @@ b.X = 2 -- real Vector2.X is also read-only TEST_CASE_FIXTURE(ClassFixture, "detailed_class_unification_error") { - ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; - + ScopedFastFlag sff[] = { + {"LuauAlwaysCommitInferencesOfFunctionCalls", true}, + }; CheckResult result = check(R"( local function foo(v) return v.X :: number + string.len(v.Y) @@ -380,10 +381,11 @@ b(a) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(toString(result.errors[0]), R"(Type 'Vector2' could not be converted into '{- X: number, Y: string -}' + const std::string expected = R"(Type 'Vector2' could not be converted into '{- X: number, Y: string -}' caused by: - Property 'Y' is not compatible. Type 'number' could not be converted into 'string')"); + Property 'Y' is not compatible. +Type 'number' could not be converted into 'string')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(ClassFixture, "class_type_mismatch_with_name_conflict") @@ -462,9 +464,11 @@ local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + const std::string expected = R"(Type 'A' could not be converted into 'B' caused by: - Property 'x' is not compatible. Type 'ChildClass' could not be converted into 'BaseClass' in an invariant context)"); + Property 'x' is not compatible. +Type 'ChildClass' could not be converted into 'BaseClass' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(ClassFixture, "callable_classes") diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 379ecac6..dbd3f258 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1150,21 +1150,28 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); + std::string expected; if (FFlag::LuauInstantiateInSubtyping) { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + expected = R"(Type + '(number, number, a) -> number' +could not be converted into + '(number, number) -> number' caused by: - Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)"; } else { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + expected = R"(Type + '(number, number, a) -> number' +could not be converted into + '(number, number) -> number' caused by: - Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)"; } + CHECK_EQ(expected, toString(result.errors[0])); + // Infer from variadic packs into elements result = check(R"( function f(a: (...number) -> number) return a(1, 2) end @@ -1202,110 +1209,66 @@ f(function(x) return x * 2 end) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(BuiltinsFixture, "infer_anonymous_function_arguments") +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_function_function_argument") { - // Simple direct arg to arg propagation CheckResult result = check(R"( -type Table = { x: number, y: number } -local function f(a: (Table) -> number) return a({x = 1, y = 2}) end -f(function(a) return a.x + a.y end) +local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end +return sum(2, 3, function(a, b) return a + b end) )"); LUAU_REQUIRE_NO_ERRORS(result); - // An optional function is accepted, but since we already provide a function, nil can be ignored result = check(R"( -type Table = { x: number, y: number } -local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end -f(function(a) return a.x + a.y end) +local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end +local a = {1, 2, 3} +local r = map(a, function(a) return a + a > 100 end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{boolean}", toString(requireType("r"))); + + check(R"( +local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end +local a = {1, 2, 3} +local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") +{ + CheckResult result = check(R"( +local function g1(a: T, f: (T) -> T) return f(a) end +local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + +local g12: typeof(g1) & typeof(g2) + +g12(1, function(x) return x + x end) +g12(1, 2, function(x, y) return x + y end) )"); LUAU_REQUIRE_NO_ERRORS(result); - // Make sure self calls match correct index result = check(R"( -type Table = { x: number, y: number } -local x = {} -x.b = {x = 1, y = 2} -function x:f(a: (Table) -> number) return a(self.b) end -x:f(function(a) return a.x + a.y end) +local function g1(a: T, f: (T) -> T) return f(a) end +local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + +local g12: typeof(g1) & typeof(g2) + +g12({x=1}, function(x) return {x=-x.x} end) +g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) )"); LUAU_REQUIRE_NO_ERRORS(result); +} - // Mix inferred and explicit argument types - result = check(R"( -function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end -f(function(a: number, b, c) return c and a + b or b - a end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Anonymous function has a variadic pack - result = check(R"( -type Table = { x: number, y: number } -local function f(a: (Table) -> number) return a({x = 1, y = 2}) end -f(function(...) return select(1, ...).z end) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); - - // Can't accept more arguments than provided - result = check(R"( -function f(a: (a: number, b: number) -> number) return a(1, 2) end -f(function(a, b, c, ...) return a + b end) - )"); - - LUAU_REQUIRE_ERRORS(result); - - if (FFlag::LuauInstantiateInSubtyping) - { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' -caused by: - Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); - } - else - { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' -caused by: - Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); - } - - // Infer from variadic packs into elements - result = check(R"( -function f(a: (...number) -> number) return a(1, 2) end -f(function(a, b) return a + b end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Infer from variadic packs into variadic packs - result = check(R"( -type Table = { x: number, y: number } -function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end -f(function(a, ...) local b = ... return b.z end) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); - - // Return type inference - result = check(R"( -type Table = { x: number, y: number } -function f(a: (number) -> Table) return a(4) end -f(function(x) return x * 2 end) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - - // Return type doesn't inference 'nil' - result = check(R"( - function f(a: (number) -> nil) return a(4) end - f(function(x) print(x) end) +TEST_CASE_FIXTURE(BuiltinsFixture, "infer_generic_lib_function_function_argument") +{ + CheckResult result = check(R"( +local a = {{x=4}, {x=7}, {x=1}} +table.sort(a, function(x, y) return x.x < y.x end) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -1385,9 +1348,13 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' + const std::string expected = R"(Type + '(number, number) -> string' +could not be converted into + '(number) -> string' caused by: - Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") @@ -1401,9 +1368,14 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' + const std::string expected = R"(Type + '(number, number) -> string' +could not be converted into + '(number, string) -> string' caused by: - Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); + Argument #2 type is not compatible. +Type 'string' could not be converted into 'number')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") @@ -1417,9 +1389,13 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' + const std::string expected = R"(Type + '(number, number) -> number' +could not be converted into + '(number, number) -> (number, boolean)' caused by: - Function only returns 1 value, but 2 are required here)"); + Function only returns 1 value, but 2 are required here)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") @@ -1433,9 +1409,14 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' + const std::string expected = R"(Type + '(number, number) -> string' +could not be converted into + '(number, number) -> number' caused by: - Return type is not compatible. Type 'string' could not be converted into 'number')"); + Return type is not compatible. +Type 'string' could not be converted into 'number')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") @@ -1449,10 +1430,14 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' + const std::string expected = R"(Type + '(number, number) -> (number, string)' +could not be converted into + '(number, number) -> (number, boolean)' caused by: - Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); + Return #2 type is not compatible. +Type 'string' could not be converted into 'boolean')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "function_decl_quantify_right_type") @@ -1575,11 +1560,19 @@ end )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(string) -> string' could not be converted into '((number) -> number)?' + CHECK_EQ(toString(result.errors[0]), R"(Type + '(string) -> string' +could not be converted into + '((number) -> number)?' caused by: - None of the union options are compatible. For example: Type '(string) -> string' could not be converted into '(number) -> number' + None of the union options are compatible. For example: +Type + '(string) -> string' +could not be converted into + '(number) -> number' caused by: - Argument #1 type is not compatible. Type 'number' could not be converted into 'string')"); + Argument #1 type is not compatible. +Type 'number' could not be converted into 'string')"); CHECK_EQ(toString(result.errors[1]), R"(Type 'string' could not be converted into 'number')"); } @@ -1603,7 +1596,10 @@ function t:b() return 2 end -- not OK )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type '(*error-type*) -> number' could not be converted into '() -> number' + CHECK_EQ(R"(Type + '(*error-type*) -> number' +could not be converted into + '() -> number' caused by: Argument count mismatch. Function expects 1 argument, but none are specified)", toString(result.errors[0])); @@ -1832,11 +1828,11 @@ z = y -- Not OK, so the line is colorable )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '((\"blue\" | \"red\") -> (\"blue\" | \"red\") -> (\"blue\" | \"red\") -> boolean) & ((\"blue\" | \"red\") -> (\"blue\") -> (\"blue\") " - "-> false) & ((\"blue\" | \"red\") -> (\"red\") -> (\"red\") -> false) & ((\"blue\") -> (\"blue\") -> (\"blue\" | \"red\") -> false) & " - "((\"red\") -> (\"red\") -> (\"blue\" | \"red\") -> false)' could not be converted into '(\"blue\" | \"red\") -> (\"blue\" | \"red\") -> " - "(\"blue\" | \"red\") -> false'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '(("blue" | "red") -> ("blue" | "red") -> ("blue" | "red") -> boolean) & (("blue" | "red") -> ("blue") -> ("blue") -> false) & (("blue" | "red") -> ("red") -> ("red") -> false) & (("blue") -> ("blue") -> ("blue" | "red") -> false) & (("red") -> ("red") -> ("blue" | "red") -> false)' +could not be converted into + '("blue" | "red") -> ("blue" | "red") -> ("blue" | "red") -> false'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "function_is_supertype_of_concrete_functions") @@ -1994,7 +1990,9 @@ TEST_CASE_FIXTURE(Fixture, "function_exprs_are_generalized_at_signature_scope_no TEST_CASE_FIXTURE(BuiltinsFixture, "param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible") { - ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + ScopedFastFlag sff[] = { + {"LuauAlwaysCommitInferencesOfFunctionCalls", true}, + }; CheckResult result = check(R"( local function foo(x: a, y: a?) @@ -2038,11 +2036,12 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "param_1_and_2_both_takes_the_same_generic_bu LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '{ x: number }' could not be converted into 'vec2?' + const std::string expected = R"(Type '{ x: number }' could not be converted into 'vec2?' caused by: - None of the union options are compatible. For example: Table type '{ x: number }' not compatible with type 'vec2' because the former is missing field 'y')"); - - CHECK_EQ(toString(result.errors[1]), "Type 'vec2' could not be converted into 'number'"); + None of the union options are compatible. For example: +Table type '{ x: number }' not compatible with type 'vec2' because the former is missing field 'y')"; + CHECK_EQ(expected, toString(result.errors[0])); + CHECK_EQ("Type 'vec2' could not be converted into 'number'", toString(result.errors[1])); } TEST_CASE_FIXTURE(BuiltinsFixture, "param_1_and_2_both_takes_the_same_generic_but_their_arguments_are_incompatible_2") diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 6b933616..ea62f0bf 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -725,12 +725,14 @@ y.a.c = y )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), - R"(Type 'y' could not be converted into 'T' + const std::string expected = R"(Type 'y' could not be converted into 'T' caused by: - Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' + Property 'a' is not compatible. +Type '{ c: T?, d: number }' could not be converted into 'U' caused by: - Property 'd' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + Property 'd' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 954d9858..9e685881 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -333,9 +333,13 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") )"); LUAU_REQUIRE_ERROR_COUNT(4, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' + const std::string expected = R"(Type + '(string, number) -> string' +could not be converted into + '(string) -> string' caused by: - Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"; + CHECK_EQ(expected, toString(result.errors[0])); CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'X & Y'"); CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'X & Y'"); @@ -357,9 +361,14 @@ TEST_CASE_FIXTURE(Fixture, "table_write_sealed_indirect") )"); LUAU_REQUIRE_ERROR_COUNT(4, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(string, number) -> string' could not be converted into '(string) -> string' + const std::string expected = R"(Type + '(string, number) -> string' +could not be converted into + '(string) -> string' caused by: - Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"; + CHECK_EQ(expected, toString(result.errors[0])); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'z' to table 'XY'"); CHECK_EQ(toString(result.errors[2]), "Type 'number' could not be converted into 'string'"); CHECK_EQ(toString(result.errors[3]), "Cannot add property 'w' to table 'XY'"); @@ -386,9 +395,11 @@ local a: XYZ = 3 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' + const std::string expected = R"(Type 'number' could not be converted into 'X & Y & Z' caused by: - Not all intersection parts are compatible. Type 'number' could not be converted into 'X')"); + Not all intersection parts are compatible. +Type 'number' could not be converted into 'X')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") @@ -469,12 +480,16 @@ TEST_CASE_FIXTURE(Fixture, "intersect_saturate_overloaded_functions") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> number?) & ((string?) -> string?)' could not be converted into '(number) -> number'; " - "none of the intersection parts are compatible"); + const std::string expected = R"(Type + '((number?) -> number?) & ((string?) -> string?)' +could not be converted into + '(number) -> number'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") { + CheckResult result = check(R"( local x : ((number) -> number) & ((string) -> string) local y : ((number | string) -> (number | string)) = x -- OK @@ -482,8 +497,11 @@ TEST_CASE_FIXTURE(Fixture, "union_saturate_overloaded_functions") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number) & ((string) -> string)' could not be converted into '(boolean | number) -> " - "boolean | number'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '((number) -> number) & ((string) -> string)' +could not be converted into + '(boolean | number) -> boolean | number'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") @@ -495,8 +513,11 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' could not be converted into " - "'{| p: nil |}'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '{| p: number?, q: number?, r: number? |} & {| p: number?, q: string? |}' +could not be converted into + '{| p: nil |}'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") @@ -532,9 +553,11 @@ TEST_CASE_FIXTURE(Fixture, "intersection_of_tables_with_top_properties") else { LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '{| p: number?, q: any |} & {| p: unknown, q: string? |}' could not be converted into '{| p: string?, " - "q: number? |}'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '{| p: number?, q: any |} & {| p: unknown, q: string? |}' +could not be converted into + '{| p: string?, q: number? |}'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } } @@ -558,9 +581,11 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_returning_intersections") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' could not be converted into " - "'(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '((number?) -> {| p: number |} & {| q: number |}) & ((string?) -> {| p: number |} & {| r: number |})' +could not be converted into + '(number?) -> {| p: number, q: number, r: number |}'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") @@ -574,8 +599,11 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> a | number) & ((string?) -> a | string)' could not be converted into '(number?) -> a'; " - "none of the intersection parts are compatible"); + const std::string expected = R"(Type + '((number?) -> a | number) & ((string?) -> a | string)' +could not be converted into + '(number?) -> a'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generics") @@ -589,8 +617,11 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generics") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '((a?) -> a | b) & ((c?) -> b | c)' could not be converted into '(a?) -> (a & c) | b'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '((a?) -> a | b) & ((c?) -> b | c)' +could not be converted into + '(a?) -> (a & c) | b'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic_packs") @@ -604,8 +635,11 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_functions_mentioning_generic_packs") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' could not be converted " - "into '(nil, b...) -> (nil, a...)'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '((number?, a...) -> (number?, b...)) & ((string?, a...) -> (string?, b...))' +could not be converted into + '(nil, b...) -> (nil, a...)'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") @@ -619,8 +653,11 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_result") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> unknown) & ((number) -> number)' could not be converted into '(number?) -> number?'; none " - "of the intersection parts are compatible"); + const std::string expected = R"(Type + '((nil) -> unknown) & ((number) -> number)' +could not be converted into + '(number?) -> number?'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") @@ -634,8 +671,11 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_unknown_arguments") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number) -> number?) & ((unknown) -> string?)' could not be converted into '(number?) -> nil'; none " - "of the intersection parts are compatible"); + const std::string expected = R"(Type + '((number) -> number?) & ((unknown) -> string?)' +could not be converted into + '(number?) -> nil'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_result") @@ -649,8 +689,11 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_result") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((nil) -> never) & ((number) -> number)' could not be converted into '(number?) -> never'; none of " - "the intersection parts are compatible"); + const std::string expected = R"(Type + '((nil) -> never) & ((number) -> number)' +could not be converted into + '(number?) -> never'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_arguments") @@ -664,8 +707,11 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_never_arguments") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((never) -> string?) & ((number) -> number?)' could not be converted into '(number?) -> nil'; none " - "of the intersection parts are compatible"); + const std::string expected = R"(Type + '((never) -> string?) & ((number) -> number?)' +could not be converted into + '(number?) -> nil'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_overlapping_results_and_variadics") @@ -677,8 +723,11 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_overlapping_results_and_ )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((number?) -> (...number)) & ((string?) -> number | string)' could not be converted into '(number | " - "string) -> (number, number?)'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '((number?) -> (...number)) & ((string?) -> number | string)' +could not be converted into + '(number | string) -> (number, number?)'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_1") @@ -722,8 +771,11 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_3") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '(() -> (a...)) & (() -> (number?, a...))' could not be converted into '() -> number'; none of the intersection parts are compatible"); + const std::string expected = R"(Type + '(() -> (a...)) & (() -> (number?, a...))' +could not be converted into + '() -> number'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_4") @@ -737,8 +789,11 @@ TEST_CASE_FIXTURE(Fixture, "overloadeded_functions_with_weird_typepacks_4") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '((a...) -> ()) & ((number, a...) -> number)' could not be converted into '(number?) -> ()'; none of " - "the intersection parts are compatible"); + const std::string expected = R"(Type + '((a...) -> ()) & ((number, a...) -> number)' +could not be converted into + '(number?) -> ()'; none of the intersection parts are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "intersect_metatables") @@ -897,8 +952,6 @@ local y = x.Bar TEST_CASE_FIXTURE(BuiltinsFixture, "index_property_table_intersection_2") { - ScopedFastFlag sff{"LuauIndexTableIntersectionStringExpr", true}; - CheckResult result = check(R"( type Foo = { Bar: string, @@ -911,4 +964,54 @@ local y = x["Bar"] LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "cli_80596_simplify_degenerate_intersections") +{ + ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + type A = { + x: number?, + } + + type B = { + x: number?, + } + + type C = A & B + local obj: C = { + x = 3, + } + + local x: number = obj.x or 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cli_80596_simplify_more_realistic_intersections") +{ + ScopedFastFlag dcr{"DebugLuauDeferredConstraintResolution", true}; + + CheckResult result = check(R"( + type A = { + x: number?, + y: string?, + } + + type B = { + x: number?, + z: string?, + } + + type C = A & B + local obj: C = { + x = 3, + } + + local x: number = obj.x or 3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index b75f909a..08291b44 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -407,11 +407,12 @@ local b: B.T = a )"; CheckResult result = frontend.check("game/C"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' + const std::string expected = R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' caused by: - Property 'x' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + Property 'x' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "module_type_conflict_instantiated") @@ -441,11 +442,12 @@ local b: B.T = a )"; CheckResult result = frontend.check("game/D"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' + const std::string expected = R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' caused by: - Property 'x' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + Property 'x' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "constrained_anyification_clone_immutable_types") diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index 08c0f7ca..c589f134 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -9,6 +9,7 @@ #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" using namespace Luau; @@ -404,4 +405,92 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cycle_between_object_constructor_and_alias") CHECK_MESSAGE(get(follow(aliasType)), "Expected metatable type but got: " << toString(aliasType)); } +TEST_CASE_FIXTURE(BuiltinsFixture, "promise_type_error_too_complex") +{ + ScopedFastFlag sff{"LuauStacklessTypeClone2", true}; + + frontend.options.retainFullTypeGraphs = false; + + // Used `luau-reduce` tool to extract a minimal reproduction. + // Credit: https://github.com/evaera/roblox-lua-promise/blob/v4.0.0/lib/init.lua + CheckResult result = check(R"( + --!strict + + local Promise = {} + Promise.prototype = {} + Promise.__index = Promise.prototype + + function Promise._new(traceback, callback, parent) + if parent ~= nil and not Promise.is(parent)then + end + + local self = { + _parent = parent, + } + + parent._consumers[self] = true + setmetatable(self, Promise) + self:_reject() + + return self + end + + function Promise.resolve(...) + return Promise._new(debug.traceback(nil, 2), function(resolve) + end) + end + + function Promise.reject(...) + return Promise._new(debug.traceback(nil, 2), function(_, reject) + end) + end + + function Promise._try(traceback, callback, ...) + return Promise._new(traceback, function(resolve) + end) + end + + function Promise.try(callback, ...) + return Promise._try(debug.traceback(nil, 2), callback, ...) + end + + function Promise._all(traceback, promises, amount) + if #promises == 0 or amount == 0 then + return Promise.resolve({}) + end + return Promise._new(traceback, function(resolve, reject, onCancel) + end) + end + + function Promise.all(promises) + return Promise._all(debug.traceback(nil, 2), promises) + end + + function Promise.allSettled(promises) + return Promise.resolve({}) + end + + function Promise.race(promises) + return Promise._new(debug.traceback(nil, 2), function(resolve, reject, onCancel) + end) + end + + function Promise.each(list, predicate) + return Promise._new(debug.traceback(nil, 2), function(resolve, reject, onCancel) + local predicatePromise = Promise.resolve(predicate(value, index)) + local success, result = predicatePromise:await() + end) + end + + function Promise.is(object) + end + + function Promise.prototype:_reject(...) + self:_finalize() + end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index 0b7a8311..c588eaba 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -12,6 +12,8 @@ #include "doctest.h" +#include "ScopedFlags.h" + using namespace Luau; LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) @@ -143,6 +145,24 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") CHECK_EQ("number", toString(requireType("c"))); } +TEST_CASE_FIXTURE(Fixture, "floor_division_binary_op") +{ + ScopedFastFlag sffs{"LuauFloorDivision", true}; + + CheckResult result = check(R"( + local a = 4 // 8 + local b = -4 // 9 + local c = 9 + c //= -6.5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("number", toString(requireType("c"))); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_overloaded_multiply_that_is_an_intersection") { CheckResult result = check(R"( @@ -1047,8 +1067,9 @@ local w = c and 1 CHECK("false | number" == toString(requireType("z"))); else CHECK("boolean | number" == toString(requireType("z"))); // 'false' widened to boolean + if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK("((false?) & a) | number" == toString(requireType("w"))); + CHECK("((false?) & unknown) | number" == toString(requireType("w"))); else CHECK("(boolean | number)?" == toString(requireType("w"))); } diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index b3d70bd7..d4dbfa91 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -796,11 +796,14 @@ TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_ty )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' + const std::string expected = R"(Type + '{| x: number? |}' +could not be converted into + '{| x: number |}' caused by: - Property 'x' is not compatible. Type 'number?' could not be converted into 'number' in an invariant context)", - toString(result.errors[0])); + Property 'x' is not compatible. +Type 'number?' could not be converted into 'number' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_with_a_singleton_argument") @@ -1003,8 +1006,6 @@ end // We would prefer this unification to be able to complete, but at least it should not crash TEST_CASE_FIXTURE(BuiltinsFixture, "table_unification_infinite_recursion") { - ScopedFastFlag luauTableUnifyRecursionLimit{"LuauTableUnifyRecursionLimit", true}; - #if defined(_NOOPT) || defined(_DEBUG) ScopedFastInt LuauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", 100}; #endif @@ -1044,4 +1045,42 @@ tbl:f3() } } +// Ideally, unification with any will not cause a 2^n normalization of a function overload +TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_limit_in_unify_with_any") +{ + ScopedFastFlag sff[] = { + {"LuauTransitiveSubtyping", true}, + {"DebugLuauDeferredConstraintResolution", true}, + }; + + // With default limit, this test will take 10 seconds in NoOpt + ScopedFastInt luauNormalizeCacheLimit{"LuauNormalizeCacheLimit", 1000}; + + // Build a function type with a large overload set + const int parts = 100; + std::string source; + + for (int i = 0; i < parts; i++) + formatAppend(source, "type T%d = { f%d: number }\n", i, i); + + source += "type Instance = { new: (('s0', extra: Instance?) -> T0)"; + + for (int i = 1; i < parts; i++) + formatAppend(source, " & (('s%d', extra: Instance?) -> T%d)", i, i); + + source += " }\n"; + + source += R"( +local Instance: Instance = {} :: any + +local function foo(a: typeof(Instance.new)) return if a then 2 else 3 end + +foo(1 :: any) +)"; + + CheckResult result = check(source); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index c0dbfce8..c42830aa 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -308,7 +308,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_if_condition_position") else CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number", toString(requireTypeAtPosition({6, 26}))); - } TEST_CASE_FIXTURE(BuiltinsFixture, "typeguard_in_assert_position") diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 84e9fc7c..d085f900 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -325,10 +325,11 @@ local a: Animal = { tag = 'cat', cafood = 'something' } )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 'a' could not be converted into 'Cat | Dog' + const std::string expected = R"(Type 'a' could not be converted into 'Cat | Dog' caused by: - None of the union options are compatible. For example: Table type 'a' not compatible with type 'Cat' because the former is missing field 'catfood')", - toString(result.errors[0])); + None of the union options are compatible. For example: +Table type 'a' not compatible with type 'Cat' because the former is missing field 'catfood')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") @@ -342,10 +343,11 @@ local a: Result = { success = false, result = 'something' } )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 'a' could not be converted into 'Bad | Good' + const std::string expected = R"(Type 'a' could not be converted into 'Bad | Good' caused by: - None of the union options are compatible. For example: Table type 'a' not compatible with type 'Bad' because the former is missing field 'error')", - toString(result.errors[0])); + None of the union options are compatible. For example: +Table type 'a' not compatible with type 'Bad' because the former is missing field 'error')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "parametric_tagged_union_alias") @@ -353,7 +355,6 @@ TEST_CASE_FIXTURE(Fixture, "parametric_tagged_union_alias") ScopedFastFlag sff[] = { {"DebugLuauDeferredConstraintResolution", true}, }; - CheckResult result = check(R"( type Ok = {success: true, result: T} type Err = {success: false, error: T} @@ -365,10 +366,10 @@ TEST_CASE_FIXTURE(Fixture, "parametric_tagged_union_alias") LUAU_REQUIRE_ERROR_COUNT(1, result); - const std::string expectedError = "Type 'a' could not be converted into 'Err | Ok'\n" - "caused by:\n" - " None of the union options are compatible. For example: Table type 'a'" - " not compatible with type 'Err' because the former is missing field 'error'"; + const std::string expectedError = R"(Type 'a' could not be converted into 'Err | Ok' +caused by: + None of the union options are compatible. For example: +Table type 'a' not compatible with type 'Err' because the former is missing field 'error')"; CHECK(toString(result.errors[0]) == expectedError); } diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 8d93561f..aea38253 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2092,9 +2092,11 @@ local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + const std::string expected = R"(Type 'A' could not be converted into 'B' caused by: - Property 'y' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + Property 'y' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") @@ -2111,11 +2113,14 @@ local b: B = a )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + const std::string expected = R"(Type 'A' could not be converted into 'B' caused by: - Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' + Property 'b' is not compatible. +Type 'AS' could not be converted into 'BS' caused by: - Property 'y' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + Property 'y' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(BuiltinsFixture, "error_detailed_metatable_prop") @@ -2130,28 +2135,61 @@ local b2 = setmetatable({ x = 2, y = 4 }, { __call = function(s, t) end }); local c2: typeof(a2) = b2 )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' + const std::string expected1 = R"(Type 'b1' could not be converted into 'a1' caused by: - Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' + Type + '{ x: number, y: string }' +could not be converted into + '{ x: number, y: number }' caused by: - Property 'y' is not compatible. Type 'string' could not be converted into 'number' in an invariant context)"); + Property 'y' is not compatible. +Type 'string' could not be converted into 'number' in an invariant context)"; + const std::string expected2 = R"(Type 'b2' could not be converted into 'a2' +caused by: + Type + '{ __call: (a, b) -> () }' +could not be converted into + '{ __call: (a) -> () }' +caused by: + Property '__call' is not compatible. +Type + '(a, b) -> ()' +could not be converted into + '(a) -> ()'; different number of generic type parameters)"; + const std::string expected3 = R"(Type 'b2' could not be converted into 'a2' +caused by: + Type + '{ __call: (a, b) -> () }' +could not be converted into + '{ __call: (a) -> () }' +caused by: + Property '__call' is not compatible. +Type + '(a, b) -> ()' +could not be converted into + '(a) -> ()'; different number of generic type parameters)"; + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(expected1, toString(result.errors[0])); if (FFlag::LuauInstantiateInSubtyping) { - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' -caused by: - Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' -caused by: - Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + CHECK_EQ(expected2, toString(result.errors[1])); } else { - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + std::string expected3 = R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' + Type + '{ __call: (a, b) -> () }' +could not be converted into + '{ __call: (a) -> () }' caused by: - Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + Property '__call' is not compatible. +Type + '(a, b) -> ()' +could not be converted into + '(a) -> ()'; different number of generic type parameters)"; + CHECK_EQ(expected3, toString(result.errors[1])); } } @@ -2166,9 +2204,11 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + const std::string expected = R"(Type 'A' could not be converted into 'B' caused by: - Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + Property '[indexer key]' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") @@ -2182,9 +2222,11 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' + const std::string expected = R"(Type 'A' could not be converted into 'B' caused by: - Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string' in an invariant context)"); + Property '[indexer value]' is not compatible. +Type 'number' could not be converted into 'string' in an invariant context)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") @@ -2218,9 +2260,11 @@ local y: number = tmp.p.y )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' + const std::string expected = R"(Type 'tmp' could not be converted into 'HasSuper' caused by: - Property 'p' is not compatible. Table type '{ x: number, y: number }' not compatible with type 'Super' because the former has extra field 'y')"); + Property 'p' is not compatible. +Table type '{ x: number, y: number }' not compatible with type 'Super' because the former has extra field 'y')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") @@ -3302,7 +3346,9 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_a_subtype_of_a_compatible_polymorphic_shap TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_shape_type") { - ScopedFastFlag sff{"LuauAlwaysCommitInferencesOfFunctionCalls", true}; + ScopedFastFlag sff[] = { + {"LuauAlwaysCommitInferencesOfFunctionCalls", true}, + }; CheckResult result = check(R"( local function f(s) @@ -3316,22 +3362,32 @@ TEST_CASE_FIXTURE(Fixture, "scalar_is_not_a_subtype_of_a_compatible_polymorphic_ LUAU_REQUIRE_ERROR_COUNT(3, result); - CHECK_EQ(R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + const std::string expected1 = + R"(Type 'string' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: - The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); + The former's metatable does not satisfy the requirements. +Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')"; + CHECK_EQ(expected1, toString(result.errors[0])); - CHECK_EQ(R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' -caused by: - The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[1])); - CHECK_EQ(R"(Type '"bar" | "baz"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + const std::string expected2 = + R"(Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: - Not all union options are compatible. Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' + The former's metatable does not satisfy the requirements. +Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')"; + CHECK_EQ(expected2, toString(result.errors[1])); + + const std::string expected3 = R"(Type + '"bar" | "baz"' +could not be converted into + 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' caused by: - The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[2])); + Not all union options are compatible. +Type '"bar"' could not be converted into 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' +caused by: + The former's metatable does not satisfy the requirements. +Table type 'typeof(string)' not compatible with type 't1 where t1 = {- absolutely_no_scalar_has_this_method: (t1) -> (a...) -}' because the former is missing field 'absolutely_no_scalar_has_this_method')"; + CHECK_EQ(expected3, toString(result.errors[2])); } TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") @@ -3349,6 +3405,7 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compati TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible") { + CheckResult result = check(R"( local function f(s): string local foo = s:absolutely_no_scalar_has_this_method() @@ -3356,11 +3413,13 @@ TEST_CASE_FIXTURE(Fixture, "a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_ end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' + const std::string expected = + R"(Type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' could not be converted into 'string' caused by: - The former's metatable does not satisfy the requirements. Table type 'typeof(string)' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')", - toString(result.errors[0])); + The former's metatable does not satisfy the requirements. +Table type 'typeof(string)' not compatible with type 't1 where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}' because the former is missing field 'absolutely_no_scalar_has_this_method')"; + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(expected, toString(result.errors[0])); CHECK_EQ("(t1) -> string where t1 = {+ absolutely_no_scalar_has_this_method: (t1) -> (a, b...) +}", toString(requireType("f"))); } @@ -3729,4 +3788,40 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_shifted_tables") LUAU_REQUIRE_NO_ERRORS(result); } + +TEST_CASE_FIXTURE(Fixture, "cli_84607_missing_prop_in_array_or_dict") +{ + ScopedFastFlag sff{"LuauFixIndexerSubtypingOrdering", true}; + + CheckResult result = check(R"( + type Thing = { name: string, prop: boolean } + + local arrayOfThings : {Thing} = { + { name = "a" } + } + + local dictOfThings : {[string]: Thing} = { + a = { name = "a" } + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + TypeError& err1 = result.errors[0]; + MissingProperties* error1 = get(err1); + REQUIRE(error1); + REQUIRE(error1->properties.size() == 1); + + CHECK_EQ("prop", error1->properties[0]); + + TypeError& err2 = result.errors[1]; + TypeMismatch* mismatch = get(err2); + REQUIRE(mismatch); + MissingProperties* error2 = get(*mismatch->error); + REQUIRE(error2); + REQUIRE(error2->properties.size() == 1); + + CHECK_EQ("prop", error2->properties[0]); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 4f40b386..2d34fc7f 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -1028,16 +1028,23 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") // unsound. LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ( - R"(Type 'Policies' from 'MainModule' could not be converted into 'Policies' from 'MainModule' + const std::string expected = R"(Type 'Policies' from 'MainModule' could not be converted into 'Policies' from 'MainModule' caused by: - Property 'getStoreFieldName' is not compatible. Type '(Policies, FieldSpecifier & {| from: number? |}) -> (a, b...)' could not be converted into '(Policies, FieldSpecifier) -> string' + Property 'getStoreFieldName' is not compatible. +Type + '(Policies, FieldSpecifier & {| from: number? |}) -> (a, b...)' +could not be converted into + '(Policies, FieldSpecifier) -> string' caused by: - Argument #2 type is not compatible. Type 'FieldSpecifier' could not be converted into 'FieldSpecifier & {| from: number? |}' + Argument #2 type is not compatible. +Type + 'FieldSpecifier' +could not be converted into + 'FieldSpecifier & {| from: number? |}' caused by: - Not all intersection parts are compatible. Table type 'FieldSpecifier' not compatible with type '{| from: number? |}' because the former has extra field 'fieldName')", - toString(result.errors[0])); + Not all intersection parts are compatible. +Table type 'FieldSpecifier' not compatible with type '{| from: number? |}' because the former has extra field 'fieldName')"; + CHECK_EQ(expected, toString(result.errors[0])); } else { @@ -1205,8 +1212,6 @@ end TEST_CASE_FIXTURE(BuiltinsFixture, "typechecking_in_type_guards") { - ScopedFastFlag sff{"LuauTypecheckTypeguards", true}; - CheckResult result = check(R"( local a = type(foo) == 'nil' local b = typeof(foo) ~= 'nil' @@ -1399,4 +1404,32 @@ TEST_CASE_FIXTURE(Fixture, "promote_tail_type_packs") LUAU_REQUIRE_NO_ERRORS(result); } +/* + * CLI-49876 + * + * We had a bug where we would not use the correct TxnLog when evaluating a + * variadic overload. We could therefore get into a state where the TxnLog has + * logged that a generic matches to one type, but the variadic tail has already + * been bound to another type outside of that TxnLog. + * + * This caused type checking to succeed when it should have failed. + */ +TEST_CASE_FIXTURE(BuiltinsFixture, "be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload") +{ + ScopedFastFlag sff{"LuauVariadicOverloadFix", true}; + + CheckResult result = check(R"( + local function concat(target: {T}, ...: {T} | T): {T} + return (nil :: any) :: {T} + end + + local res = concat({"alic"}, 1, 2) + )"); + + LUAU_REQUIRE_ERRORS(result); + + for (const auto& e : result.errors) + CHECK(5 == e.location.begin.line); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 5656e871..1f0b04c6 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -373,11 +373,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table state.log.commit(); REQUIRE_EQ(state.errors.size(), 1); - - std::string expected = "Type '{ @metatable {| __index: {| foo: string |} |}, { } }' could not be converted into '{- foo: number -}'\n" - "caused by:\n" - " Type 'number' could not be converted into 'string'"; - CHECK_EQ(toString(state.errors[0]), expected); + const std::string expected = R"(Type + '{ @metatable {| __index: {| foo: string |} |}, { } }' +could not be converted into + '{- foo: number -}' +caused by: + Type 'number' could not be converted into 'string')"; + CHECK_EQ(expected, toString(state.errors[0])); } TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index d2ae166b..92d736f8 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -879,9 +879,13 @@ a = b )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '() -> (number, ...boolean)' could not be converted into '() -> (number, ...string)' + const std::string expected = R"(Type + '() -> (number, ...boolean)' +could not be converted into + '() -> (number, ...string)' caused by: - Type 'boolean' could not be converted into 'string')"); + Type 'boolean' could not be converted into 'string')"; + CHECK_EQ(expected, toString(result.errors[0])); } // TODO: File a Jira about this diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 3ab7bebb..b455903c 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -471,9 +471,11 @@ local b: { w: number } = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'X | Y | Z' could not be converted into '{| w: number |}' + const std::string expected = R"(Type 'X | Y | Z' could not be converted into '{| w: number |}' caused by: - Not all union options are compatible. Table type 'X' not compatible with type '{| w: number |}' because the former is missing field 'w')"); + Not all union options are compatible. +Table type 'X' not compatible with type '{| w: number |}' because the former is missing field 'w')"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "error_detailed_union_all") @@ -501,9 +503,11 @@ local a: X? = { w = 4 } )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X?' + const std::string expected = R"(Type 'a' could not be converted into 'X?' caused by: - None of the union options are compatible. For example: Table type 'a' not compatible with type 'X' because the former is missing field 'x')"); + None of the union options are compatible. For example: +Table type 'a' not compatible with type 'X' because the former is missing field 'x')"; + CHECK_EQ(expected, toString(result.errors[0])); } // We had a bug where a cyclic union caused a stack overflow. @@ -524,6 +528,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") { + CheckResult result = check(R"( type A = { x: number, y: (number) -> string } | { z: number, y: (number) -> string } @@ -540,8 +545,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_union_write_indirect") LUAU_REQUIRE_ERROR_COUNT(1, result); // NOTE: union normalization will improve this message - CHECK_EQ(toString(result.errors[0]), - R"(Type '(string) -> number' could not be converted into '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"); + const std::string expected = R"(Type + '(string) -> number' +could not be converted into + '((number) -> string) | ((number) -> string)'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_true_and_false") @@ -615,8 +623,11 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_mentioning_generic_typepacks") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(number, a...) -> (number?, a...)' could not be converted into '((number) -> number) | ((number?, " - "a...) -> (number?, a...))'; none of the union options are compatible"); + const std::string expected = R"(Type + '(number, a...) -> (number?, a...)' +could not be converted into + '((number) -> number) | ((number?, a...) -> (number?, a...))'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") @@ -628,12 +639,16 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_arities") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(number) -> number?' could not be converted into '((number) -> nil) | ((number, string?) -> " - "number)'; none of the union options are compatible"); + const std::string expected = R"(Type + '(number) -> number?' +could not be converted into + '((number) -> nil) | ((number, string?) -> number)'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") { + CheckResult result = check(R"( local x : () -> (number | string) local y : (() -> number) | (() -> string) = x -- OK @@ -641,12 +656,16 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_arities") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '() -> number | string' could not be converted into '(() -> (string, string)) | (() -> number)'; none " - "of the union options are compatible"); + const std::string expected = R"(Type + '() -> number | string' +could not be converted into + '(() -> (string, string)) | (() -> number)'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") { + CheckResult result = check(R"( local x : (...nil) -> (...number?) local y : ((...string?) -> (...number)) | ((...number?) -> nil) = x -- OK @@ -654,12 +673,16 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_variadics") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '(...nil) -> (...number?)' could not be converted into '((...string?) -> (...number)) | ((...string?) " - "-> nil)'; none of the union options are compatible"); + const std::string expected = R"(Type + '(...nil) -> (...number?)' +could not be converted into + '((...string?) -> (...number)) | ((...string?) -> nil)'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_variadics") { + CheckResult result = check(R"( local x : (number) -> () local y : ((number?) -> ()) | ((...number) -> ()) = x -- OK @@ -667,12 +690,16 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_arg_variadics") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - "Type '(number) -> ()' could not be converted into '((...number?) -> ()) | ((number?) -> ())'; none of the union options are compatible"); + const std::string expected = R"(Type + '(number) -> ()' +could not be converted into + '((...number?) -> ()) | ((number?) -> ())'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics") { + CheckResult result = check(R"( local x : () -> (number?, ...number) local y : (() -> (...number)) | (() -> nil) = x -- OK @@ -680,8 +707,11 @@ TEST_CASE_FIXTURE(Fixture, "union_of_functions_with_mismatching_result_variadics )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Type '() -> (number?, ...number)' could not be converted into '(() -> (...number)) | (() -> number)'; none " - "of the union options are compatible"); + const std::string expected = R"(Type + '() -> (number?, ...number)' +could not be converted into + '(() -> (...number)) | (() -> number)'; none of the union options are compatible)"; + CHECK_EQ(expected, toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "less_greedy_unification_with_union_types") diff --git a/tests/Unifier2.test.cpp b/tests/Unifier2.test.cpp index 363c8109..2e6cf3b6 100644 --- a/tests/Unifier2.test.cpp +++ b/tests/Unifier2.test.cpp @@ -81,18 +81,12 @@ TEST_CASE_FIXTURE(Unifier2Fixture, "T <: U") TEST_CASE_FIXTURE(Unifier2Fixture, "(string) -> () <: (X) -> Y...") { - TypeId stringToUnit = arena.addType(FunctionType{ - arena.addTypePack({builtinTypes.stringType}), - arena.addTypePack({}) - }); + TypeId stringToUnit = arena.addType(FunctionType{arena.addTypePack({builtinTypes.stringType}), arena.addTypePack({})}); auto [x, xFree] = freshType(); TypePackId y = arena.freshTypePack(&scope); - TypeId xToY = arena.addType(FunctionType{ - arena.addTypePack({x}), - y - }); + TypeId xToY = arena.addType(FunctionType{arena.addTypePack({x}), y}); u2.unify(stringToUnit, xToY); @@ -105,4 +99,54 @@ TEST_CASE_FIXTURE(Unifier2Fixture, "(string) -> () <: (X) -> Y...") CHECK(!yPack->tail); } +TEST_CASE_FIXTURE(Unifier2Fixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type") +{ + auto [t1, ft1] = freshType(); + auto [t2, ft2] = freshType(); + + // t2 <: t1 <: unknown + // unknown <: t2 <: t1 + + ft1->lowerBound = t2; + ft2->upperBound = t1; + ft2->lowerBound = builtinTypes.unknownType; + + auto t2generalized = u2.generalize(NotNull{&scope}, t2); + REQUIRE(t2generalized); + + CHECK(follow(t1) == follow(t2)); + + auto t1generalized = u2.generalize(NotNull{&scope}, t1); + REQUIRE(t1generalized); + + CHECK(builtinTypes.unknownType == follow(t1)); + CHECK(builtinTypes.unknownType == follow(t2)); +} + +// Same as generalize_a_type_that_is_bounded_by_another_generalizable_type +// except that we generalize the types in the opposite order +TEST_CASE_FIXTURE(Unifier2Fixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type_in_reverse_order") +{ + auto [t1, ft1] = freshType(); + auto [t2, ft2] = freshType(); + + // t2 <: t1 <: unknown + // unknown <: t2 <: t1 + + ft1->lowerBound = t2; + ft2->upperBound = t1; + ft2->lowerBound = builtinTypes.unknownType; + + auto t1generalized = u2.generalize(NotNull{&scope}, t1); + REQUIRE(t1generalized); + + CHECK(follow(t1) == follow(t2)); + + auto t2generalized = u2.generalize(NotNull{&scope}, t2); + REQUIRE(t2generalized); + + CHECK(builtinTypes.unknownType == follow(t1)); + CHECK(builtinTypes.unknownType == follow(t2)); +} + TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 07150ba8..17f4497a 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -91,6 +91,15 @@ assert((function() local a = 1 a = a + 2 return a end)() == 3) assert((function() local a = 1 a = a - 2 return a end)() == -1) assert((function() local a = 1 a = a * 2 return a end)() == 2) assert((function() local a = 1 a = a / 2 return a end)() == 0.5) + +-- floor division should always round towards -Infinity +assert((function() local a = 1 a = a // 2 return a end)() == 0) +assert((function() local a = 3 a = a // 2 return a end)() == 1) +assert((function() local a = 3.5 a = a // 2 return a end)() == 1) +assert((function() local a = -1 a = a // 2 return a end)() == -1) +assert((function() local a = -3 a = a // 2 return a end)() == -2) +assert((function() local a = -3.5 a = a // 2 return a end)() == -2) + assert((function() local a = 5 a = a % 2 return a end)() == 1) assert((function() local a = 3 a = a ^ 2 return a end)() == 9) assert((function() local a = 3 a = a ^ 3 return a end)() == 27) @@ -168,6 +177,33 @@ assert((function() local a = 1 for b=1,9 do a = a * 2 if a == 128 then break els -- make sure internal index is protected against modification assert((function() local a = 1 for b=9,1,-2 do a = a * 2 b = nil end return a end)() == 32) +-- make sure that when step is 0, we treat it as backward iteration (and as such, iterate zero times or indefinitely) +-- this is consistent with Lua 5.1; future Lua versions emit an error when step is 0; LuaJIT instead treats 0 as forward iteration +-- we repeat tests twice, with and without constant folding +local zero = tonumber("0") +assert((function() local c = 0 for i=1,10,0 do c += 1 if c > 10 then break end end return c end)() == 0) +assert((function() local c = 0 for i=10,1,0 do c += 1 if c > 10 then break end end return c end)() == 11) +assert((function() local c = 0 for i=1,10,zero do c += 1 if c > 10 then break end end return c end)() == 0) +assert((function() local c = 0 for i=10,1,zero do c += 1 if c > 10 then break end end return c end)() == 11) + +-- make sure that when limit is nan, we iterate zero times (this is consistent with Lua 5.1; future Lua versions break this) +-- we repeat tests twice, with and without constant folding +local nan = tonumber("nan") +assert((function() local c = 0 for i=1,0/0 do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=1,0/0,-1 do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=1,nan do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=1,nan,-1 do c += 1 end return c end)() == 0) + +-- make sure that when step is nan, we treat it as backward iteration and as such iterate once iff start<=limit +assert((function() local c = 0 for i=1,10,0/0 do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=10,1,0/0 do c += 1 end return c end)() == 1) +assert((function() local c = 0 for i=1,10,nan do c += 1 end return c end)() == 0) +assert((function() local c = 0 for i=10,1,nan do c += 1 end return c end)() == 1) + +-- make sure that when index becomes nan mid-iteration, we correctly exit the loop (this is broken in Lua 5.1; future Lua versions fix this) +assert((function() local c = 0 for i=-math.huge,0,math.huge do c += 1 end return c end)() == 1) +assert((function() local c = 0 for i=math.huge,math.huge,-math.huge do c += 1 end return c end)() == 1) + -- generic for -- ipairs assert((function() local a = '' for k in ipairs({5, 6, 7}) do a = a .. k end return a end)() == "123") @@ -277,6 +313,10 @@ assert((function() return result end)() == "ArcticDunesCanyonsWaterMountainsHillsLavaflowPlainsMarsh") +-- table literals may contain duplicate fields; the language doesn't specify assignment order but we currently assign left to right +assert((function() local t = {data = 4, data = nil, data = 42} return t.data end)() == 42) +assert((function() local t = {data = 4, data = nil, data = 42, data = nil} return t.data end)() == nil) + -- multiple returns -- local= assert((function() function foo() return 2, 3, 4 end local a, b, c = foo() return ''..a..b..c end)() == "234") @@ -494,6 +534,7 @@ local function vec3t(x, y, z) __sub = function(l, r) return vec3t(l.x - r.x, l.y - r.y, l.z - r.z) end, __mul = function(l, r) return type(r) == "number" and vec3t(l.x * r, l.y * r, l.z * r) or vec3t(l.x * r.x, l.y * r.y, l.z * r.z) end, __div = function(l, r) return type(r) == "number" and vec3t(l.x / r, l.y / r, l.z / r) or vec3t(l.x / r.x, l.y / r.y, l.z / r.z) end, + __idiv = function(l, r) return type(r) == "number" and vec3t(l.x // r, l.y // r, l.z // r) or vec3t(l.x // r.x, l.y // r.y, l.z // r.z) end, __unm = function(v) return vec3t(-v.x, -v.y, -v.z) end, __tostring = function(v) return string.format("%g, %g, %g", v.x, v.y, v.z) end }) @@ -504,10 +545,13 @@ assert((function() return tostring(vec3t(1,2,3) + vec3t(4,5,6)) end)() == "5, 7, assert((function() return tostring(vec3t(1,2,3) - vec3t(4,5,6)) end)() == "-3, -3, -3") assert((function() return tostring(vec3t(1,2,3) * vec3t(4,5,6)) end)() == "4, 10, 18") assert((function() return tostring(vec3t(1,2,3) / vec3t(2,4,8)) end)() == "0.5, 0.5, 0.375") +assert((function() return tostring(vec3t(1,2,3) // vec3t(2,4,2)) end)() == "0, 0, 1") +assert((function() return tostring(vec3t(1,2,3) // vec3t(-2,-4,-2)) end)() == "-1, -1, -2") -- reg vs constant assert((function() return tostring(vec3t(1,2,3) * 2) end)() == "2, 4, 6") assert((function() return tostring(vec3t(1,2,3) / 2) end)() == "0.5, 1, 1.5") +assert((function() return tostring(vec3t(1,2,3) // 2) end)() == "0, 1, 1") -- unary assert((function() return tostring(-vec3t(1,2,3)) end)() == "-1, -2, -3") diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua index 94314c3f..4ee801f0 100644 --- a/tests/conformance/events.lua +++ b/tests/conformance/events.lua @@ -107,6 +107,7 @@ t.__add = f("add") t.__sub = f("sub") t.__mul = f("mul") t.__div = f("div") +t.__idiv = f("idiv") t.__mod = f("mod") t.__unm = f("unm") t.__pow = f("pow") @@ -128,6 +129,8 @@ assert(a*a == a) assert(cap[0] == "mul" and cap[1] == a and cap[2] == a and cap[3]==nil) assert(a/0 == a) assert(cap[0] == "div" and cap[1] == a and cap[2] == 0 and cap[3]==nil) +assert(a//0 == a) +assert(cap[0] == "idiv" and cap[1] == a and cap[2] == 0 and cap[3]==nil) assert(a%2 == a) assert(cap[0] == "mod" and cap[1] == a and cap[2] == 2 and cap[3]==nil) assert(-a == a) diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 9e9ae384..4027bba0 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -82,6 +82,7 @@ assert(not(1>1) and not(1>2) and (2>1)) assert(not('a'>'a') and not('a'>'b') and ('b'>'a')) assert((1>=1) and not(1>=2) and (2>=1)) assert(('a'>='a') and not('a'>='b') and ('b'>='a')) +assert((unk and unk > 0) == nil) -- validate precedence between and and > -- testing mod operator assert(-4%3 == 2) @@ -188,6 +189,26 @@ do -- testing NaN assert(a[NaN] == nil) end +-- extra NaN tests, hidden in a function +do + function neq(a) return a ~= a end + function eq(a) return a == a end + function lt(a) return a < a end + function le(a) return a <= a end + function gt(a) return a > a end + function ge(a) return a >= a end + + local NaN -- to avoid constant folding + NaN = 10e500 - 10e400 + + assert(neq(NaN)) + assert(not eq(NaN)) + assert(not lt(NaN)) + assert(not le(NaN)) + assert(not gt(NaN)) + assert(not ge(NaN)) +end + -- require "checktable" -- stat(a) diff --git a/tests/conformance/native.lua b/tests/conformance/native.lua index 6c0e0e0e..7a77edec 100644 --- a/tests/conformance/native.lua +++ b/tests/conformance/native.lua @@ -14,6 +14,17 @@ assert((function(x, y) return c, b, t, t1, t2 end)(5, 10) == 50) +assert((function(x) + local oops -- split to prevent inlining + function oops() + end + + -- x is checked to be a number here; we can not execute a reentry from oops() because optimizer assumes this holds until return + local y = math.abs(x) + oops() + return y * x +end)("42") == 1764) + local function fuzzfail1(...) repeat _ = nil @@ -101,4 +112,63 @@ end assert(pcall(fuzzfail10) == false) +local function fuzzfail11(x, ...) + return bit32.arshift(bit32.bnot(x),(...)) +end + +assert(fuzzfail11(0xffff0000, 8) == 0xff) + +local function fuzzfail12() + _,_,_,_,_,_,_,_ = not _, not _, not _, not _, not _, not _, not _, not _ +end + +assert(pcall(fuzzfail12) == true) + +local function fuzzfail13() + _,_,_,_,_,_,_,_,_,_,_,_,_,_,_,_ = not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _, not _ +end + +assert(pcall(fuzzfail13) == true) + +local function arraySizeInv1() + local t = {1, 2, nil, nil, nil, nil, nil, nil, nil, true} + + table.insert(t, 3) + + return t[10] +end + +assert(arraySizeInv1() == true) + +local function arraySizeInv2() + local t = {1, 2, nil, nil, nil, nil, nil, nil, nil, true} + + local u = {a = t} + table.insert(u.a, 3) -- aliased modifiction of 't' register through other value + + return t[10] +end + +assert(arraySizeInv2() == true) + +local function nilInvalidatesSlot() + local function tabs() + local t = { x=1, y=2, z=3 } + setmetatable(t, { __index = function(t, k) return 42 end }) + return t, t + end + + local t1, t2 = tabs() + + for i=1,2 do + local a = t1.x + t2.x = nil + local b = t1.x + t2.x = 1 + assert(a == 1 and b == 42) + end +end + +nilInvalidatesSlot() + return('OK') diff --git a/tests/conformance/native_types.lua b/tests/conformance/native_types.lua index 67779230..c375ab81 100644 --- a/tests/conformance/native_types.lua +++ b/tests/conformance/native_types.lua @@ -68,5 +68,16 @@ ecall(checkuserdata, 2) call(checkvector, vector(1, 2, 3)) ecall(checkvector, 2) +local function mutation_causes_bad_exit(a: number, count: number, sum: number) + repeat + a = 's' + sum += count + pcall(function() end) + count -= 1 + until count == 0 + return sum +end + +assert(call(mutation_causes_bad_exit, 5, 10, 0) == 55) return('OK') diff --git a/tests/conformance/tmerror.lua b/tests/conformance/tmerror.lua index 1ad4dd16..fef077ea 100644 --- a/tests/conformance/tmerror.lua +++ b/tests/conformance/tmerror.lua @@ -12,4 +12,30 @@ pcall(function() testtable.missingmethod() end) +-- +local testtable2 = {} +setmetatable(testtable2, { __index = function() pcall(function() error("Error") end) end }) + +local m2 = testtable2.missingmethod + +pcall(function() + testtable2.missingmethod() +end) + +-- +local testtable3 = {} +setmetatable(testtable3, { __index = function() pcall(error, "Error") end }) + +local m3 = testtable3.missingmethod + +pcall(function() + testtable3.missingmethod() +end) + +-- +local testtable4 = {} +setmetatable(testtable4, { __index = function() pcall(error) end }) + +local m4 = testtable4.missingmember + return('OK') diff --git a/tests/conformance/userdata.lua b/tests/conformance/userdata.lua index 98392e25..32759ad1 100644 --- a/tests/conformance/userdata.lua +++ b/tests/conformance/userdata.lua @@ -30,6 +30,12 @@ assert(int64(4) / 2 == int64(2)) assert(int64(4) % 3 == int64(1)) assert(int64(2) ^ 3 == int64(8)) +-- / and // round in different directions in our test implementation +assert(int64(5) / int64(2) == int64(2)) +assert(int64(5) // int64(2) == int64(2)) +assert(int64(-5) / int64(2) == int64(-2)) +assert(int64(-5) // int64(2) == int64(-3)) + -- tostring assert(tostring(int64(2)) == "2") diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 6164e929..c9cc47aa 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -67,6 +67,18 @@ assert(vector(1, 2, 4) / '8' == vector(1/8, 1/4, 1/2)); assert(-vector(1, 2, 4) == vector(-1, -2, -4)); +-- test floor division +assert(vector(1, 3, 5) // 2 == vector(0, 1, 2)) +assert(vector(1, 3, 5) // val == vector(8, 24, 40)) + +if vector_size == 4 then + assert(10 // vector(1, 2, 3, 4) == vector(10, 5, 3, 2)) + assert(vector(10, 9, 8, 7) // vector(1, 2, 3, 4) == vector(10, 4, 2, 1)) +else + assert(10 // vector(1, 2, 3) == vector(10, 5, 3)) + assert(vector(10, 9, 8) // vector(1, 2, 3) == vector(10, 4, 2)) +end + -- test NaN comparison local nanv = vector(0/0, 0/0, 0/0) assert(nanv ~= nanv); diff --git a/tests/main.cpp b/tests/main.cpp index cfa1f9db..26872196 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -6,6 +6,8 @@ #define DOCTEST_CONFIG_OPTIONS_PREFIX "" #include "doctest.h" +#include "RegisterCallbacks.h" + #ifdef _WIN32 #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN @@ -161,8 +163,10 @@ struct BoostLikeReporter : doctest::IReporter } void log_message(const doctest::MessageData& md) override - { // - printf("%s(%d): ERROR: %s\n", md.m_file, md.m_line, md.m_string.c_str()); + { + const char* severity = (md.m_severity & doctest::assertType::is_warn) ? "WARNING" : "ERROR"; + + printf("%s(%d): %s: %s\n", md.m_file, md.m_line, severity, md.m_string.c_str()); } // called when a test case is skipped either because it doesn't pass the filters, has a skip decorator @@ -325,6 +329,14 @@ int main(int argc, char** argv) } } + // These callbacks register unit tests that need runtime support to be + // correctly set up. Running them here means that all command line flags + // have been parsed, fast flags have been set, and we've potentially already + // exited. Once doctest::Context::run is invoked, the test list will be + // picked up from global state. + for (Luau::RegisterCallback cb : Luau::getRegisterCallbacks()) + cb(); + int result = context.run(); if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h")) { diff --git a/tools/faillist.txt b/tools/faillist.txt index 390b8b62..69790ba5 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -3,7 +3,6 @@ AnnotationTests.two_type_params AstQuery.last_argument_function_call_type AutocompleteTest.anonymous_autofilled_generic_on_argument_type_pack_vararg AutocompleteTest.anonymous_autofilled_generic_type_pack_vararg -AutocompleteTest.autocomplete_if_else_regression AutocompleteTest.autocomplete_interpolated_string_as_singleton AutocompleteTest.autocomplete_oop_implicit_self AutocompleteTest.autocomplete_response_perf1 @@ -92,8 +91,6 @@ IntersectionTypes.table_intersection_write_sealed_indirect IntersectionTypes.table_write_sealed_indirect Normalize.negations_of_tables Normalize.specific_functions_cannot_be_negated -ParserTests.parse_nesting_based_end_detection -ParserTests.parse_nesting_based_end_detection_single_line ProvisionalTests.assign_table_with_refined_property_with_a_similar_type_is_illegal ProvisionalTests.do_not_ice_when_trying_to_pick_first_of_generic_type_pack ProvisionalTests.error_on_eq_metamethod_returning_a_type_other_than_boolean @@ -124,13 +121,13 @@ RefinementTest.x_as_any_if_x_is_instance_elseif_x_is_table TableTests.a_free_shape_can_turn_into_a_scalar_if_it_is_compatible TableTests.a_free_shape_cannot_turn_into_a_scalar_if_it_is_not_compatible TableTests.call_method -TableTests.call_method_with_explicit_self_argument TableTests.cannot_augment_sealed_table TableTests.cannot_change_type_of_unsealed_table_prop TableTests.casting_sealed_tables_with_props_into_table_with_indexer TableTests.casting_tables_with_props_into_table_with_indexer4 TableTests.casting_unsealed_tables_with_props_into_table_with_indexer TableTests.checked_prop_too_early +TableTests.cli_84607_missing_prop_in_array_or_dict TableTests.cyclic_shifted_tables TableTests.defining_a_method_for_a_local_sealed_table_must_fail TableTests.defining_a_self_method_for_a_local_sealed_table_must_fail @@ -139,6 +136,7 @@ TableTests.dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar TableTests.dont_extend_unsealed_tables_in_rvalue_position TableTests.dont_hang_when_trying_to_look_up_in_cyclic_metatable_index TableTests.dont_leak_free_table_props +TableTests.dont_quantify_table_that_belongs_to_outer_scope TableTests.dont_suggest_exact_match_keys TableTests.error_detailed_metatable_prop TableTests.explicitly_typed_table @@ -154,19 +152,21 @@ TableTests.inequality_operators_imply_exactly_matching_types TableTests.infer_array_2 TableTests.infer_indexer_from_value_property_in_literal TableTests.infer_type_when_indexing_from_a_table_indexer -TableTests.inferred_properties_of_a_table_should_start_with_the_same_TypeLevel_of_that_table TableTests.inferred_return_type_of_free_table TableTests.instantiate_table_cloning_3 TableTests.leaking_bad_metatable_errors TableTests.less_exponential_blowup_please TableTests.missing_metatable_for_sealed_tables_do_not_get_inferred TableTests.mixed_tables_with_implicit_numbered_keys +TableTests.ok_to_add_property_to_free_table TableTests.ok_to_provide_a_subtype_during_construction TableTests.okay_to_add_property_to_unsealed_tables_by_assignment -TableTests.okay_to_add_property_to_unsealed_tables_by_function_call -TableTests.only_ascribe_synthetic_names_at_module_scope TableTests.oop_indexer_works TableTests.oop_polymorphic +TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table +TableTests.pass_a_union_of_tables_to_a_function_that_requires_a_table_2 +TableTests.pass_incompatible_union_to_a_generic_table_without_crashing +TableTests.passing_compatible_unions_to_a_generic_table_without_crashing TableTests.quantify_even_that_table_was_never_exported_at_all TableTests.quantify_metatables_of_metatables_of_table TableTests.quantifying_a_bound_var_works @@ -185,7 +185,6 @@ TableTests.table_unification_4 TableTests.type_mismatch_on_massive_table_is_cut_short TableTests.used_colon_instead_of_dot TableTests.used_dot_instead_of_colon -TableTests.used_dot_instead_of_colon_but_correctly TableTests.when_augmenting_an_unsealed_table_with_an_indexer_apply_the_correct_scope_to_the_indexer_type TableTests.wrong_assign_does_hit_indexer ToDot.function @@ -215,6 +214,7 @@ TypeAliases.type_alias_of_an_imported_recursive_generic_type TypeFamilyTests.family_as_fn_arg TypeFamilyTests.table_internal_families TypeFamilyTests.unsolvable_family +TypeInfer.be_sure_to_use_active_txnlog_when_evaluating_a_variadic_overload TypeInfer.bidirectional_checking_of_higher_order_function TypeInfer.check_type_infer_recursion_count TypeInfer.cli_39932_use_unifier_in_ensure_methods @@ -226,21 +226,24 @@ TypeInfer.fuzz_free_table_type_change_during_index_check TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.infer_locals_via_assignment_from_its_call_site TypeInfer.no_stack_overflow_from_isoptional -TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter TypeInfer.recursive_function_that_invokes_itself_with_a_refinement_of_its_parameter_2 TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.type_infer_cache_limit_normalizer TypeInfer.type_infer_recursion_limit_no_ice TypeInfer.type_infer_recursion_limit_normalizer TypeInferAnyError.can_subscript_any +TypeInferAnyError.for_in_loop_iterator_is_any TypeInferAnyError.for_in_loop_iterator_is_any2 +TypeInferAnyError.for_in_loop_iterator_is_error +TypeInferAnyError.for_in_loop_iterator_is_error2 TypeInferAnyError.for_in_loop_iterator_returns_any +TypeInferAnyError.for_in_loop_iterator_returns_any2 TypeInferAnyError.intersection_of_any_can_have_props TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferAnyError.union_of_types_regression_test -TypeInferClasses.can_read_prop_of_base_class_using_string TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.index_instance_property +TypeInferClasses.table_class_unification_reports_sane_errors_for_missing_properties TypeInferFunctions.cannot_hoist_interior_defns_into_signature TypeInferFunctions.dont_assert_when_the_tarjan_limit_is_exceeded_during_generalization TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site @@ -255,6 +258,9 @@ TypeInferFunctions.higher_order_function_4 TypeInferFunctions.improved_function_arg_mismatch_errors TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_anonymous_function_arguments_outside_call +TypeInferFunctions.infer_generic_function_function_argument +TypeInferFunctions.infer_generic_function_function_argument_overloaded +TypeInferFunctions.infer_generic_lib_function_function_argument TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.luau_subtyping_is_np_hard TypeInferFunctions.no_lossy_function_type @@ -266,21 +272,22 @@ TypeInferFunctions.too_few_arguments_variadic_generic2 TypeInferFunctions.too_many_arguments_error_location TypeInferFunctions.too_many_return_values_in_parentheses TypeInferFunctions.too_many_return_values_no_function -TypeInferFunctions.toposort_doesnt_break_mutual_recursion -TypeInferFunctions.vararg_function_is_quantified TypeInferLoops.cli_68448_iterators_need_not_accept_nil TypeInferLoops.dcr_iteration_on_never_gives_never TypeInferLoops.for_in_loop TypeInferLoops.for_in_loop_error_on_factory_not_returning_the_right_amount_of_values +TypeInferLoops.for_in_loop_on_error TypeInferLoops.for_in_loop_with_custom_iterator TypeInferLoops.for_in_loop_with_incompatible_args_to_iterator TypeInferLoops.for_in_loop_with_next TypeInferLoops.ipairs_produces_integral_indices TypeInferLoops.iteration_regression_issue_69967_alt +TypeInferLoops.loop_iter_basic TypeInferLoops.loop_iter_metamethod_nil TypeInferLoops.loop_iter_metamethod_ok_with_inference TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.unreachable_code_after_infinite_loop +TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferModules.do_not_modify_imported_types_5 TypeInferModules.module_type_conflict TypeInferModules.module_type_conflict_instantiated @@ -291,7 +298,6 @@ TypeInferOOP.inferring_hundreds_of_self_calls_should_not_suffocate_memory TypeInferOOP.methods_are_topologically_sorted TypeInferOperators.and_binexps_dont_unify TypeInferOperators.cli_38355_recursive_union -TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.concat_op_on_string_lhs_and_free_rhs TypeInferOperators.disallow_string_and_types_without_metatables_from_arithmetic_binary_ops TypeInferOperators.luau_polyfill_is_array